In [3]:
import numpy as np
import cv2
from keras.layers import Conv2D, Input, Layer, BatchNormalization, add, Activation
from keras.models import Model
import tensorflow as tf

In [4]:
class InterpolateTensor(Layer):
    
    # (x, y, c) represents the dimesions of HR image
    def __init__(self, x, y, c, **kwargs):
        self.x = x
        self.y = y
        self.channels = c
        self.trainable = False
        super(InterpolateTensor, self).__init__(**kwargs)
    
    def call(self, x):
        return tf.image.resize_bicubic(x, (self.x, self.y))
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.x, self.y, self.channels)
    

In [13]:
class RecCNN:
    
    def __init__(self, c):
        self.channels = c
        
    def model(self, filters, scale, h, w, inp):
        
        hr = InterpolateTensor(h*scale, w*scale,self.channels)(inp)
        hr_duplicate = hr
        conv1 = Conv2D(filters, kernel_size=(3,3), padding='same', activation='relu')(hr)
        for i in range(2, 20):
            conv1 = self.block(conv1, filters)
        out1 = Conv2D(self.channels, kernel_size=(3,3), padding='same', activation='sigmoid')(conv1)
        return add([hr_duplicate, out1])
    
    def block(self, x, filters):
        conv1 = Conv2D(filters, kernel_size=(3,3), padding='same')(x)
        bn = BatchNormalization()(conv1)
        out = Activation('relu')(bn)
        return out


In [15]:
# inp = Input(shape=(64, 64,3))
# rec_cnn = RecCNN(c=3)

# out = rec_cnn.model(filters=64, scale=4, h=64, w=64, inp=inp)
# model_reccnn = Model(inp, out)
# model_reccnn.compile(optimizer='adam', loss='mean_squared_error')
# model_reccnn.summary()


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
interpolate_tensor_4 (Interpola (None, 256, 256, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 256, 256, 64) 1792        interpolate_tensor_4[0][0]       
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 256, 256, 64) 36928       conv2d_23[0][0]                  
__________________________________________________________________________________________________
batch_norm