In [1]:
import tensorflow as tf
from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model

In [2]:
class InvertedResidualBlock(Model):
    def __init__(self, input_channel, output_channel, factor, strides):
        super().__init__()
        
        self.expand = Conv2D(filters=input_channel*factor, kernel_size=1, strides=1, padding='valid')
        self.expand_bn = BatchNormalization()
        self.expand_av = Activation(tf.nn.relu6)
        
        self.depthwise = DepthwiseConv2D(kernel_size=3, strides=strides, padding='same')
        self.depthwise_bn = BatchNormalization()
        self.depthwise_av = Activation(tf.nn.relu6)
        
        self.projection = Conv2D(filters=output_channel, kernel_size=1, strides=1, padding='valid')
        self.projection_bn = BatchNormalization()
        
        self.add = Add()
        self.sc = True if strides==1 else False
        
    
    def call(self, x):
        expand = self.expand_av(self.expand_bn(self.expand(x)))
        depthwise = self.depthwise_av(self.depthwise_bn(self.depthwise(expand)))
        output = self.projection_bn(self.projection(depthwise))
        if self.sc:
            output = self.add([output, x])
            
        return output