In [1]:
import cv2, os , random
import matplotlib.pyplot as plt 
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from time import sleep

h,w = (256,256)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
class FrameGenerator:
    def __init__(self, path , shape):
        self.path = path
        self.shape = shape
        self.folder_paths = []
    def __call__(self):
        if not self.folder_paths:
            for root, _, _ in os.walk(self.path):
                self.folder_paths.append(root)
        
        random.shuffle(self.folder_paths)
        
        for folder_path in self.folder_paths:
            _, _, files = next(os.walk(folder_path))
            files.sort()  # Sort files in ascending order
            for i in range(1, len(files) - 1):
                frame_t_minus_1 = os.path.join(folder_path, files[i-1])
                frame_t = os.path.join(folder_path, files[i])
                frame_t_plus_1 = os.path.join(folder_path, files[i+1])
                
                ft_minus = cv2.imread(frame_t_minus_1)
                ft_minus = cv2.resize(ft_minus, (self.shape[1],self.shape[0]))
                fi = cv2.imread(frame_t)
                fi = cv2.resize(fi, (self.shape[1],self.shape[0]))
                ft_plus = cv2.imread(frame_t_plus_1)
                ft_plus = cv2.resize(ft_plus, (self.shape[1],self.shape[0]))
                fs = random_translation(fi)
                yield ft_minus, fi, fs, ft_plus


def random_translation(img):
    (h,w) = img.shape[:-1]
    dx = np.random.randint(-w//8,w//8)
    dy = np.random.randint(-h//8,h//8)
    mat = np.array([[1,0,dx],[0,1,dy]],dtype=np.float32)
    return cv2.warpAffine(img, mat, (w,h))

![Alt text](Screenshot_1.png)

In [3]:
from models.models import PWC, UNet, ResNet
from tensorflow.keras.layers import Input ,Concatenate
#sub_models
pwc = PWC()
pwc(tf.zeros((1,h,w,6))) # pass a dummy input to build and load weights 
#freeze paramaters of pwc
for layer in pwc.layers:
    layer.trainable = False
unet = UNet()
resnet = ResNet()
#inputs
f1 = Input((None,None,3)) #previous frame
fs = Input((None,None,3)) #pseudo ground truth
fi = Input((None,None,3)) #original frame
f2 = Input((None,None,3)) #next frame
#preparing input for pwc
f1s = Concatenate(axis=-1)([f1,fs]) #concatenated f1,fs resulting in 6 channels
f2s = Concatenate(axis=-1)([f2,fs]) #concatenated f2,fs resulting in 6 channels
#forward through pwc
fw_minus = pwc(f1s)  #3 channels
fw_plus = pwc(f2s)   #3 channels
#preparing input for unet
fw = Concatenate(axis=-1)([fw_minus,fw_plus])
#forward through unet
fint = unet(fw)
#preparing input for pwc (warping fi to fint)
fiint = Concatenate(axis=-1)([fi,fint])
warped = pwc(fiint)
#concatenating the warped and interpolated frame for input to resnet
fr = Concatenate(axis=-1)([warped,fint])
fout = resnet(fr) 

#defining the entire model
difrint = tf.keras.Model([f1,fs,fi,f2],[fint,fout])
difrint.summary()

TensorFlow Version:  2.12.0
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 input_2 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 input_4 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                              

In [10]:
from tensorflow.keras.applications import VGG19
vgg = VGG19(include_top = False, weights='imagenet')

#optimizers
optimizer = tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999,learning_rate=1e-3)
u_optimizer = tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999,learning_rate=1e-3)
#metrics
difrint_loss_tracker = tf.keras.metrics.Mean(name="difrint_loss")
unet_loss_tracker = tf.keras.metrics.Mean(name="unet_loss")

def loss_function(f1,f2):
        l1 = tf.reduce_mean(tf.abs(f1 - f2))
        #vgg loss
        features_1 = vgg(f1)
        features_2 = vgg(f2)
        features_1 = tf.reshape(features_1,[-1]) #flatten tensor
        features_2 = tf.reshape(features_2,[-1]) #flatten tensor
        vgg_loss = tf.sqrt(tf.reduce_sum(tf.square(features_1 - features_2)))
        return l1 + vgg_loss

@tf.function
def train_step(data):
    f1,fs,fi,f2 = data
    with tf.GradientTape() as tape, tf.GradientTape() as u_tape:
        fint , fout = difrint([f1,fs,fi,f2])
        loss = loss_function(fs,fout)
        u_loss = loss_function(fs,fint)
    grads = tape.gradient(loss, difrint.trainable_weights)
    u_grads = u_tape.gradient(u_loss, unet.trainable_weights)
    
    optimizer.apply_gradients(
        zip(grads,difrint.trainable_weights)
    )
    u_optimizer.apply_gradients(
        zip(u_grads,unet.trainable_weights)
    )

    #update trackers
    difrint_loss_tracker.update_state(loss)
    unet_loss_tracker.update_state(u_loss)
    return{
        'difrint_loss': difrint_loss_tracker.result(),
        'unet_loss': unet_loss_tracker.result() 
        }

In [6]:
data_gen = FrameGenerator('D:/Files/Datasets/DAVIS-data/DAVIS/JPEGImages/480p/',(256,256))
for [f1,fi,fs,f2] in data_gen():
    cv2.imshow('window',cv2.hconcat([f1,fs]))
    sleep(1/30)
    if cv2.waitKey(1) & 0xFF == ord(' '):
        break
cv2.destroyAllWindows()

In [8]:
data_gen = FrameGenerator('D:/Files/Datasets/DAVIS-data/DAVIS/JPEGImages/480p/',(256,256))
output_signature = (tf.TensorSpec(shape = (None, None, 3), dtype = tf.float32),
                    tf.TensorSpec(shape = (None, None, 3), dtype = tf.float32),
                    tf.TensorSpec(shape = (None, None, 3), dtype = tf.float32),
                    tf.TensorSpec(shape = (None, None, 3), dtype = tf.float32)
                    )
train_ds = tf.data.Dataset.from_generator(data_gen,
                                          output_signature = output_signature)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.batch(1)
train_ds = train_ds.cache('./Difrint Cache/').prefetch(buffer_size = AUTOTUNE)

In [None]:
for epoch in range(200):
    for idx,batch in enumerate(train_ds):
        loss_dict = train_step(batch)
        if epoch >= 100:
            optimizer.learning_rate *= 0.1
            u_optimizer.learning_rate *- 0.1
        print(f"\rbatch: {idx}\tdifrint_loss: {loss_dict['difrint_loss']:.4f}\tunet_loss: {loss_dict['unet_loss']:.4f}", end="")


batch: 6	difrint_loss: 1792.6521	unet_loss: 1793.8043