In [103]:
from keras.layers import Conv2D, Conv2DTranspose, Cropping2D
from keras.layers import Input
from keras.models import Model

import keras.backend as K
from keras.applications.vgg19 import VGG19
from keras.models import Model
from keras.layers import Input, Lambda, Add, Concatenate
from keras.optimizers import Adam
from keras.layers.core import Activation
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU, PReLU
from keras.layers.normalization import BatchNormalization
from keras.models import load_model

import math

In [71]:
upscale_factor = 8
output_image_shape = (1080, 1920, input_image_shape[2])
input_image_shape = (output_image_shape[0] // upscale_factor, output_image_shape[1] // upscale_factor, output_image_shape[2])

downscale_times = int(math.log(upscale_factor,2)) - 2
upscale_times = int(math.log(upscale_factor,2)) + 1

In [72]:
input_layer = Input(shape = input_image_shape)

In [73]:
layers = Conv2D(filters = 1, kernel_size = 3, strides = 2, padding = "same")(input_layer)
#layers = Conv2D(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)
#layers = Conv2D(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)

In [74]:
layers = Conv2DTranspose(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)
layers = Conv2DTranspose(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)
layers = Conv2DTranspose(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)
layers = Conv2DTranspose(filters = 1, kernel_size = 3, strides = 2, padding = "same")(layers)

In [75]:
model = Model(inputs = input_layer, outputs = layers)

In [76]:
output_shape = model.output_shape

In [78]:
height_crop = (output_shape[1] - output_image_shape[0]) // 2
width_crop  = (output_shape[2] - output_image_shape[1]) // 2

In [79]:
(height_crop, width_crop)

(4, 0)

In [80]:
layers = Cropping2D(cropping=(height_crop, width_crop))(layers)

In [81]:
model = Model(inputs = input_layer, outputs = layers)

In [82]:
model.output_shape

(None, 1080, 1920, 1)

In [83]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 135, 240, 3)       0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 68, 120, 1)        28        
_________________________________________________________________
conv2d_transpose_17 (Conv2DT (None, 136, 240, 1)       10        
_________________________________________________________________
conv2d_transpose_18 (Conv2DT (None, 272, 480, 1)       10        
_________________________________________________________________
conv2d_transpose_19 (Conv2DT (None, 544, 960, 1)       10        
_________________________________________________________________
conv2d_transpose_20 (Conv2DT (None, 1088, 1920, 1)     10        
_________________________________________________________________
cropping2d_4 (Cropping2D)    (None, 1080, 1920, 1)     0         
Total para

In [69]:
output_image_shape

(1080, 1920, 3)

In [70]:
downscale_times

1

In [247]:
def same_size_unetish_block(model, kernal_size, filters, strides, name):
    
    model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same", name = name+'/Conv2D')(model)
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2], name = name+'/PReLU')(model)
    
    return model

def downsampling_unetish_block(model, kernel_size, filters, strides, name):
    
    model = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same", name = name+'/Conv2D')(model)
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2], name = name+'/PReLU')(model)
    
    return model

def upsampling_unetish_block(model, kernel_size, filters, strides, name):
    
    model = Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = strides, padding = "same", name = name+'/Conv2DTrans')(model)
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2], name = name+'/PReLU')(model)
    
    return model

def find_crop_shape(input_layer, output_down, output_up):
    # this way the shapes of the output is calcualted
    down_shape = Model(inputs = input_layer, outputs = output_down).output_shape
    up_shape   = Model(inputs = input_layer, outputs = output_up).output_shape
    
    height_diff = up_shape[1] - down_shape[1]
    width_diff  = up_shape[2] - down_shape[2]
    
    top_crop = height_diff // 2
    left_crop = width_diff // 2
    
    crop_shapes = ((top_crop, height_diff - top_crop),(left_crop, width_diff - left_crop))
    
    return crop_shapes
    
    
def concatenate_layers(input_layer, output_down, output_up, name):
    
    crop_shapes = find_crop_shape(input_layer, output_down, output_up)
    
    model = Cropping2D(cropping=crop_shapes, name = name+"/Cropping2D")(output_up)
    model = Concatenate(axis = 3, name = name+"/Concatenate")([output_down, model])
    
    return model

In [257]:
def make_upscaler_unetish(output_image_shape, upscale_factor = 4, step_size = 2, downscale_times = 2, initial_step_filter_count = 128): 

    upscale_times = int(math.log(upscale_factor,2)) + downscale_times
    input_image_shape = (output_image_shape[0] // upscale_factor, output_image_shape[1] // upscale_factor, output_image_shape[2])
    
    upscaler_input = Input(shape = input_image_shape, name = 'input')
    
    model = Conv2D(filters = initial_step_filter_count, kernel_size = 9, strides = 1, padding = "same", name = 'initial/Conv2D')(upscaler_input)
    model = PReLU(shared_axes=[1,2], name = 'initial/PReLU')(model)

    upsc_model = model
    
    
    outputs = []
    step_filter_count = initial_step_filter_count
    
    # downsampling steps
    for step in range(downscale_times):
        
        for index in range(step_size):
            model = same_size_unetish_block(model, 3, step_filter_count, 1, "down/"+str(step)+"/same/"+str(index))
        
        outputs.append(model)
        model = downsampling_unetish_block(model, 3, step_filter_count, 2, "down/"+str(step)+"/down")
        step_filter_count = step_filter_count * 2
    
    
    # steps at the bottom of U
    for index in range(step_size):
        model = same_size_unetish_block(model, 3, step_filter_count, 1, "bottom/"+str(step)+"/same"+str(index))
    
    
    down_outputs_len = len(outputs)
    
    # upsampling steps
    for step in range(upscale_times):
        model = upsampling_unetish_block(model, 3, step_filter_count, 2, "up/"+str(step)+"/up")
        
        if step < down_outputs_len:
            model = concatenate_layers(upscaler_input, outputs[down_outputs_len - step - 1], model, "up/"+str(step)+"/concat")
            step_filter_count = step_filter_count // 2
            
        for index in range(step_size):
            model = same_size_unetish_block(model, 3, step_filter_count, 1, "up/"+str(step)+"/same/"+str(index))
    
    
    model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same", name = 'final/Conv2D')(model)
    model = Activation('tanh', name = 'final/tanh')(model)
    
    
    # making sure the output is of the right shape
    
    # to extract the output shape
    output_shape = Model(inputs = upscaler_input, outputs = model).output_shape

    height_diff = output_shape[1] - output_image_shape[0]
    width_diff  = output_shape[2] - output_image_shape[1]
    
    top_crop = height_diff // 2
    left_crop = width_diff // 2
    
    crop_shapes = ((top_crop, height_diff - top_crop),(left_crop, width_diff - left_crop))
    
    model = Cropping2D(cropping=crop_shapes, name = 'final/Cropping2D')(model)
    
    upscaler_model = Model(inputs = upscaler_input, outputs = model)

    return upscaler_model

In [258]:
m = make_upscaler_unetish((1080,1920,3), 8)

In [259]:
m.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              (None, 135, 240, 3)  0                                            
__________________________________________________________________________________________________
initial/Conv2D (Conv2D)         (None, 135, 240, 128 31232       input[0][0]                      
__________________________________________________________________________________________________
initial/PReLU (PReLU)           (None, 135, 240, 128 128         initial/Conv2D[0][0]             
__________________________________________________________________________________________________
down/0/same/0/Conv2D (Conv2D)   (None, 135, 240, 128 147584      initial/PReLU[0][0]              
__________________________________________________________________________________________________
down/0/sam

In [260]:
len([1,2,3])

3