# Proof of concept notebook for the Frame Booster project
- Author: Kamil Barszczak
- Contact: kamilbarszczak62@gmail.com
- Project: https://github.com/kbarszczak/Frame_booster

In [None]:
import matplotlib.pyplot as plt
import tensorflow_addons.image as tfa_image
import tensorflow as tf
import numpy as np
import pickle
import time
import cv2
import os

from tqdm import tqdm
from tensorflow import keras
from tensorflow.keras.utils import plot_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import preprocessing
from tensorflow.keras import regularizers
from tensorflow.keras import activations
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import backend as K

In [None]:
# model parameters
model_base_path = 'E:/OneDrive - Akademia Górniczo-Hutnicza im. Stanisława Staszica w Krakowie/Programming/Labs/Frame_booster/models/model_v3'
model_name = 'frame_booster'
width, height = 256, 144

# training data parameters
data_base_path = 'E:/Data/Video_Frame_Interpolation/processed/vimeo90k'
data_creation_time = 1682372054
data_train_size = 19000
data_test_size = 1000
data_valid_size = 1000
batch_size = 1 # TODO: change to normal value
epochs = 10

# visualization data parameters
vis_base_path = 'E:/Data/Video_Frame_Interpolation/processed/low_motion'
vis_creation_time = '1682179015'
vis_prefix = 'test'

## Load generators for learning

In [None]:
name_to_features = {
    'image_1': tf.io.FixedLenFeature([], tf.string),
    'image_2': tf.io.FixedLenFeature([], tf.string),
    'image_3': tf.io.FixedLenFeature([], tf.string),
}

