In [1]:
from keras import backend as K
from keras.layers import Layer, Concatenate
import numpy as np
from utils import g

Using TensorFlow backend.


In [2]:
class flowcomputation(Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        return 

    def call(self, inputs):
        
        if type(inputs) is not list or len(inputs) != 4:
            raise Exception('FlowComputation must be called on a list of four tensors [t, I0, I1, Optical_flow].'
                            'Instead got: ' + str(inputs))
        t = inputs[0][0]
        I0 = inputs[1]
        I1 = inputs[2]
        Optical_flow = inputs[3]
        F01 = Optical_flow[:, :, :, :2]
        F10 = Optical_flow[:, :, :, 2:]
        
        Fhat_t0 = (t - 1) * t * F01 + (t * t) * F10
        Fhat_t1 = (1 - t) * (1 - t) * F01 - t * (1 - t) * F10
        return Concatenate(axis = 3)([I0, I1, g(I0,Fhat_t0), g(I1, Fhat_t1), Fhat_t0, Fhat_t1])
        
        
    def compute_output_shape(self, input_shape):
        output_channels = 3 * 4 + 2 * 2
        shape = list(input_shape[1])
        return (shape[0], shape[1], shape[2], output_channels)

In [3]:
class imagecomputation(Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        return 

    def call(self,inputs):
        
        if type(inputs) is not list or len(inputs) != 4:
            raise Exception('ImageComputation must be called on a list of three tensors [t, I0, I1, interpolation_output].'
                            'Instead got: ' + str(inputs))
        
        if K.int_shape(inputs[3])[3] != 5:
            raise Exception('ImageComputation must be with 5 lists in 4th dim'
                            'Instead got: ' + str(K.int_shape(inputs[3])[3]))
        t = inputs[0][0]
        I0 = inputs[1]
        I1 = inputs[2]
        Ft0 = inputs[3][:, :, :, :2]
        Ft1 = inputs[3][:, :, :, 2:4]
        Vt0 = inputs[3][:, :, :, 4:]
        Vt1 = 1 - Vt0
        It = (1 - t)* Vt0 * g(I0, Ft0) + t * Vt1 * g(I1, Ft1)
        Z = (1 - t) * Vt0 + t * Vt1
        It = It / Z
        return It
        
        
    def compute_output_shape(self, input_shapes):
        shape = list(input_shapes[1])
        return (tuple(shape))