In [78]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, AvgPool2D, ReLU, UpSampling2D, Add, Layer, Input

In [79]:
class OctaveConvolution(Layer):
    def __init__(self,c_out, kernel_size=3,strides=1,first=False, last=False):
        super(OctaveConvolution, self).__init__()
        alpha_out = 0.5
        self.first = first
        self.last = last
        self.strides = strides
        if self.last:
            alpha_out = 0

        l_l = h_l = int(alpha_out*c_out)
        h_h = l_h = c_out - l_l

        if self.strides>1:
            self.reduce_low  = AvgPool2D(strides=self.strides)
            self.reduce_high = AvgPool2D(strides=self.strides)

        self.L2L = Conv2D(l_l,kernel_size=kernel_size,padding='same',use_bias=False)
        self.H2L_pool = AvgPool2D(strides=2)
        self.H2L = Conv2D(h_l,kernel_size=kernel_size,padding='same',use_bias=False)

        self.H2H = Conv2D(h_h,kernel_size=kernel_size,padding='same',use_bias=False)
        self.L2H = Conv2D(l_h,kernel_size=kernel_size,padding='same',use_bias=False)
        self.L2H_up = UpSampling2D(size=(2,2))
        
    
    def call(self, *args):
        if self.first:
            low_filter  = self.H2L(self.H2L_pool(args[0]))
            high_filter = self.H2H(args[0])
            return low_filter, high_filter
        if self.last:
            final_filter  = self.L2H_up(self.L2H(args[0])) + self.H2H(args[1])
            return final_filter
        low, high= args[0], args[1]
        if self.strides>1:
            low = self.reduce_low(args[0])
            high = self.reduce_high(args[1])
        low_filter  = self.L2L(low) + self.H2L(self.H2L_pool(high))
        high_filter = self.L2H_up(self.L2H(low)) + self.H2H(high)

        return low_filter, high_filter

In [80]:
inputs = Input(shape=(224,224,3))
x_low, x_high = OctaveConvolution(48,3,1,True,False)(inputs)
x_low, x_high = OctaveConvolution(64)(x_low, x_high)
x_low, x_high = OctaveConvolution(96,strides=2)(x_low, x_high)
outputs = OctaveConvolution(10,3,1,False,True)(x_low, x_high)

(None, 224, 224, 24) (None, 112, 112, 24)
(None, 224, 224, 32) (None, 112, 112, 32)
(None, 112, 112, 48) (None, 56, 56, 48)


In [74]:
model = tf.keras.Model(inputs, outputs)

In [75]:
model.summary()

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_23 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
octave_convolution_44 (Octav ((None, 16, 16, 24), (Non 1296      
_________________________________________________________________
octave_convolution_45 (Octav ((None, 16, 16, 32), (Non 27648     
_________________________________________________________________
octave_convolution_46 (Octav ((None, 8, 8, 48), (None, 55296     
_________________________________________________________________
octave_convolution_47 (Octav (None, 16, 16, 10)        8640      
Total params: 92,880
Trainable params: 92,880
Non-trainable params: 0
_________________________________________________________________
