In [1]:
import tensorflow as tf
from tensorflow import layers

# Operations

In [18]:
OPS = {
    'none': lambda C, stride: Zero(stride),
    'avg_pool_3x3': lambda C, stride: layers.AveragePooling2D(3, strides=stride, padding='same'),
    'max_pool_3x3': lambda C, stride: layers.MaxPooling2D(3, strides=stride, padding='same'),
    'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
    'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 'same'),
    'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 'same'),
    'sep_conv_7x7': lambda C, stride: SepConv(C, C, 7, stride, 'same'),
    'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 'same', 2),
    'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 'same', 2),
    'conv_7x1_1x7': lambda C, stride: Conv_7x1_1x7(C, stride)
}

In [3]:
class ReLUConvBN(tf.keras.layers.Layer):
    """Applies ReLU, Conv and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding='same'):
        """Initializes the operation

        Args:
            C_in (int): no of kernels in
            C_out (int): no of kernels out
            kernel_size (int): size of kernel
            stride (int): stride
            padding (int): padding
            affine (bool), optional): Defaults to True.
        """
        super(ReLUConvBN, self).__init__()
        self.relu = tf.nn.relu
        self.conv = layers.Conv2D(filters=C_out,
                                  kernel_size=kernel_size,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn = layers.BatchNormalization()
        
    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

In [4]:
class DilConv(tf.keras.layers.Layer):
    """Applies ReLU, Conv with dilation and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super(DilConv, self).__init__()
        self.relu = tf.nn.relu
        self.dil_conv = layers.Conv2D(filters=C_out, 
                                      kernel_size=kernel_size, 
                                      strides=stride, 
                                      padding=padding,
                                      dilation_rate=dilation,
                                      use_bias=False
                                     )
        self.conv = layers.Conv2D(filters=C_out,
                                  kernel_size=1,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn = layers.BatchNormalization()

    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.dil_conv(x)
        x = self.conv(x)
        x = self.bn(x)
        return x

In [5]:
class SepConv(tf.keras.layers.Layer):
    """Applies ReLU, Sep Conv with dilation and BatchNormalisation operation
    """

    def __init__(self, C_in, C_out, kernel_size, stride, padding):
        super(SepConv, self).__init__()
        self.relu = tf.nn.relu
        self.conv1 = layers.Conv2D(filters=C_in,
                                  kernel_size=kernel_size,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.conv2 = layers.Conv2D(filters=C_in,
                                  kernel_size=1,
                                  strides=stride,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn1 = layers.BatchNormalization()
        self.conv3 = layers.Conv2D(filters=C_in,
                                  kernel_size=kernel_size,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.conv4 = layers.Conv2D(filters=C_out,
                                  kernel_size=1,
                                  strides=1,
                                  padding='same',
                                  use_bias=False
                                 )
        self.bn2 = layers.BatchNormalization()

    def call(self, x):
        """Applies the ReLU, Conv, BN to input

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: array or tensor with operations applied on it
        """
        x = self.relu(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.bn2(x)
        return x


In [6]:
class Identity(tf.keras.layers.Layer):
    """Apply the identity operation
    """

    def __init__(self):
        super(Identity, self).__init__()
        pass

    def call(self, x):
        return x

In [7]:
class Zero(tf.keras.layers.Layer):
    """Makes array element zero with given stride
    """

    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride

    def call(self, x):
        if self.stride == 1:
            return tf.multiply(x, 0)
        return tf.multiply(x[:, ::self.stride, ::self.stride, :], 0)

In [8]:
class Conv_7x1_1x7(tf.keras.layers.Layer):

  def __init__(self, C, stride):
    super(Conv_7x1_1x7, self).__init__()
    self.relu = tf.nn.relu
    self.conv1 = layers.Conv2D(filters=C,
                             kernel_size=(1,7),
                             strides=(1, stride),
                             padding='same',
                             use_bias=False)
    self.conv2 = layers.Conv2D(filters=C,
                             kernel_size=(7,1),
                             strides=(stride, 1),
                             padding='same',
                             use_bias=False)
    self.bn = layers.BatchNormalization()

  def call(self, x):
    x = self.relu(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.bn(x)
    return x 

In [9]:
class FactorizedReduce(tf.keras.layers.Layer):
    """Applies ReLU, conv with stride=2 and c_out/2 
    """

    def __init__(self, C_in, C_out):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.relu = tf.nn.relu
        self.conv_1 = layers.Conv2D(filters=C_out//2, 
                                    kernel_size=1,
                                    strides=2, 
                                    padding='same', 
                                    use_bias=False)
        self.conv_2 = layers.Conv2D(filters=C_out//2, 
                                    kernel_size=1,
                                    strides=2, 
                                    padding='same', 
                                    use_bias=False)
        self.bn = layers.BatchNormalization()

    def call(self, x):
        """concats conv and Batch normalise them

        Args:
            x (tensor): array or tensor (can be image)

        Returns:
            tensor: tensor of operations on input
        """
        x = self.relu(x)
        out = tf.concat([self.conv_1(x), self.conv_2(x[:, 1:, 1:, :])], axis=3)
        out = self.bn(out)
        return out


In [10]:
class FactorizedUp(tf.keras.layers.Layer):

  def __init__(self, C_in, C_out):
    super(FactorizedUp, self).__init__()
    self.relu = tf.nn.relu
    self.trans_conv1 = layers.Conv2DTranspose(filters=C_out, 
                                             kernel_size=3,
                                             strides=2,
                                             padding='same',
                                             )
    self.trans_conv2 = layers.Conv2DTranspose(filters=C_out,
                                             kernel_size=3,
                                             strides=2,
                                             padding='same',
                                             )
    
    self.bn = layers.BatchNormalization()

  def call(self, x):
    x = self.relu(x)
    out = (self.trans_conv1(x) + self.trans_conv2(x)) * 0.5
    out = self.bn(out)
    return out

In [11]:
class SkipConnection(tf.keras.layers.Layer):

  def __init__(self, C):    
    super(SkipConnection, self).__init__()
    self.relu = tf.nn.relu
    self.conv = layers.Conv2D(filters=C,
                             kernel_size=3,
                             strides=1,
                             padding='same',
                             use_bias=False)

  def call(self, s0, s1):
    s0 = self.relu(s0)
    s1 = self.relu(s1)
    x = tf.concat([s1, s0], axis=3)
    x = self.conv(x)
    return out 

# Test Here

In [17]:
for op in OPS:
    model = OPS[op](6, 1)

    image = tf.random_uniform([1, 16, 16, 3], 0, 255, seed=0, dtype=tf.bfloat16)

    ip = tf.placeholder(tf.bfloat16, shape=[None, 16, 16, 3], name = "input")
    out = model(ip)
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        image = sess.run(image)
        out = sess.run(out, {ip: image})
        print(op, out.shape)
print('Pass')

('sep_conv_7x7', (1, 16, 16, 6))
('sep_conv_3x3', (1, 16, 16, 6))
('none', (1, 16, 16, 3))
('dil_conv_5x5', (1, 8, 8, 6))
('skip_connect', (1, 16, 16, 3))
('max_pool_3x3', (1, 14, 14, 3))
('dil_conv_3x3', (1, 12, 12, 6))
('avg_pool_3x3', (1, 14, 14, 3))
('conv_7x1_1x7', (1, 16, 16, 6))
('sep_conv_5x5', (1, 16, 16, 6))
Pass


In [19]:
for op in OPS:
    model = OPS[op](6, 1)

    image = tf.random_uniform([1, 16, 16, 3], 0, 255, seed=0, dtype=tf.bfloat16)

    ip = tf.placeholder(tf.bfloat16, shape=[None, 16, 16, 3], name = "input")
    out = model(ip)
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        image = sess.run(image)
        out = sess.run(out, {ip: image})
        print(op, out.shape)
print('Pass')

('sep_conv_7x7', (1, 16, 16, 6))
('sep_conv_3x3', (1, 16, 16, 6))
('none', (1, 16, 16, 3))
('dil_conv_5x5', (1, 16, 16, 6))
('skip_connect', (1, 16, 16, 3))
('max_pool_3x3', (1, 16, 16, 3))
('dil_conv_3x3', (1, 16, 16, 6))
('avg_pool_3x3', (1, 16, 16, 3))
('conv_7x1_1x7', (1, 16, 16, 6))
('sep_conv_5x5', (1, 16, 16, 6))
Pass


Test Skip Connection Here

In [13]:
model = SkipConnection(3)

In [14]:
t1 = tf.random_uniform([1, 16, 16, 6], 0, 255, seed=0, dtype=tf.int32)
t2 = tf.random_uniform([1, 16, 16, 6], 0, 255, seed=0, dtype=tf.int32)

s0 = tf.placeholder(tf.float32, shape=[None, 16, 16, 6], name = "state0")
s1 = tf.placeholder(tf.float32, shape=[None, 16, 16, 6], name = "state1")
out = model(s0, s1)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    t1 = sess.run(t1)
    t2 = sess.run(t2)
    out = sess.run(out, {s0: t1, s1: t2})
    print(out.shape)

TypeError: Fetch argument array([[[[-2.57018547e+01,  7.61903229e+01, -3.45278358e+01,
          -5.72116356e+01,  1.26228523e+01,  8.30161095e+00],
         [-4.53565216e+01,  5.90059395e+01, -7.78427734e+01,
          -4.51087723e+01,  1.96594200e+01, -2.17728271e+01],
         [-3.78554611e+01,  6.83958206e+01, -8.30463867e+01,
          -4.79010162e+01, -7.03880835e+00, -3.36997681e+01],
         ...,
         [-7.99613094e+00,  6.40714874e+01, -3.66161537e+01,
          -2.66654663e+01,  2.62819099e+01, -1.66042209e+00],
         [-2.34229183e+01,  5.17971725e+01, -3.66127586e+01,
          -4.10278931e+01, -9.38137323e-02, -7.03476143e+00],
         [-4.30943031e+01,  6.05036640e+00, -4.04830284e+01,
          -5.66074753e+01, -2.16268654e+01,  4.25936174e+00]],

        [[-4.94909554e+01,  9.09436493e+01, -8.67251892e+01,
          -6.36064949e+01,  2.13529282e+01, -1.88705063e+01],
         [-7.74965134e+01,  9.92491226e+01, -1.10210274e+02,
          -8.35157318e+01,  3.90716057e+01, -2.11505527e+01],
         [-4.00234680e+01, -1.84381218e+01, -8.12903442e+01,
           3.23257232e+00,  6.44881248e+00, -4.53211784e+01],
         ...,
         [-6.22253227e+01,  9.67921753e+01, -1.46833572e+02,
          -2.01705742e+01,  8.15413132e+01, -6.76078110e+01],
         [-5.81512833e+01, -2.11156578e+01, -1.36577301e+02,
           3.11246929e+01,  4.04988518e+01, -8.95357819e+01],
         [-3.54642487e+01, -2.52454448e+00, -4.85342903e+01,
          -4.43812561e+01, -1.30654621e+01,  8.24523747e-01]],

        [[-6.77391968e+01,  5.96715202e+01, -7.74788513e+01,
          -7.69253006e+01,  1.01868165e+00, -1.32284241e+01],
         [-7.40155487e+01,  7.78080750e+01, -1.37188217e+02,
          -3.57185059e+01,  4.85170708e+01, -6.43924713e+01],
         [-9.59733276e+01,  3.96537132e+01, -1.80275833e+02,
          -1.92790070e+01,  6.16730995e+01, -8.68394394e+01],
         ...,
         [-7.42173462e+01,  6.64554138e+01, -1.32984802e+02,
          -2.27691288e+01,  5.10788956e+01, -7.56631851e+01],
         [ 1.49093294e+00,  2.96701527e+00, -5.97532310e+01,
           6.32925339e+01,  3.69947319e+01, -7.26561584e+01],
         [-1.75588169e+01, -3.59368944e+00, -7.10392838e+01,
          -1.64102539e-02,  7.88503551e+00, -2.55195541e+01]],

        ...,

        [[-1.76744385e+01,  5.27106047e-01, -2.63586502e+01,
          -1.53687239e+00, -5.40139580e+00, -2.21775360e+01],
         [-6.89619827e+01, -2.81733761e+01, -8.27778778e+01,
          -1.11997318e+00,  3.14137821e+01, -4.65441971e+01],
         [-8.76259766e+01, -4.28141098e+01, -1.25212433e+02,
           1.96293602e+01,  3.43364067e+01, -8.77583694e+01],
         ...,
         [-9.16067352e+01, -7.26044846e+01, -1.64549179e+02,
           4.91880074e+01,  4.79318504e+01, -1.15561310e+02],
         [-5.19634895e+01,  4.32963943e+01, -1.08599693e+02,
           1.81007957e+01,  6.13245316e+01, -8.18364258e+01],
         [-5.99000931e+01,  1.64973888e+01, -1.03957756e+02,
           2.38437867e+00,  4.57824783e+01, -6.13046608e+01]],

        [[-5.03090820e+01,  7.92624331e+00, -3.39663506e+01,
          -5.61375732e+01, -2.18432922e+01,  1.02051270e+00],
         [-5.36571236e+01, -1.69657879e+01, -6.37024193e+01,
          -2.99647856e+00,  8.96926212e+00, -4.55297852e+01],
         [-6.35462570e+01,  1.03312826e+01, -1.09852364e+02,
           1.12363577e+01,  3.88848343e+01, -7.48301697e+01],
         ...,
         [-8.32152328e+01, -3.77999992e+01, -1.59636032e+02,
           5.76926308e+01,  7.34345551e+01, -1.17269714e+02],
         [-4.79142265e+01,  2.88896503e+01, -8.50314026e+01,
           2.72671967e+01,  8.50743179e+01, -6.45131531e+01],
         [-5.58761101e+01, -1.49972343e+00, -1.13074829e+02,
           4.55414009e+01,  9.10814362e+01, -7.66313095e+01]],

        [[-1.82450161e+01, -3.03277045e-01, -2.67801018e+01,
           1.34086618e+01,  1.81960869e+01, -2.76990852e+01],
         [-6.84590769e+00, -3.85046959e+01, -1.86253510e+01,
           4.57233276e+01, -1.30617580e+01, -5.34230042e+01],
         [-5.31964417e+01,  2.74249020e+01, -6.41893539e+01,
          -1.09134007e+00,  3.58701630e+01, -5.07112617e+01],
         ...,
         [-6.10478172e+01, -3.10197296e+01, -1.02551376e+02,
           4.27776604e+01,  5.05676079e+01, -8.35764618e+01],
         [-2.46421013e+01,  3.74993668e+01, -3.15012741e+01,
           2.46936398e+01,  6.65503922e+01, -4.21986656e+01],
         [-1.94295921e+01, -4.30718384e+01, -5.79539986e+01,
           5.87252884e+01,  4.99326782e+01, -4.89989662e+01]]]],
      dtype=float32) has invalid type <type 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)