In [4]:
import numpy as np
from utils import g
from Layers import flowcomputation, imagecomputation
from keras.models import Model
from keras.models import load_model
from keras.optimizers import Adam
from keras.layers import Input, Conv2D, UpSampling2D, Dropout, LeakyReLU, AveragePooling2D, Lambda, ZeroPadding2D
from keras.layers.merge import Concatenate
from keras.applications import VGG16
from keras import backend as K
from DataGenerator import FrameGenerator

Using TensorFlow backend.


In [5]:
class Network(object):
    def __init__(self, shape = (512, 512, 3), vgg_model = None):
        self.shape = shape
        
        self.vgg = self.load_vgg(vgg_model)
        self.flowcomputation_model, self.flowinterpolation_model, self.model, outputs =  self.SlowMo_network()        
        
        self.compile_network(self.flowcomputation_model)
        self.compile_network(self.flowinterpolation_model)
        self.compile_network(self.model, outputs)
    
    def load_vgg(self, vgg_weights):
            
        img = Input(self.shape)
        
        if vgg_weights:
            vgg = VGG16(weights = None, include_top = False)
            vgg.load_weights(vgg_weights, by_name = True)
        
        else:
            vgg = VGG16()
        
        vgg.outputs = [vgg.layers[13].output] 
        out = vgg(img)
        
        model = Model(inputs = img, outputs = out)
        model.trainable = False
        model.compile(loss='mse', optimizer='adam')
        
        return model
    
    
    def UNet(self, I, output_channels, kernel_size = [7,5,3,3,3,3], decoder_extra_input = None, alpha = 0.1, return_activations = False):
        
        def encode_layer(I, filters, kernel_size, downsampling = True):
            
            conv1  = Conv2D(filters, kernel_size, strides = 1, padding = 'same')(I)
            conv2  = Conv2D(filters, kernel_size, strides = 1, padding = 'same')(conv1)
            out = LeakyReLU(alpha)(conv2)

            if downsampling:
                out = AveragePooling2D((2,2), strides = 2)(out) 
            return (out)
        
        def decode_layer(I, concat_I, filters, kernel_size):
            
            upsampled_I = UpSampling2D(size = (2,2))(I)
            concat_I = Concatenate(axis = 3)([concat_I, upsampled_I])
            conv1  = Conv2D(filters, kernel_size, strides = 1, padding = 'same')(concat_I)
            conv2  = Conv2D(filters, kernel_size, strides = 1, padding = 'same')(conv1)
            out = LeakyReLU(alpha)(conv2)
            
            return (out)
            
        encodings = []
        
        encodings.append(I)
        
        encodings.append(encode_layer(encodings[0], 32,  kernel_size[0]))
        encodings.append(encode_layer(encodings[1], 64,  kernel_size[1]))
        encodings.append(encode_layer(encodings[2], 128, kernel_size[2]))
        encodings.append(encode_layer(encodings[3], 256, kernel_size[3]))
        encodings.append(encode_layer(encodings[4], 512, kernel_size[4]))
        encodings.append(encode_layer(encodings[5], 512, kernel_size[5], False))
        
        decodings = encodings[6]
        
        if decoder_extra_input is not None:
            decodings = Concatenate(axis = 3)([decodings,decoder_extra_input])
        
        decodings = decode_layer(decodings, encodings[4], 512, kernel_size[4])
        decodings = decode_layer(decodings, encodings[3], 256, kernel_size[3])
        decodings = decode_layer(decodings, encodings[2], 128, kernel_size[2])
        decodings = decode_layer(decodings, encodings[1], 64,  kernel_size[1])
        decodings = decode_layer(decodings, encodings[0], 32,  kernel_size[0])
        
        out = Conv2D(output_channels, 1, activation = 'relu', padding = 'same')(decodings)
        
        
        if return_activations:
            return [encodings[6],out]
        else:
            return out
        
    def SlowMo_network(self, t = 0.5):
        
        I0 = Input(self.shape)
        I1 = Input(self.shape)
        
        #Optical flow computaion model
        flow_computation_input = Concatenate(axis = 3)([I0,I1])
        encoding, Optical_flow = self.UNet(I = flow_computation_input, output_channels = 4, 
                                           return_activations = True)
        
        flow_computation_model = Model(inputs = [I0, I1], outputs = [Optical_flow, encoding])
        
        
        #Optical flow interpolation model
        t = Input((1,))
        flow = Input(tensor = Optical_flow)
        extra_input = Input(tensor = encoding)
        
        flow_interpolation_output = self.UNet(I = flowcomputation()([t, I0, I1, flow]), 
                                              decoder_extra_input = extra_input, output_channels = 5)
        It = imagecomputation()([t, I0, I1, flow_interpolation_output])
        
        flow_interpolation_model = Model(inputs = [t, I0, I1, flow, extra_input], output = It)
        
        
        #Complete model
        Optical_flow, encoding = flow_computation_model([I0, I1])
        It = flow_interpolation_model([t, I0, I1, Optical_flow, encoding])
        
        model = Model(inputs = [t, I0, I1], output = It)
        
        #Extracting outputs to be used in loss function
        Ft0, Ft1, Vt0 = flow_interpolation_output[:, :, :, :2],flow_interpolation_output[:, :, :, 2:4], flow_interpolation_output[:, :, :, 4:]
        F01, F10 = Optical_flow[:, :, :, :2], Optical_flow[:, :, :, 2:]
        
        return [flow_computation_model,  flow_interpolation_model, model, (I0, I1, F01, F10, Ft0, Ft1)]
    
    
    def compile_network(self, model, supp = None, lr = 0.0001):
                
        if supp is not None:
            model.compile(optimizer = Adam(lr=lr), loss = self.loss_total(supp), metrics=[self.PSNR])
        
        else:
            model.compile(optimizer = Adam(lr=lr), loss = 'mse')
        
    def loss_total(self, supp):
        I0, I1, F01, F10, Ft0, Ft1 = supp
        
        def l1(a, b):
            return K.mean(K.abs(b - a), axis = [1, 2, 3])
        
        def l2(a, b):
            return K.mean(K.square(b - a), axis = [1, 2, 3])
        
        def wrapping_loss(I0, I1, It, F01, F10, Ft0, Ft1):
            return l1(I0, g(I1, F01)) + l1(I1, g(I0, F10)) + l1(It, g(I0, Ft0)) + l1(It, g(I1, Ft1))
    
        def precptual_loss(y_true, y_pred):
            vgg_true = self.vgg(y_true)
            vgg_pred = self.vgg(y_pred)
            return l2(vgg_true, vgg_pred)
        
        def reconstruction_loss(y_true, y_pred):
            return l1(y_true, y_pred)
        
        def smoothness_loss(F01, F10):
            deltaF01 = K.mean(K.abs(F01[:, 1:, :, :] - F01[:, :-1, :, :]) , axis=[1, 2, 3]) + K.mean(
                K.abs(F01[:, :, 1:, :] - F01[:, :, :-1, :]), axis=[1, 2,3])
            
            deltaF10 = K.mean(K.abs(F10[:, 1:, :, :] - F10[:, :-1, :, :]), axis=[1, 2, 3]) + K.mean(
                K.abs(F10[:, :, 1:, :] - F10[:, :, :-1, :]), axis=[1, 2, 3])
        
            return 0.5 * (deltaF01 + deltaF10)
        
        def loss(y_true, y_pred):
            lr = reconstruction_loss(y_true, y_pred)
            lw = wrapping_loss(I0, I1, y_pred, F01, F10, Ft0, Ft1)
            lp = precptual_loss(y_true, y_pred)
            ls = smoothness_loss(F01, F10)   
            return 0.8 * lr + 0.005 * lp +0.4 * lw + 1 * ls
        
        return loss
    
    def PSNR(self, y_true, y_pred):
        return - 10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0)    
       
    def summary(self, model = 0):
        if  (model = 0)
            print(self.model.summary())
            
        elif(model = 1)
            print(self.flowcomputation_model.summary())
            
        elif(model = 2)
            print(self.flowinterpolation_model.summary())
        
    def predict(self, frames, timestamps, **kwargs):
        Optical_flow = self.flow_model.predict(frames)
        
        encoding = np.array(Optical_flow[1])
        flow = np.array(Optical_flow[0])

        frames = []
        for t in timestamps:
            t = np.full((1,1), t)
            frames.append(self.interpolation_model.predict([t, frames[0], frames[1]]))
        
        return frames
        
    def load(self, filepath, lr = 0.0001):

        self.flowcomputation_model, self.flowinterpolation_model, self.model, supplementary = self.SlowMo_network()
        self.compile_network(self.model, supplementary, lr = lr)
        self.compile_network(self.flowcomputation_model, lr = lr)
        self.compile_network(self.flowinterpolation_model, lr = lr)
        self.model.load_weights(filepath)
        
    def fit_generator(self, generator, *args, **kwargs):
        self.model.fit_generator(generator,*args, **kwargs)