In [1]:
import tensorflow as tf

In [2]:
class Conv(tf.keras.layers.Layer):
    def __init__(self,
                 filters:int,
                 kernel_size,
                 strides = (1,1),
                 padding = 'same',
                 activation = 'relu',
                 use_bias = False
                ):
        super(Conv, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.activation = activation
        self.use_bias = use_bias
        
        self.C = tf.keras.layers.Conv2D(filters = self.filters,
                                        kernel_size = self.kernel_size,
                                        strides = self.strides,
                                        padding = self.padding,
                                        use_bias = self.use_bias
                                       )
        if not self.use_bias:
            self.BN = tf.keras.layers.BatchNormalization()
        
    def call(self, X):
        y = self.C(X)
        if not self.use_bias:
            y = self.BN(y)
        if self.activation == 'relu':
            y = tf.nn.relu(y)
        return y

    
class ScaledActivation(tf.keras.layers.Layer):
    def __init__(self, scale:float):
        super(ScaledActivation, self).__init__()
        self.scale = scale
        
    def call(self, X):
        y = X * self.scale
        return y

In [3]:
class Stem(tf.keras.layers.Layer):
    def __init__(self):
        super(Stem, self).__init__()
        
        self.C1 = Conv(filters = 32,
                       kernel_size = (3,3),
                       strides = (2,2),
                       padding = 'valid'
                      )
        self.C2 = Conv(filters = 32,
                       kernel_size = (3,3),
                       padding = 'valid'
                      )
        self.C3 = Conv(filters = 64,
                       kernel_size = (3,3),
                      )
        self.MP4_1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                               strides = (2,2),
                                               padding = 'valid'
                                              )
        self.C4_2 = Conv(filters = 96,
                         kernel_size = (3,3),
                         strides = (2,2),
                         padding = 'valid'
                        )
        
        self.C5_1 = Conv(filters = 64,
                         kernel_size = (1,1)
                        )
        self.C5_2 = Conv(filters = 64,
                         kernel_size = (1,1),
                        )
        self.C6_1 = Conv(filters = 96,
                         kernel_size = (3,3),
                         padding = 'valid'
                        )
        self.C6_2 = Conv(filters = 64,
                         kernel_size = (7,1)
                        )
        self.C7_2 = Conv(filters = 64,
                         kernel_size = (1,7),
                        )
        self.C8_2 = Conv(filters = 96,
                         kernel_size = (3,3),
                         padding = 'valid'
                        )
        
        self.C9_1 = Conv(filters = 192,
                         kernel_size = (3,3),
                         strides = (2,2),
                         padding = 'valid'
                        )
        self.MP9_2 = tf.keras.layers.MaxPool2D(pool_size = (3,3), ###
                                               strides = (2,2),
                                               padding = 'valid'
                                              )
        
    def call(self, X):
        y = self.C1(X)
        y = self.C2(y)
        y = self.C3(y)
        y1 = self.MP4_1(y)
        y2 = self.C4_2(y)
        y = tf.concat([y1,y2], axis = -1)
        
        y1 = self.C5_1(y)
        y1 = self.C6_1(y1)
        y2 = self.C5_2(y)
        y2 = self.C6_2(y2)
        y2 = self.C7_2(y2)
        y2 = self.C8_2(y2)
        y = tf.concat([y1, y2], axis = -1)
        
        y1 = self.C9_1(y)
        y2 = self.MP9_2(y)
        y = tf.concat([y1,y2], axis = -1)
        return y

    