In [None]:
def parse_decode_record(record):
    features = tf.io.parse_single_example(record, name_to_features)
    image_1 = tf.io.decode_raw(
        features['image_1'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_1 = tf.reshape(image_1, (height, width, 3))
    
    image_2 = tf.io.decode_raw(
        features['image_2'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_2 = tf.reshape(image_2, (height, width, 3))
    
    image_3 = tf.io.decode_raw(
        features['image_3'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_3 = tf.reshape(image_3, (height, width, 3))
    
    return (image_1, image_3), image_2

In [None]:
def load_generator(base_path, prefix, creation_time, width, height, size):
    path = os.path.join(base_path, f'{prefix}_{size}_{height}x{width}_{creation_time}.tfrecords')
    generator = tf.data.TFRecordDataset(path)

    generator = generator.map(parse_decode_record)
    generator = generator.repeat(epochs)
    generator = generator.prefetch(5)
    generator = generator.shuffle(buffer_size=5 * batch_size)
    generator = generator.batch(batch_size, drop_remainder=True)

    return generator

In [None]:
train_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'train',
    creation_time = str(data_creation_time), 
    width = width, 
    height = height, 
    size = str(data_train_size)
)
test_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'test',
    creation_time = str(data_creation_time),
    width = width, 
    height = height, 
    size = str(data_test_size)
)
valid_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'valid',
    creation_time = str(data_creation_time),
    width = width, 
    height = height, 
    size = str(data_valid_size)
)

## Load data for test visualization

In [None]:
def load_data(base_path, prefix, creation_time, width, height):
    with open(os.path.join(base_path, f'x_{prefix}_{height}x{width}_{creation_time}.pickle'), 'rb') as file:
        x = pickle.load(file)
        
    with open(os.path.join(base_path, f'y_{prefix}_{height}x{width}_{creation_time}.pickle'), 'rb') as file:
        y = pickle.load(file)
        
    return (np.array(x)/255.0).astype('float32'), (np.array(y)/255.0).astype('float32')

In [None]:
x_vis, y_vis = load_data(
    base_path = vis_base_path,
    prefix = vis_prefix,
    creation_time = vis_creation_time,
    width = width, 
    height = height
)
if len(x_vis) % batch_size != 0:
    x_vis = x_vis[0:len(x_vis) - (len(x_vis)%batch_size)]
    y_vis = y_vis[0:len(x_vis)]

In [None]:
vis_generator = ImageDataGenerator().flow(
    x = [x_vis[:, 0, :, :], x_vis[:, 1, :, :]],
    y = y_vis,
    batch_size = batch_size,
    shuffle = False,
)

## Create activation and loss functions

In [None]:
"""
The output activation returns linear values cropped to range from 0 to 1
"""
def output_activation(x):
    return tf.math.minimum(tf.math.maximum(x, 0), 1)

"""
The L1 reconstruction loss function makes the final image colors 
look the same as the colors in the ground-truth image
"""
def l1(y_true, y_pred):
    return K.mean(K.abs(y_true - y_pred))


"""
The L2 reconstruction loss function is similar to L1
"""
def l2(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred))


"""
The PSNR loss function is responsible for boosting the overall quality of the image by reducing its noise 
(The higher the PSNR the better so we return 1 - PSNR because the loss function tries to minimize it)
"""
def psnr(y_true, y_pred):
    psnr = tf.math.reduce_mean(tf.image.psnr(y_true, y_pred, max_val = 1.0))
    return 1 - psnr / 40.0


"""
The SSIM loss function keeps the result image structure
(The more significant the SSIM the more similar the final image is)
"""
def ssim(y_true, y_pred):
    ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
    return 1 - ssim

"""
The final loss is linear combination of ssim, psnr, l1, l2 losses
"""
def loss(y_true, y_pred):
    ssim_ = ssim(y_true, y_pred)
    psnr_ = psnr(y_true, y_pred)
    l1_ = l1(y_true, y_pred)
    l2_ = l2(y_true, y_pred)
    
    return ssim_ + psnr_ + 5.0*l1_ + 10.0*l2_

## Create and build the model

In [None]:
def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor:
    warped = tf.keras.layers.Lambda(
        lambda x: tfa_image.dense_image_warp(*x))((image, -flow))
    return tf.reshape(warped, shape=tf.shape(image))


In [None]:
class BidirectionalFlowEstimation(layers.Layer):
    def __init__(self, filter_count=[32, 64, 64, 16], filter_size=[(3, 3), (3, 3), (1, 1), (1, 1)], activation='relu', regularizer=None, interpolation='bilinear', **kwargs):
        super(BidirectionalFlowEstimation, self).__init__(**kwargs)
        
        # flow 1 -> 2
        self.flow_add_1_2 = layers.Add()
        self.flow_upsample_1_2 = layers.UpSampling2D((2, 2), interpolation=interpolation)
        self.flow_1_2_concat = layers.Concatenate(axis=3)
        self.flow_prediction_1_2 =  keras.Sequential([
            layers.Conv2D(filter_count[0], filter_size[0], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[1], filter_size[1], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[2], filter_size[2], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[3], filter_size[3], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(2, (1, 1), kernel_regularizer=regularizer, padding='same')
        ])
        
        # flow 2 -> 1
        self.flow_add_2_1 = layers.Add()
        self.flow_upsample_2_1 = layers.UpSampling2D((2, 2), interpolation=interpolation)
        self.flow_2_1_concat = layers.Concatenate(axis=3)
        self.flow_prediction_2_1 = keras.Sequential([
            layers.Conv2D(filter_count[0], filter_size[0], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[1], filter_size[1], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[2], filter_size[2], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[3], filter_size[3], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(2, (1, 1), kernel_regularizer=regularizer, padding='same')
        ])
        
        self.filter_count = filter_count
        self.filter_size = filter_size
        self.activation = activation
        self.regularizer = regularizer
        self.interpolation = interpolation
        
    def get_config(self):
        config = super().get_config()
        config.update({
            "filter_count": self.filter_count,
            "filter_size": self.filter_size,
            "activation": self.activation,
            "regularizer": self.regularizer,
            "interpolation": self.interpolation,
        })
        return config

    def call(self, inputs):
        input_1 = inputs[0]
        input_2 = inputs[1]
        flow_1_2 = inputs[2]
        flow_2_1 = inputs[3]

        if type(flow_1_2) == list:
            flow_1_2 = tf.zeros(shape=(batch_size, input_2.shape[1], input_2.shape[2], 2))
            flow_2_1 = tf.zeros(shape=(batch_size, input_2.shape[1], input_2.shape[2], 2))
        
        # input_1 to input_2 flow prediction
        input_1_warped_1 = warp(input_1, flow_1_2)
            
        flow_change_1_2_concat = self.flow_1_2_concat([input_2, input_1_warped_1])
        flow_change_1_2 = self.flow_prediction_1_2(flow_change_1_2_concat)
        
        flow_1_2_changed = self.flow_add_1_2([flow_1_2, flow_change_1_2])
        input_1_warped_2 = warp(input_1, flow_1_2_changed)
        flow_1_2_changed_upsampled = self.flow_upsample_1_2(flow_1_2_changed)
        
        # input_2 to input_1 flow prediction
        input_2_warped_1 = warp(input_2, flow_2_1)
        
        flow_change_2_1_concat = self.flow_2_1_concat([input_1, input_2_warped_1])
        flow_change_2_1 = self.flow_prediction_2_1(flow_change_2_1_concat)

        flow_2_1_changed = self.flow_add_2_1([flow_2_1, flow_change_2_1])
        input_2_warped_2 = warp(input_2, flow_2_1_changed)
        flow_2_1_changed_upsampled = self.flow_upsample_2_1(flow_2_1_changed)
        
        return input_1_warped_2, input_2_warped_2, flow_1_2_changed_upsampled, flow_2_1_changed_upsampled

In [None]:
filter_count=[32, 40, 48]
filter_size=(3, 3)
flow_filter_count=[40, 80, 80, 24] 
flow_filter_size=[(7, 7), (7, 7), (1, 1), (1, 1)]
interpolation='bilinear'
activation= layers.LeakyReLU(0.2)
regularizer=None

# ------------- shared layers
cnn_1st_level_1 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_1")
cnn_1st_level_2 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_2")
cnn_1st_level_3 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_3")
cnn_1st_level_4 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_4")

cnn_2nd_level_1 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_1")
cnn_2nd_level_2 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_2")
cnn_2nd_level_3 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_3")

cnn_3rd_level_1 = layers.Conv2D(filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_3_1")
cnn_3rd_level_2 = layers.Conv2D(filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_3_2")

# ------------- feature extraction left side
input_1_left = layers.Input(shape=(height, width, 3), name="input_left")
input_2_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/2")(input_1_left)
input_3_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/4")(input_2_left)
input_4_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/8")(input_3_left)

# feature extraction for layer 1
input_1_column_1_row_1_left = cnn_1st_level_1(input_1_left)
input_2_column_1_row_2_left = cnn_1st_level_2(input_2_left)
input_3_column_1_row_3_left = cnn_1st_level_3(input_3_left)
input_4_column_1_row_4_left = cnn_1st_level_4(input_4_left)

# downsample layer 1
input_1_column_2_row_2_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/2")(input_1_column_1_row_1_left)
input_2_column_2_row_3_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/4")(input_2_column_1_row_2_left)
input_3_column_2_row_4_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/8")(input_3_column_1_row_3_left)

# feature extraction for layer 2
input_1_column_2_row_2_left = cnn_2nd_level_1(input_1_column_2_row_2_left)
input_2_column_2_row_3_left = cnn_2nd_level_2(input_2_column_2_row_3_left)
input_3_column_2_row_4_left = cnn_2nd_level_3(input_3_column_2_row_4_left)

# downsample layer 2
input_1_column_3_row_3_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_2_1/4")(input_1_column_2_row_2_left)
input_2_column_3_row_4_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_2_1/8")(input_2_column_2_row_3_left)

# feature extraction for layer 3
input_1_column_3_row_3_left = cnn_3rd_level_1(input_1_column_3_row_3_left)
input_2_column_3_row_4_left = cnn_3rd_level_2(input_2_column_3_row_4_left)

# concatenate
concat_2nd_left = layers.Concatenate(name="con_left_2")([input_2_column_1_row_2_left, input_1_column_2_row_2_left])
concat_3rd_left = layers.Concatenate(name="con_left_3")([input_3_column_1_row_3_left, input_2_column_2_row_3_left, input_1_column_3_row_3_left])
concat_4th_left = layers.Concatenate(name="con_left_4")([input_4_column_1_row_4_left, input_3_column_2_row_4_left, input_2_column_3_row_4_left])

# output from feature extraction left side: input_1_column_1_row_1_left, concat_2nd_left, concat_3rd_left, concat_4th_left

# ------------- feature extraction left side
input_1_right = layers.Input(shape=(height, width, 3), name="input_right")
input_2_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/2")(input_1_right)
input_3_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/4")(input_2_right)
input_4_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/8")(input_3_right)

# feature extraction for layer 1
input_1_column_1_row_1_right = cnn_1st_level_1(input_1_right)
input_2_column_1_row_2_right = cnn_1st_level_2(input_2_right)
input_3_column_1_row_3_right = cnn_1st_level_3(input_3_right)
input_4_column_1_row_4_right = cnn_1st_level_4(input_4_right)

# downsample layer 1
input_1_column_2_row_2_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/2")(input_1_column_1_row_1_right)
input_2_column_2_row_3_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/4")(input_2_column_1_row_2_right)
input_3_column_2_row_4_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/8")(input_3_column_1_row_3_right)

# feature extraction for layer 2
input_1_column_2_row_2_right = cnn_2nd_level_1(input_1_column_2_row_2_right)
input_2_column_2_row_3_right = cnn_2nd_level_2(input_2_column_2_row_3_right)
input_3_column_2_row_4_right = cnn_2nd_level_3(input_3_column_2_row_4_right)

# downsample layer 2
input_1_column_3_row_3_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_2_1/4")(input_1_column_2_row_2_right)
input_2_column_3_row_4_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_2_1/8")(input_2_column_2_row_3_right)

# feature extraction for layer 3
input_1_column_3_row_3_right = cnn_3rd_level_1(input_1_column_3_row_3_right)
input_2_column_3_row_4_right = cnn_3rd_level_2(input_2_column_3_row_4_right)

# concatenate
concat_2nd_right = layers.Concatenate(name="con_right_2")([input_2_column_1_row_2_right, input_1_column_2_row_2_right])
concat_3rd_right = layers.Concatenate(name="con_right_3")([input_3_column_1_row_3_right, input_2_column_2_row_3_right, input_1_column_3_row_3_right])
concat_4th_right = layers.Concatenate(name="con_right_4")([input_4_column_1_row_4_right, input_3_column_2_row_4_right, input_2_column_3_row_4_right])

# output from feature extraction left side: input_1_column_1_row_1_right, concat_2nd_right, concat_3rd_right, concat_4th_right

# ------------- warping features at each level     
# for flow estimation
bidirectional_flow_estimation_1 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_flow_1st"
)
bidirectional_flow_estimation_2 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_flow_2nd"
)
bidirectional_flow_estimation_3 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_flow_above_3rd"
)

# create empty flow for the coarest level
empty_flow_1 = tf.zeros(shape=(batch_size, height//8, width//8, 2))
empty_flow_2 = tf.zeros(shape=(batch_size, height//8, width//8, 2))

# calculate the flow for each level using the input of current level and the upsampled flow from the level + 1
bfe_4_i1, bfe_4_i2, bfe_4_f_1_2, bfe_4_f_2_1 = bidirectional_flow_estimation_3([concat_4th_left, concat_4th_right, empty_flow_1, empty_flow_2])
bfe_3_i1, bfe_3_i2, bfe_3_f_1_2, bfe_3_f_2_1 = bidirectional_flow_estimation_3([concat_3rd_left, concat_3rd_right, bfe_4_f_1_2, bfe_4_f_2_1])
bfe_2_i1, bfe_2_i2, bfe_2_f_1_2, bfe_2_f_2_1 = bidirectional_flow_estimation_2([concat_2nd_left, concat_2nd_right, bfe_3_f_1_2, bfe_3_f_2_1])
bfe_1_i1, bfe_1_i2, _, _ = bidirectional_flow_estimation_1([input_1_column_1_row_1_left, input_1_column_1_row_1_right, bfe_2_f_1_2, bfe_2_f_2_1])

# returned by warping: (bfe_1_i1, bfe_2_i1, bfe_3_i1, bfe_4_i1), (bfe_1_i2, bfe_2_i2, bfe_3_i2, bfe_4_i2)

# ------------- warped features fusion   
        
# merge 4th level
added_4th_level = layers.Add()([bfe_4_i1, bfe_4_i2])
cnn_4th_1 = layers.Conv2D(filter_count[0] + filter_count[1] + filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_4th_level)
cnn_4th_2 = layers.Conv2D(filter_count[0] + filter_count[1] + filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_4th_1)
up_4th = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_4th_2)

# merge 3rd level
added_3rd_level = layers.Add()([bfe_3_i1, bfe_3_i2, up_4th])
cnn_3rd_1 = layers.Conv2D(filter_count[0] + filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_3rd_level)
cnn_3rd_2 = layers.Conv2D(filter_count[0] + filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_3rd_1)
up_3rd = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_3rd_2)

# merge 2nd level
added_2nd_level = layers.Add()([bfe_2_i1, bfe_2_i2, up_3rd])
cnn_2nd_1 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_2nd_level)
cnn_2nd_2 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_2nd_1)
up_2nd = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_2nd_2)

# merge 1st level
added_1st_level = layers.Add()([bfe_1_i1, bfe_1_i2, up_2nd])
x = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_1st_level)
x = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(x)

# net output
outputs = layers.Conv2D(3, (1, 1), activation=output_activation, padding='same')(x)

model = keras.Model(inputs=[input_1_left, input_1_right], outputs=outputs)

model.compile(
    loss = loss,
    optimizer = optimizers.Nadam(0.0005, clipnorm=0.5),
    metrics = [l1, l2, psnr, ssim]
)

model.summary()

## Train the model

In [None]:
def fit(model, train_generator, train_size, valid_generator, valid_size, optimizer, loss, metrics, epochs, batch_size, save_freq=50, log_freq=10, bad_input_limit=5, mode="all"):
    @tf.function
    def train_step(x, y):
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)
            loss_value = loss(y, y_pred)

        if not tf.math.is_nan(loss_value):
            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))

        return loss_value, y_pred
    
    
    @tf.function
    def calc_metrics(x, y):
        y_pred = model(x, training=False)
        return [metric(y, y_pred) for metric in metrics]

    
    @tf.function
    def valid_step(x, y):
        y_pred = model(x, training=False)
        return loss(y, y_pred)
    
    
    def get_loss_metrics_str(loss_value, metrics_values, sep=' '):
        result = 'loss=' + '{:.5f}'.format(loss_value)
        for metric_value, metric in zip(metrics_values, metrics):
            result += f'{sep}{metric.__name__}='+'{:.5f}'.format(metric_value)
        return result

    
    # create dict for a history and a list for bad input
    history = {metric.__name__: [] for metric in metrics}
    history = history | {"val_" + metric.__name__: [] for metric in metrics}
    history[loss.__name__] = []
    history["val_" + loss.__name__] = []
    bad_input = []
    best_loss = None
    
    try:
        # loop over epochs
        for epoch in range(1, epochs+1):
            print(f"Epoch: {epoch}/{epochs}")

            # process the full training dataset
            total_metrics = np.zeros(len(metrics))
            total_loss = 0
            batch_index = 1.0
            for step, record in enumerate(train_generator):
                # extract the data
                x = record[0]
                y = record[1]

                # calculate metrics values, the loss and then apply the gradient change if loss is not NaN
                loss_value, y_pred = train_step(x, y)
                metrics_values = np.array(calc_metrics(x, y))
                
                # is loss was NaN save the bad input and get to next iteration
                if tf.math.is_nan(loss_value):
                    print(f"Loss NaN detected at epoch {epoch} in step {(step+1)}. Wrong data saved to bad_input list")
                    bad_input.append((x, y, y_pred))
                    if len(bad_input) >= bad_input_limit:
                        raise OverflowError(f"The bad_input limit of {bad_input_limit} was reached")
                    continue

                # save the loss & metrics values
                total_loss += loss_value
                total_metrics += metrics_values

                # save the model
                if step % save_freq == 0:
                    loss_avg = total_loss / batch_index
                    if mode == "all" or (mode == "best" and (best_loss is None or best_loss > loss_avg)):
                        print("Saving model with loss " + '{:.5f}'.format(loss_avg))
                        model.save(os.path.join(model_base_path, f'{model_name}_{get_loss_metrics_str(loss_avg, total_metrics/batch_index, sep="_")}_e={(epoch+1)}_s={(step+1)}_t={int(time.time())}.h5'))
                        best_loss = loss_avg

                # log the loss
                if step % log_freq == 0:
                    prefix = f'Step {(step+1)}/{(train_size//batch_size)}: '.ljust(15)
                    print(f'{prefix}{get_loss_metrics_str(total_loss/batch_index, total_metrics/batch_index)}')

                # break the learning if the generator is over
                if step >= ((train_size // batch_size) - 1):
                    break
                
                batch_index += 1.0

            # save the loss value
            history[loss.__name__].append(total_loss / batch_index)
            for index, metric in enumerate(metrics):
                history[metric.__name__].append(total_metrics[index] / batch_index)

            # process the full validating dataset
            total_loss = 0
            total_metrics = np.zeros(len(metrics))
            batch_index = 1.0
            for step, record in enumerate(valid_generator):
                x = record[0]
                y = record[1]

                total_loss += valid_step(x, y)
                total_metrics += np.array(calc_metrics(x, y))

                if step >= ((valid_size // batch_size) - 1):
                    break

                batch_index += 1.0

            # log the validation score
            print(f'Validation for epoch {epoch}: {get_loss_metrics_str(total_loss/batch_index, total_metrics/batch_index)}')

            # save the validation score
            history["val_" + loss.__name__].append(total_loss/batch_index)
            for index, metric in enumerate(metrics):
                history["val_" + metric.__name__].append(total_metrics[index]/batch_index)
    except (OverflowError, KeyboardInterrupt) as e:
        print(f"Learning interrupted. Details: '{e}'")
    
    return history, bad_input

In [None]:
history, bad_input = fit(
    model=model, 
    train_generator=train_generator,
    train_size=data_train_size, 
    valid_generator=valid_generator,
    valid_size=data_valid_size,
    optimizer=optimizers.Nadam(0.0005, clipnorm=0.5), 
    loss=loss, 
    metrics=[l1, l2, psnr, ssim],
    epochs=epochs, 
    batch_size=batch_size, 
    save_freq=50,
    log_freq=10,
    bad_input_limit=5,
    mode="best"
)

## Evaluate the training

In [None]:
def norm_0_1(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def norm(data):
    return (data - np.mean(data)) / np.std(data)

In [None]:
def plot_history(history, normalize_method=None, metrics_restrictions=None):
    plt.clf()
    plt.figure(figsize=(25,10))
    
    metrics = list(history.keys())
    metrics = [metric for metric in metrics if "val" not in metric]
    if metrics_restrictions is not None:
        metrics = [metric for metric in metrics if metric in metrics_restrictions]
    
    data = [(index, history[metric], history['val_'+metric], metric) for index, metric in enumerate(metrics)]
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    epochs = range(1, len(data[0][1]) + 1)
    
    for index, value, val_value, metric in data:
        if normalize_method is not None:
            buffer = value.copy()
            buffer.extend(val_value)
            buffer = normalize_method(buffer)
            value = buffer[0:len(epochs)]
            val_value = buffer[len(epochs)::]
        
        plt.plot(epochs, value, colors[index], label=f"Training {metric}")
        plt.plot(epochs, val_value, colors[index]+'--', label=f"Validation {metric}")
        
    plt.xticks(epochs, size=17)    
    plt.yticks(size=17)
    plt.title(f"Comparison of Training and Validation metrics", size=20)
    plt.xlabel('Epochs', size=17)
    plt.ylabel("Metric values", size=17)
    plt.legend(loc='upper right', fontsize=14)
    plt.show()

In [None]:
def evaluate(model, generator, verbose=0):
    result_dict = {}
    result = model.evaluate(generator, verbose=verbose, steps=data_test_size//batch_size)
    for index, metric in enumerate(model.metrics):
        result_dict[metric.name] = result[index]
        print(f'{metric.name.zfill(13).replace("0", " ")}: {np.round(result[index], 4)}')
        
    return result_dict

In [None]:
plot_history(history, normalize_method=norm)

In [None]:
plot_history(history, normalize_method=norm_0_1)

In [None]:
plot_history(history, normalize_method=None, metrics_restrictions=['loss'])

In [None]:
_ = evaluate(model, test_generator, verbose=1)

## Visualize generated frames

In [None]:
predictions = model.predict(vis_generator)

In [None]:
def visualize(generator, predictions, batch, index):
    # verify arguments
    batch_size = generator.batch_size
    assert generator.shuffle == False
    assert batch >= 0 and batch < len(generator)
    assert index >= 0 and index < batch_size
    
    # get neighbours frames
    neighbours, true = generator[batch]
    neighbours = np.array(neighbours)
    neighbours = neighbours[:, index, :, :, :]

    # get true and predicted frames
    true = np.array(true)[index]
    predicted = predictions[batch_size*batch + index]
    
    # mark true edges on predicted frame
    true_edges = cv2.cvtColor(true, cv2.COLOR_RGB2GRAY)
    true_edges = cv2.GaussianBlur(true_edges, (3, 3), 1)
    true_edges = cv2.medianBlur(true_edges, 3)
    true_edges = cv2.Canny((true_edges*255).astype('uint8'), 50, 100)
    predicted_marked = predicted.copy()
    predicted_marked[true_edges != 0] = (1, 1, 1)

    # plot images
    f, ax = plt.subplots(3, 2)
    f.set_size_inches(20, 20)

    ax[0][0].set_title("First frame")
    ax[0][0].set_xticks([])
    ax[0][0].set_yticks([])
    ax[0][0].imshow(neighbours[0])
    
    ax[1][0].set_title("Predicted frame")
    ax[1][0].set_xticks([])
    ax[1][0].set_yticks([])
    ax[1][0].imshow(predicted)
    
    ax[2][0].set_title("Second frame")
    ax[2][0].set_xticks([])
    ax[2][0].set_yticks([])
    ax[2][0].imshow(neighbours[1])
    
    ax[0][1].set_title("Predicted and Ground-truth difference")
    ax[0][1].set_xticks([])
    ax[0][1].set_yticks([])
    ax[0][1].imshow(cv2.absdiff(predicted, true))
    
    ax[1][1].set_title("Ground-truth frame")
    ax[1][1].set_xticks([])
    ax[1][1].set_yticks([])
    ax[1][1].imshow(true)
    
    ax[2][1].set_title("Edge shift")
    ax[2][1].set_xticks([])
    ax[2][1].set_yticks([])
    ax[2][1].imshow(predicted_marked)

In [None]:
visualize(vis_generator, predictions, 0, 3)

In [None]:
visualize(vis_generator, predictions, 2, 0)

In [None]:
visualize(vis_generator, predictions, 2, 2)

In [None]:
visualize(vis_generator, predictions, 4, 2)

In [None]:
visualize(vis_generator, predictions, 5, 3)

In [None]:
visualize(vis_generator, predictions, 12, 3)

## Visualize the filters of each Conv2D layer in the model

In [None]:
def append_image(image, append_image, row, col, margin):
    horizontal_start = row * height + row * margin
    horizontal_end = horizontal_start + height
    vertical_start = col * width + col * margin
    vertical_end = vertical_start + width
    image[horizontal_start : horizontal_end, vertical_start : vertical_end, : ] = append_image
    return image

In [None]:
@tf.function
def gradient_ascent_step(model, img, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(img)
        activation = model([img, img])
        filter_activation = activation[:, 2:-2, 2:-2, filter_index]
        loss = tf.reduce_mean(filter_activation)
        
    grads = tape.gradient(loss, img)
    grads = tf.math.l2_normalize(grads)
    img += learning_rate * grads
    
    return loss, img


def visualize_filter(model, layer_name, filter_index, size, iterations=20, learning_rate=10.0):
    img = tf.random.uniform((batch_size, 144, 256, 3))
    pat_model = keras.Model(inputs=model.inputs, outputs=model.get_layer(name=layer_name).output) 
    
    for iteration in range(iterations):
        loss, img = gradient_ascent_step(pat_model, img, filter_index, learning_rate)

    return deprocess_image(img[0].numpy())


def deprocess_image(img):
    img -= img.mean()
    img /= img.std() + 1e-5
    img *= 0.15

    img += 0.5
    img = np.clip(img, 0, 1)

    img *= 255
    img = np.clip(img, 0, 255).astype("uint8")
    
    return img


def rows_cols(value):
    assert value >= 1
    
    rows = 1
    cols = value
    for i in range(2, value//2+1):
        if value % i == 0:
            if np.abs(i - int(value / i)) < np.abs(rows - cols):
                rows = i
                cols = int(value / i)
    
    return rows, cols

In [None]:
def visualize_filters(model, margin=3):
    layers = [layer.name for layer in model.layers if type(layer) == tf.keras.layers.Conv2D]
    layers.sort()
    results = []
    for layer_name in layers: 
        print("Layer name: " + layer_name)
        f_count = model.get_layer(layer_name).filters
        rows, cols = rows_cols(f_count)
        result = np.zeros((rows * height + (rows-1) * margin, cols * width + (cols-1) * margin, 3), dtype=np.uint8)
        
        for index in tqdm(range(rows*cols)):
            i, j = index//cols, index % cols
            filter_img = visualize_filter(model, layer_name, j + (i * cols), size=(height, width))
            result = append_image(result, filter_img, i, j, margin)

        plt.figure(figsize=(height / 2, width / 2))
        plt.xticks([])
        plt.yticks([])
        plt.imshow(result)
        plt.show()
        results.append(result)
        
    return results

In [None]:
filters = visualize_filters(model)

## Visualize the predicted flow at each level

In [None]:
def deprocess_flow(u, v):
     # calculate both the magnitude and the angle
    magnitude, angle = cv2.cartToPolar(v, u)
    width = u.shape[0]
    height = u.shape[1]

    # create array for the result image
    result = np.zeros((width, height, 3), dtype='uint8')
    result[:, :, 0] = angle * 90 / np.pi
    result[:, :, 1] = magnitude / np.max(magnitude) * 255
    result[:, :, 2] = 255
    result = cv2.cvtColor(result, cv2.COLOR_HSV2RGB);

    return result

def visualize_flow(model, x, size=(width, height)):
    _, _, flow_1_2, flow_2_1 = model(x)
    flow_1_2 = np.squeeze(flow_1_2.numpy())
    flow_2_1 = np.squeeze(flow_2_1.numpy())
    
    flow_1_2 = deprocess_flow(flow_1_2[:, :, 0], flow_1_2[:, :, 1])
    flow_2_1 = deprocess_flow(flow_2_1[:, :, 0], flow_2_1[:, :, 1])
    
    flow_1_2 = cv2.resize(flow_1_2, size, interpolation=cv2.INTER_CUBIC)
    flow_2_1 = cv2.resize(flow_2_1, size, interpolation=cv2.INTER_CUBIC)

    return flow_1_2, flow_2_1    

In [None]:
def visualize_flows(model, x, margin=3):    
    layers = [layer.name for layer in model.layers if type(layer) == BidirectionalFlowEstimation]
    layers.append("bi_flow_above_3rd")
    layers.sort()
    
    index = 1
    shared_layer_index = 1
    rows, cols = 5, 2
    result = np.zeros((rows * height + (rows-1) * margin, cols * width + (cols-1) * margin, 3), dtype=np.uint8)
    result = append_image(result, (x[0] * 255).astype('uint8'), 0, 0, margin)
    result = append_image(result, (x[1] * 255).astype('uint8'), 0, 1, margin)
    for layer_name in layers:
        if layer_name == "bi_flow_above_3rd":
            pat_model = keras.Model(inputs=model.inputs, outputs=model.get_layer(layer_name).get_output_at(shared_layer_index))
            shared_layer_index -= 1
        else:
            pat_model = keras.Model(inputs=model.inputs, outputs=model.get_layer(layer_name).output) 
        flow_1, flow_2 = visualize_flow(pat_model, x)
        result = append_image(result, flow_1, index, 0, margin)
        result = append_image(result, flow_2, index, 1, margin)
        index += 1

    plt.figure(figsize=(height / 2, width / 2))
    plt.xticks([])
    plt.yticks([])
    plt.imshow(result)
    plt.show()
    
    return result

def visualize_flows_generator(model, generator, batch, index, margin=3):
    neighbours = np.array(generator[batch][0])[:, index, :, :, :]
    first = np.expand_dims(neighbours[0, :, :, :], axis=0)
    second = np.expand_dims(neighbours[1, :, :, :], axis=0)
    return visualize_flows(model, [first, second], margin=margin)

In [None]:
flows = visualize_flows_generator(model, vis_generator, 3, 0)

## Save or load the model

In [None]:
model_creation_time = int(time.time())
model.save(os.path.join(model_base_path, f'{model_name}_0.3576.h5'))

In [None]:
model = keras.models.load_model(
    os.path.join(model_base_path, 'FBNet_0.3574352562427521_5x5.h5'),
    custom_objects = {
        'BidirectionalFlowEstimation': BidirectionalFlowEstimation,
        'output_activation': output_activation,
        'loss': loss,
        'ssim': ssim,
        'psnr': psnr,
        "l2": l2,
        'l1': l1,
    }
)