In [30]:
import tensorflow as tf
from tensorflow import layers

# Operations

In [44]:
OPS = {
    'none': lambda C, stride: Zero(stride),
    'avg_pool_3x3': lambda C, stride: layers.AveragePooling2D(3, strides=stride, padding='same'),
    'max_pool_3x3': lambda C, stride: layers.MaxPooling2D(3, strides=stride, padding='same'),
    'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
    'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 'same'),
    'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 'same'),
    'sep_conv_7x7': lambda C, stride: SepConv(C, C, 7, stride, 'same'),
    'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 'same', 2),
    'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 'same', 2),
    'conv_7x1_1x7': lambda C, stride: Conv_7x1_1x7(C, stride)
}

In [32]:
class ReLUConvBN(tf.keras.layers.Layer):
    """Applies ReLU, Conv and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding='same'):
        """Initializes the operation

        Args:
            C_in (int): no of kernels in
            C_out (int): no of kernels out
            kernel_size (int): size of kernel
            stride (int): stride
            padding (int): padding
            affine (bool), optional): Defaults to True.
        """
        super(ReLUConvBN, self).__init__()
        self.relu = tf.nn.relu
        self.conv = layers.Conv2D(filters=C_out,
                                  kernel_size=kernel_size,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn = layers.BatchNormalization()
        
    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

In [33]:
class DilConv(tf.keras.layers.Layer):
    """Applies ReLU, Conv with dilation and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super(DilConv, self).__init__()
        self.relu = tf.nn.relu
        self.dil_conv = layers.Conv2D(filters=C_out, 
                                      kernel_size=kernel_size, 
                                      strides=stride, 
                                      padding=padding,
                                      dilation_rate=dilation,
                                      use_bias=False
                                     )
        self.conv = layers.Conv2D(filters=C_out,
                                  kernel_size=1,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn = layers.BatchNormalization()

    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.dil_conv(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

In [34]:
class SepConv(tf.keras.layers.Layer):
    """Applies ReLU, Sep Conv with dilation and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding):
        super(SepConv, self).__init__()
        self.relu = tf.nn.relu
        self.conv1 = layers.Conv2D(filters=C_in,
                                  kernel_size=kernel_size,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.conv2 = layers.Conv2D(filters=C_in,
                                  kernel_size=1,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn1 = layers.BatchNormalization()
        self.conv3 = layers.Conv2D(filters=C_in,
                                  kernel_size=kernel_size,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.conv4 = layers.Conv2D(filters=C_out,
                                  kernel_size=1,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn2 = layers.BatchNormalization()

    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.bn2(x)
        return x


In [35]:
class Identity(tf.keras.layers.Layer):
    """Apply the identity operation
    """

    def __init__(self):
        super(Identity, self).__init__()
        pass

    def call(self, x):
        return x

In [36]:
class Zero(tf.keras.layers.Layer):
    """Makes array element zero with given stride
    """

    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride

    def call(self, x):
        if self.stride == 1:
            return tf.multiply(x, 0)
        return tf.multiply(x[:, ::self.stride, ::self.stride, :], 0)

In [37]:
class Conv_7x1_1x7(tf.keras.layers.Layer):

  def __init__(self, C, stride):
    super(Conv_7x1_1x7, self).__init__()
    self.relu = tf.nn.relu
    self.conv1 = layers.Conv2D(filters=C,
                             kernel_size=(1,7),
                             strides=(1, stride),
                             padding='same',
                             use_bias=False)
    self.conv2 = layers.Conv2D(filters=C,
                             kernel_size=(7,1),
                             strides=(stride, 1),
                             padding='same',
                             use_bias=False)
    self.bn = layers.BatchNormalization()

  def call(self, x):
    x = self.relu(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.bn(x)
    return x 

In [38]:
class FactorizedReduce(tf.keras.layers.Layer):
    """Applies ReLU, conv with stride=2 and c_out/2 
    """

    def __init__(self, C_in, C_out):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.relu = tf.nn.relu
        self.conv_1 = layers.Conv2D(filters=C_out//2, 
                                    kernel_size=1,
                                    strides=2, 
                                    padding='same', 
                                    use_bias=False)
        self.conv_2 = layers.Conv2D(filters=C_out//2, 
                                    kernel_size=1,
                                    strides=2, 
                                    padding='same', 
                                    use_bias=False)
        self.bn = layers.BatchNormalization()

    def call(self, x):
        """concats conv and Batch normalise them

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: tensor of operations on input
        """
        x = self.relu(x)
        out = tf.concat([self.conv_1(x), self.conv_2(x[:, 1:, 1:, :])], axis=3)
        out = self.bn(out)
        return out


In [39]:
class FactorizedUp(tf.keras.layers.Layer):

  def __init__(self, C_in, C_out):
    super(FactorizedUp, self).__init__()
    self.relu = tf.nn.relu
    self.trans_conv1 = layers.Conv2DTranspose(filters=C_out, 
                                             kernel_size=3,
                                             strides=2,
                                             padding='same',
                                             )
    self.trans_conv2 = layers.Conv2DTranspose(filters=C_out,
                                             kernel_size=3,
                                             strides=2,
                                             padding='same',
                                             )
    
    self.bn = layers.BatchNormalization()

  def call(self, x):
    x = self.relu(x)
    out = (self.trans_conv1(x) + self.trans_conv2(x)) * 0.5
    out = self.bn(out)
    return out

In [40]:
class SkipConnection(tf.keras.layers.Layer):

  def __init__(self, C):    
    super(SkipConnection, self).__init__()
    self.relu = tf.nn.relu
    self.conv = layers.Conv2D(filters=C,
                             kernel_size=3,
                             strides=1,
                             padding='same',
                             use_bias=False)
    self.bn = layers.BatchNormalization()

  def call(self, s0, s1):
    s0 = self.relu(s0)
    s1 = self.relu(s1)
    x = tf.concat([s1, s0], axis=3)
    x = self.conv(x)
    out = self.bn(x)
    return out 

# Test Here

In [48]:
for op in OPS:
    model = OPS[op](6, 1)

    image = tf.random_uniform([1, 16, 16, 3], 0, 255, seed=0, dtype=tf.int32)

    ip = tf.placeholder(tf.float32, shape=[None, 16, 16, 3], name = "input")
    out = model(ip)
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        image = sess.run(image)
        out = sess.run(out, {ip: image})
        print(op, out.shape)
print('Pass')

none (1, 16, 16, 3)
avg_pool_3x3 (1, 16, 16, 3)
max_pool_3x3 (1, 16, 16, 3)
skip_connect (1, 16, 16, 3)
sep_conv_3x3 (1, 16, 16, 6)
sep_conv_5x5 (1, 16, 16, 6)
sep_conv_7x7 (1, 16, 16, 6)
dil_conv_3x3 (1, 16, 16, 6)
dil_conv_5x5 (1, 16, 16, 6)
conv_7x1_1x7 (1, 16, 16, 6)
Pass


Test Skip Connection Here

In [65]:
model = SkipConnection(3)

In [36]:
t1 = tf.random_uniform([1, 16, 16, 6], 0, 255, seed=0, dtype=tf.int32)
t2 = tf.random_uniform([1, 16, 16, 6], 0, 255, seed=0, dtype=tf.int32)

s0 = tf.placeholder(tf.float32, shape=[None, 16, 16, 6], name = "state0")
s1 = tf.placeholder(tf.float32, shape=[None, 16, 16, 6], name = "state1")
out = model(s0, s1)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    t1 = sess.run(t1)
    t2 = sess.run(t2)
    out = sess.run(out, {s0: t1, s1: t2})
    print(out.shape)

TypeError: call() takes 2 positional arguments but 3 were given