class InceptionResnetA(tf.keras.layers.Layer):
    def __init__(self):
        super(InceptionResnetA, self).__init__()
        
        #Path1
        self.P1 = Conv(filters = 32,
                       kernel_size = (1,1)
                      )
        #Path2
        self.P2 = tf.keras.Sequential([
            Conv(filters =  32,
                 kernel_size = (1,1)
                ),
            Conv(filters = 32,
                 kernel_size = (3,3)
                )
        ])
        #Path3
        self.P3 = tf.keras.Sequential([
            Conv(filters = 32,
                 kernel_size = (1,1)
                ),
            Conv(filters = 48,
                 kernel_size = (3,3)
                ),
            Conv(filters = 64,
                 kernel_size = (3,3)
                )
        ])
        self.Linear = Conv(filters = 384,
                           kernel_size = (1,1),
                           activation = 'linear',
                           use_bias = True
                          )
        self.Scale = ScaledActivation(.1)
    
    def call(self, X):
        y1 = self.P1(X)
        y2 = self.P2(X)
        y3 = self.P3(X)
        y = self.Linear(tf.concat([y1,y2,y3], axis = -1))
        y = self.Scale(y)
        y = tf.nn.relu(X + y)
        return y
    

class InceptionResnetB(tf.keras.layers.Layer):
    def __init__(self, Identity=True):
        super(InceptionResnetB, self).__init__()
        self.Identity = Identity
        
        #Path1
        self.P1 = Conv(filters = 192,
                       kernel_size = (1,1)
                      )
        #Path2
        self.P2 = tf.keras.Sequential([
            Conv(filters = 128,
                 kernel_size = (1,1)
                ),
            Conv(filters = 160,
                 kernel_size = (1,7)
                ),
            Conv(filters = 192,
                 kernel_size = (7,1)
                )
        ])
        self.Act = Conv(filters = 1154,
                        kernel_size = (1,1),
                        activation = 'linear',
                        use_bias = True
                       )
        self.Scale = ScaledActivation(.1)
        if not self.Identity:
            self.Res = Conv(filters = 1154,
                            kernel_size = (1,1),
                            padding = 'valid',
                            activation = 'linear'
                           )
        
    def call(self, X):
        y1 = self.P1(X)
        y2 = self.P2(X)
        y = self.Act(tf.concat([y1,y2], axis = -1))
        y = self.Scale(y)
        if not self.Identity:
            X = self.Res(X)
        y = tf.nn.relu(X + y)
        return y
    
class InceptionResnetC(tf.keras.layers.Layer):
    def __init__(self, Identity = True):
        super(InceptionResnetC, self).__init__()
        self.Identity = Identity
        
        #Path1
        self.P1 = Conv(filters = 192,
                       kernel_size = (1,1)
                      )
        #Path2
        self.P2 = tf.keras.Sequential([
            Conv(filters = 192,
                 kernel_size = (1,1)
                ),
            Conv(filters = 224,
                 kernel_size = (1,3)
                ),
            Conv(filters = 256,
                 kernel_size = (3,1)
                )
        ])
        self.Act = Conv(filters = 2048,
                        kernel_size = (1,1),
                        activation = 'linear',
                        use_bias = False
                       )
        self.Scale = ScaledActivation(.1)
        if not self.Identity:
            self.Res = Conv(filters = 2048,
                            kernel_size = (1,1),
                            padding = 'valid',
                            activation = 'linear'
                           )
            
    def call(self, X):
        y1 = self.P1(X)
        y2 = self.P2(X)
        y = self.Act(tf.concat([y1, y2], axis = -1))
        y = self.Scale(y)
        if not self.Identity:
            X = self.Res(X)
        y = tf.nn.relu(X + y)
        return y
    

class ReductionA(tf.keras.layers.Layer):
    def __init__(self):
        super(ReductionA, self).__init__()
        
        #Path1
        self.P1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                            strides = (2,2)
                                           )
        #Path2
        self.P2 = Conv(filters = 384,
                       kernel_size = (3,3),
                       strides = (2,2),
                       padding = 'valid'
                      )
        #Path3
        self.P3 = tf.keras.Sequential([
            Conv(filters = 256,
                 kernel_size = (1,1)
                ),
            Conv(filters = 256,
                 kernel_size = (3,3)
                ),
            Conv(filters = 384,
                 kernel_size = (3,3),
                 strides = (2,2),
                 padding = 'valid'
                )
        ])
        
    def call(self, X):
        y1 = self.P1(X)
        y2 = self.P2(X)
        y3 = self.P3(X)
        y = tf.concat([y1,y2,y3], axis = -1)
        return y
    
    
