In [19]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Lambda, Multiply, Add, LeakyReLU, Dropout, Conv2D, MaxPool2D, GlobalAveragePooling2D, Flatten, Conv2DTranspose
import numpy as np

In [2]:
class block(tf.keras.layers.Layer):
    def __init__(self, channels, downsample = False, kernal_size = 3):
        super(block, self).__init__()
        self._channels = channels
        self._strides = [2, 1] if downsample else [1, 1]
        self._down_sample = downsample
        kernal_init = tf.keras.initializers.he_normal()
        self.conv_1 = Conv2D(self._channels, kernal_size, strides=self._strides[0], padding="same", kernel_initializer= kernal_init)
        self.bn_1 = BatchNormalization()
        self.conv_2 = Conv2D(self._channels, kernal_size, strides=self._strides[1], padding="same", kernel_initializer= kernal_init)
        self.bn_2 = BatchNormalization()
        self.merge = Lambda(lambda x : Add()([x[0], x[1]]), name= 'z')
        if self._down_sample:
            self.conv_ds = Conv2D(self._channels, (1,1), strides=2, padding="same", kernel_initializer= kernal_init)
            self.bn_ds = BatchNormalization()
            
    def call(self, inputs):
        res = inputs
        x = self.conv_1(inputs)
        x = self.bn_1(x)
        x = tf.nn.relu(x)
        x = self.conv_2(x)
        x = self.bn_2(x)
        if self._down_sample:
            res = self.conv_ds(res)
            res = self.bn_ds(res)
        x = self.merge([x, res])   #to avoide vanishing gradient and exploding gradient
        out = tf.nn.relu(x)
        return out

In [5]:
class ConvLeaky(tf.keras.layers.Layer):
    def __init__(self, out_channels):
        super(ConvLeaky, self).__init__()
        kernel_size = 3
        kernal_init = tf.keras.initializers.he_normal()
        self.conv1 = Conv2D(out_channels, kernel_size, strides = 1, padding = 'same', kernel_initializer= kernal_init)
        self.bn_1 = BatchNormalization()
        self.conv2 = Conv2D(out_channels, kernel_size, strides = 1, padding = 'same', kernel_initializer= kernal_init)
        self.bn_2 = BatchNormalization()
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn_1(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv2(x)
        x = self.bn_2(x)
        x = tf.nn.leaky_relu(x)
        return x

In [147]:
class FNetBlock(tf.keras.layers.Layer):
    def __init__(self, out_channels, typ):
        super(FNetBlock, self).__init__()
        self.convleaky = ConvLeaky(out_channels)
        if typ == "maxpool":
            self.out = Lambda(lambda x: tf.nn.max_pool2d(x, ksize=(2, 2), strides=(2, 2), padding="SAME"))
        elif typ == "bilinear":
            self.out = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        else:
            raise Exception('typ does not match')
    
    def call(self, inputs):
        x = self.convleaky(inputs)
        x = self.out(x)
        return x


In [24]:
#Testing FnetBlock
l = FnetBlock(10, "bilinear")
a = np.ones((1, 20, 20, 3))
l(a).shape

TensorShape([1, 40, 40, 10])

In [125]:
class SRNet(tf.keras.Model):
    def __init__(self):
        super(SRNet, self).__init__()
        kernel_size = 3
        kernal_init = tf.keras.initializers.he_normal()
        self.conv_in = Conv2D(64, kernel_size, strides = 1, padding = 'same', kernel_initializer= kernal_init)
        self.resBlock = tf.keras.Sequential([block(64) for i in range(0, 10)])
        self.deconv1 = Conv2DTranspose(64, kernel_size, strides=(2, 2), padding = 'same')
        self.deconv2 = Conv2DTranspose(64, kernel_size, strides=(2, 2), padding = 'valid', output_padding = 1)
        self.out_conv = Conv2D(3, kernel_size)
        
    def call(self, inputs):
        x = self.conv_in(inputs)
        x = self.resBlock(x)
        x = self.deconv1(x)
        x = tf.nn.relu(x)
        x = self.deconv2(x)
        x = tf.nn.relu(x)
        x = self.out_conv(x)
        return x

In [127]:
#SRNet testing
m = SRNet()
i = np.ones((1, 20, 20, 3))
m(i).shape

TensorShape([1, 80, 80, 3])

In [128]:
m.summary()

Model: "sr_net_34"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_866 (Conv2D)         multiple                  1792      
                                                                 
 sequential_37 (Sequential)  (1, 20, 20, 64)           743680    
                                                                 
 conv2d_transpose_65 (Conv2D  multiple                 36928     
 Transpose)                                                      
                                                                 
 conv2d_transpose_66 (Conv2D  multiple                 36928     
 Transpose)                                                      
                                                                 
 conv2d_887 (Conv2D)         multiple                  1731      
                                                                 
Total params: 821,059
Trainable params: 818,499
Non-train

In [157]:
class FNet(tf.keras.Model):
    def __init__(self):
        super(FNet, self).__init__()
        #input dim = 6
        kernel_size = 3
        kernal_init = tf.keras.initializers.he_normal()
        self.convpool_1 = FNetBlock(32, "maxpool")
        self.convpool_2 = FNetBlock(64, "maxpool")
        self.convpool_3 = FNetBlock(128, "maxpool")
        self.convbin_1 = FNetBlock(256, "bilinear")
        self.convbin_2 = FNetBlock(128, "bilinear")
        self.convbin_3 = FNetBlock(64, "bilinear")
        self.conv1 = Conv2D(32, kernel_size = 3, strides = 1, padding = 'same', kernel_initializer= kernal_init)
        self.conv2 = Conv2D(2, kernel_size = 3, strides = 1, padding = 'same', kernel_initializer= kernal_init)
    def call(self, inputs):
        x = self.convpool_1(inputs)
        x = self.convpool_2(x)
        x = self.convpool_3(x)
        x = self.convbin_1(x)
        x = self.convbin_2(x)
        x = self.convbin_3(x)
        x = self.conv1(x)
        x = tf.nn.leaky_relu(x)
        x = self.conv2(x)
        x = tf.nn.tanh(x)
        self.x = x
        return x

In [159]:
#FNet test
f = FNet()
x = np.ones((1, 40, 40, 6))
x = f(x)
x.shape

TensorShape([1, 40, 40, 2])

In [140]:
class SpaceToDepth(tf.keras.Model)

TensorShape([1, 16, 16, 2])