class ReductionB(tf.keras.layers.Layer):
    def __init__(self):
        super(ReductionB, self).__init__()
        
        #Path1
        self.P1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                            strides = (2,2)
                                           )
        #Path2
        self.P2 = tf.keras.Sequential([
            Conv(filters = 256,
                 kernel_size = (1,1)
                ),
            Conv(filters = 384,
                 kernel_size = (3,3),
                 strides = (2,2),
                 padding = 'valid'
                )
        ])
        #Path3
        self.P3 = tf.keras.Sequential([
            Conv(filters = 256,
                 kernel_size = (1,1)
                ),
            Conv(filters = 288,
                 kernel_size = (3,3),
                 strides = (2,2),
                 padding = 'valid'
                ) 
        ])
        #Path4
        self.P4 = tf.keras.Sequential([
            Conv(filters = 256,
                 kernel_size = (1,1)
                ),
            Conv(filters = 288,
                 kernel_size = (3,3)
                ),
            Conv(filters = 320,
                 kernel_size = (3,3),
                 strides = (2,2),
                 padding = 'valid'
                )
        ])
        
    def call(self, X):
        y1 = self.P1(X)
        y2 = self.P2(X)
        y3 = self.P3(X)
        y4 = self.P4(X)
        y = tf.concat([y1,y2,y3,y4], axis = -1)
        return y

In [4]:
class InceptionResnetV2(tf.keras.models.Model):
    def __init__(self, n_labels:int, last_activation:str = 'softmax'):
        super(InceptionResnetV2, self).__init__()
        self.n_labels = n_labels
        self.last_activation = last_activation
        
        self.Stem = Stem()
        self.InceptionABlocks = tf.keras.Sequential([
            InceptionResnetA() for x in range(5)
        ])
        self.ReductionA = ReductionA()
        self.InceptionBBlocks = tf.keras.Sequential([
            InceptionResnetB(Identity = False)
        ] + [
            InceptionResnetB() for x in range(9)
        ])
        self.ReductionB = ReductionB()
        self.InceptionCBlocks = tf.keras.Sequential([
            InceptionResnetC(Identity = False)
        ] + [
            InceptionResnetC() for x in range(4)
        ])
        self.classifier = tf.keras.Sequential([
            tf.keras.layers.AveragePooling2D(pool_size = (8,8),
                                             strides = (1,1)
                                            ),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dropout(.2),
            tf.keras.layers.Dense(self.n_labels, activation = self.last_activation)
        ])
        
    def call(self, X):
        y = self.Stem(X)
        y = self.InceptionABlocks(y)
        y = self.ReductionA(y)
        y = self.InceptionBBlocks(y)
        y = self.ReductionB(y)
        y = self.InceptionCBlocks(y)
        y = self.classifier(y)
        return y

In [5]:
inceptionrv2 = InceptionResnetV2(1000)

In [6]:
inceptionrv2.build([16,299,299,3])

In [7]:
inceptionrv2.summary()

Model: "inception_resnet_v2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
stem (Stem)                  multiple                  607456    
_________________________________________________________________
sequential_10 (Sequential)   (16, 35, 35, 384)         690240    
_________________________________________________________________
reduction_a (ReductionA)     multiple                  2905088   
_________________________________________________________________
sequential_22 (Sequential)   (16, 17, 17, 1154)        13079964  
_________________________________________________________________
reduction_b (ReductionB)     multiple                  3935744   
_________________________________________________________________
sequential_31 (Sequential)   (16, 8, 8, 2048)          14524032  
_________________________________________________________________
sequential_32 (Sequential)   (16, 1000)        