# Proof of concept notebook for the Frame Booster project
- Author: Kamil Barszczak
- Contact: kamilbarszczak62@gmail.com
- Project: https://github.com/kbarszczak/Frame_booster

In [1]:
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np
import pickle
import keras
import time
import cv2
import os

from keras import preprocessing
from keras import regularizers
from keras import activations
from keras import optimizers
from keras import callbacks
from keras import layers
from keras import losses
from keras import models

from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
# model parameters
model_base_path = 'E:/OneDrive - Akademia Górniczo-Hutnicza im. Stanisława Staszica w Krakowie/Programming/Labs/Frame_booster/models/model_v3'
model_name = 'frame_booster'
width, height = 256, 144

# training data parameters
data_base_path = 'E:/Data/Video_Frame_Interpolation/processed/vimeo90k'
data_creation_time ='1682372054'
data_train_size = '19000'
data_test_size = '1000'
data_valid_size = '1000'
batch_size = 5
epochs = 10

# visualization data parameters
vis_base_path = 'E:/Data/Video_Frame_Interpolation/processed/low_motion'
vis_creation_time = '1682179015'
vis_prefix = 'test'

## Load generators

In [3]:
name_to_features = {
    'image_1': tf.io.FixedLenFeature([], tf.string),
    'image_2': tf.io.FixedLenFeature([], tf.string),
    'image_3': tf.io.FixedLenFeature([], tf.string),
}

In [4]:
def parse_decode_record(record):
    features = tf.io.parse_single_example(record, name_to_features)
    image_1 = tf.io.decode_raw(
        features['image_1'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_1 = tf.reshape(image_1, (height, width, 3))
    
    image_2 = tf.io.decode_raw(
        features['image_2'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_2 = tf.reshape(image_2, (height, width, 3))
    
    image_3 = tf.io.decode_raw(
        features['image_3'], out_type='float32', little_endian=True, fixed_length=None, name=None
    )
    image_3 = tf.reshape(image_3, (height, width, 3))
    
    return (image_1, image_3), image_2

In [5]:
def load_generator(base_path, prefix, creation_time, width, height, size):
    path = os.path.join(base_path, f'{prefix}_{size}_{height}x{width}_{creation_time}.tfrecords')
    generator = tf.data.TFRecordDataset(path)

    generator = generator.map(parse_decode_record)
    generator = generator.repeat(epochs)
    generator = generator.prefetch(5)
    generator = generator.shuffle(buffer_size=5 * batch_size)
    generator = generator.batch(batch_size, drop_remainder=True)

    return generator

In [6]:
train_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'train',
    creation_time = data_creation_time, 
    width = width, 
    height = height, 
    size = data_train_size
)
test_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'test',
    creation_time = data_creation_time, 
    width = width, 
    height = height, 
    size = data_test_size
)
valid_generator = load_generator(
    base_path = data_base_path, 
    prefix = 'valid',
    creation_time = data_creation_time, 
    width = width, 
    height = height, 
    size = data_valid_size
)

## Load data for test visualization

In [7]:
def load_data(base_path, prefix, creation_time, width, height):
    with open(os.path.join(base_path, f'x_{prefix}_{height}x{width}_{creation_time}.pickle'), 'rb') as file:
        x = pickle.load(file)
        
    with open(os.path.join(base_path, f'y_{prefix}_{height}x{width}_{creation_time}.pickle'), 'rb') as file:
        y = pickle.load(file)
        
    return (np.array(x)/255.0).astype('float32'), (np.array(y)/255.0).astype('float32')

In [8]:
x_vis, y_vis = load_data(
    base_path = vis_base_path,
    prefix = vis_prefix,
    creation_time = vis_creation_time,
    width = width, 
    height = height
)
if len(x_vis) % batch_size != 0:
    x_vis = x_vis[0:len(x_vis) - (len(x_vis)%batch_size)]
    y_vis = y_vis[0:len(x_vis)]

In [9]:
vis_generator = ImageDataGenerator().flow(
    x = [x_vis[:, 0, :, :], x_vis[:, 1, :, :]],
    y = y_vis,
    batch_size = batch_size,
    shuffle = False,
)

## Create loss functions

In [10]:
"""
The L1 reconstruction loss function makes the final image colors 
look the same as the colors in the ground-truth image
"""
def l1(y_true, y_pred):
    return K.mean(K.abs(y_true - y_pred))


"""
The L2 reconstruction loss function is similar to L1
"""
def l2(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred))


"""
The PSNR loss function is responsible for boosting the overall quality of the image by reducing its noise 
(The higher the PSNR the better so we return 1 - PSNR because the loss function tries to minimize it)
"""
def psnr(y_true, y_pred):
    psnr = tf.image.psnr(y_true, y_pred, max_val = 1.0)
    return 1 - psnr / 40.0


"""
The SSIM loss function keeps the result image structure
(The more significant the SSIM the more similar the final image is)
"""
def ssim(y_true, y_pred):
    ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
    return 1 - ssim

def output_activation(x):
    return tf.math.minimum(tf.math.maximum(x, 0), 1)

## Create custom layers for the model

In [11]:
"""
BidirectionalFlowEstimation is a layer that warps the features extracted at the given level trying 
to modify them to fit the target image. It predicts the flow between features of both the input_1 
and the input_2 and warps them to get the final output. The output shape is the same as the input shape.
"""
class BidirectionalFlowEstimation(layers.Layer):
    def __init__(self, filter_count=[32, 64, 64, 16], filter_size=[(3, 3), (3, 3), (1, 1), (1, 1)], activation='relu', regularizer=None, interpolation='bilinear', **kwargs):
        super(BidirectionalFlowEstimation, self).__init__(**kwargs)
        
        # flow 1 -> 2
        self.flow_add_1_2 = layers.Add()
        self.flow_upsample_1_2 = layers.UpSampling2D((2, 2), interpolation=interpolation)
        self.flow_1_2_concat = layers.Concatenate(axis=3)
        self.flow_prediction_1_2 =  keras.Sequential([
            layers.Conv2D(filter_count[0], filter_size[0], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[1], filter_size[1], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[2], filter_size[2], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[3], filter_size[3], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(2, (1, 1), kernel_regularizer=regularizer, padding='same')
        ])
        
        # flow 2 -> 1
        self.flow_add_2_1 = layers.Add()
        self.flow_upsample_2_1 = layers.UpSampling2D((2, 2), interpolation=interpolation)
        self.flow_2_1_concat = layers.Concatenate(axis=3)
        self.flow_prediction_2_1 = keras.Sequential([
            layers.Conv2D(filter_count[0], filter_size[0], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[1], filter_size[1], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[2], filter_size[2], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(filter_count[3], filter_size[3], activation=activation, kernel_regularizer=regularizer, padding='same'),
            layers.Conv2D(2, (1, 1), kernel_regularizer=regularizer, padding='same')
        ])
        
        self.filter_count = filter_count
        self.filter_size = filter_size
        self.activation = activation
        self.regularizer = regularizer
        self.interpolation = interpolation
        
    def get_config(self):
        config = super().get_config()
        config.update({
            "filter_count": self.filter_count,
            "filter_size": self.filter_size,
            "activation": self.activation,
            "regularizer": self.regularizer,
            "interpolation": self.interpolation,
        })
        return config

    def call(self, inputs):
        input_1 = inputs[0]
        input_2 = inputs[1]
        flow_1_2 = inputs[2]
        flow_2_1 = inputs[3]
        
        # input_1 to input_2 flow prediction
        input_1_warped_1 = tfa.image.dense_image_warp(input_1, flow_1_2)
            
        flow_change_1_2_concat = self.flow_1_2_concat([input_2, input_1_warped_1])
        flow_change_1_2 = self.flow_prediction_1_2(flow_change_1_2_concat)
        
        flow_1_2_changed = self.flow_add_1_2([flow_1_2, flow_change_1_2])
        input_1_warped_2 = tfa.image.dense_image_warp(input_1, flow_1_2_changed)
        flow_1_2_changed_upsampled = self.flow_upsample_1_2(flow_1_2_changed)
        
        # input_2 to input_1 flow prediction
        input_2_warped_1 = tfa.image.dense_image_warp(input_2, flow_2_1)
        
        flow_change_2_1_concat = self.flow_2_1_concat([input_1, input_2_warped_1])
        flow_change_2_1 = self.flow_prediction_2_1(flow_change_2_1_concat)

        flow_2_1_changed = self.flow_add_2_1([flow_2_1, flow_change_2_1])
        input_2_warped_2 = tfa.image.dense_image_warp(input_2, flow_2_1_changed)
        flow_2_1_changed_upsampled = self.flow_upsample_2_1(flow_2_1_changed)
        
        return input_1_warped_2, input_2_warped_2, flow_1_2_changed_upsampled, flow_2_1_changed_upsampled

In [15]:
filter_count=[16, 16, 16]
filter_size=(3, 3)
flow_filter_count=[16, 24, 24, 8] 
flow_filter_size=[(3, 3), (3, 3), (1, 1), (1, 1)]
interpolation='bilinear'
activation= 'relu'
regularizer=None

# ------------- shared layers
cnn_1st_level_1 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_1")
cnn_1st_level_2 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_2")
cnn_1st_level_3 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_3")
cnn_1st_level_4 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_1_4")

cnn_2nd_level_1 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_1")
cnn_2nd_level_2 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_2")
cnn_2nd_level_3 = layers.Conv2D(filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_2_3")

cnn_3rd_level_1 = layers.Conv2D(filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_3_1")
cnn_3rd_level_2 = layers.Conv2D(filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same', name="cnn_fe_3_2")

# ------------- feature extraction left side
input_1_left = layers.Input(shape=(height, width, 3), name="input_left")
input_2_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/2")(input_1_left)
input_3_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/4")(input_2_left)
input_4_left = layers.AveragePooling2D((2, 2), name="avg_input_left_1/8")(input_3_left)

# feature extraction for layer 1
input_1_column_1_row_1_left = cnn_1st_level_1(input_1_left)
input_2_column_1_row_2_left = cnn_1st_level_2(input_2_left)
input_3_column_1_row_3_left = cnn_1st_level_3(input_3_left)
input_4_column_1_row_4_left = cnn_1st_level_4(input_4_left)

# downsample layer 1
input_1_column_2_row_2_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/2")(input_1_column_1_row_1_left)
input_2_column_2_row_3_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/4")(input_2_column_1_row_2_left)
input_3_column_2_row_4_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_1_1/8")(input_3_column_1_row_3_left)

# feature extraction for layer 2
input_1_column_2_row_2_left = cnn_2nd_level_1(input_1_column_2_row_2_left)
input_2_column_2_row_3_left = cnn_2nd_level_2(input_2_column_2_row_3_left)
input_3_column_2_row_4_left = cnn_2nd_level_3(input_3_column_2_row_4_left)

# downsample layer 2
input_1_column_3_row_3_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_2_1/4")(input_1_column_2_row_2_left)
input_2_column_3_row_4_left = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_left_2_1/8")(input_2_column_2_row_3_left)

# feature extraction for layer 3
input_1_column_3_row_3_left = cnn_3rd_level_1(input_1_column_3_row_3_left)
input_2_column_3_row_4_left = cnn_3rd_level_2(input_2_column_3_row_4_left)

# concatenate
concat_2nd_left = layers.Concatenate(name="con_left_2")([input_2_column_1_row_2_left, input_1_column_2_row_2_left])
concat_3rd_left = layers.Concatenate(name="con_left_3")([input_3_column_1_row_3_left, input_2_column_2_row_3_left, input_1_column_3_row_3_left])
concat_4th_left = layers.Concatenate(name="con_left_4")([input_4_column_1_row_4_left, input_3_column_2_row_4_left, input_2_column_3_row_4_left])

# output from feature extraction left side: input_1_column_1_row_1_left, concat_2nd_left, concat_3rd_left, concat_4th_left

# ------------- feature extraction left side
input_1_right = layers.Input(shape=(height, width, 3), name="input_right")
input_2_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/2")(input_1_right)
input_3_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/4")(input_2_right)
input_4_right = layers.AveragePooling2D((2, 2), name="avg_input_right_1/8")(input_3_right)

# feature extraction for layer 1
input_1_column_1_row_1_right = cnn_1st_level_1(input_1_right)
input_2_column_1_row_2_right = cnn_1st_level_2(input_2_right)
input_3_column_1_row_3_right = cnn_1st_level_3(input_3_right)
input_4_column_1_row_4_right = cnn_1st_level_4(input_4_right)

# downsample layer 1
input_1_column_2_row_2_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/2")(input_1_column_1_row_1_right)
input_2_column_2_row_3_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/4")(input_2_column_1_row_2_right)
input_3_column_2_row_4_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_1_1/8")(input_3_column_1_row_3_right)

# feature extraction for layer 2
input_1_column_2_row_2_right = cnn_2nd_level_1(input_1_column_2_row_2_right)
input_2_column_2_row_3_right = cnn_2nd_level_2(input_2_column_2_row_3_right)
input_3_column_2_row_4_right = cnn_2nd_level_3(input_3_column_2_row_4_right)

# downsample layer 2
input_1_column_3_row_3_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_2_1/4")(input_1_column_2_row_2_right)
input_2_column_3_row_4_right = layers.AveragePooling2D((2, 2), name="avg_cnn_fe_right_2_1/8")(input_2_column_2_row_3_right)

# feature extraction for layer 3
input_1_column_3_row_3_right = cnn_3rd_level_1(input_1_column_3_row_3_right)
input_2_column_3_row_4_right = cnn_3rd_level_2(input_2_column_3_row_4_right)

# concatenate
concat_2nd_right = layers.Concatenate(name="con_right_2")([input_2_column_1_row_2_right, input_1_column_2_row_2_right])
concat_3rd_right = layers.Concatenate(name="con_right_3")([input_3_column_1_row_3_right, input_2_column_2_row_3_right, input_1_column_3_row_3_right])
concat_4th_right = layers.Concatenate(name="con_right_4")([input_4_column_1_row_4_right, input_3_column_2_row_4_right, input_2_column_3_row_4_right])

# output from feature extraction left side: input_1_column_1_row_1_right, concat_2nd_right, concat_3rd_right, concat_4th_right

# ------------- warping features at each level     
# for flow estimation
bidirectional_flow_estimation_1 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_blow_1st"
)
bidirectional_flow_estimation_2 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_blow_2nd"
)
bidirectional_flow_estimation_3 = BidirectionalFlowEstimation(
    filter_count=flow_filter_count, 
    filter_size=flow_filter_size, 
    activation=activation, 
    regularizer=regularizer, 
    interpolation=interpolation,
    name="bi_blow_above_3rd"
)

# create empty flow for the coarest level
empty_flow_1 = tf.zeros(shape=(batch_size, height//8, width//8, 2))
empty_flow_2 = tf.zeros(shape=(batch_size, height//8, width//8, 2))

# calculate the flow for each level using the input of current level and the upsampled flow from the level + 1
bfe_4_i1, bfe_4_i2, bfe_4_f_1_2, bfe_4_f_2_1 = bidirectional_flow_estimation_3([concat_4th_left, concat_4th_right, empty_flow_1, empty_flow_2])
bfe_3_i1, bfe_3_i2, bfe_3_f_1_2, bfe_3_f_2_1 = bidirectional_flow_estimation_3([concat_3rd_left, concat_3rd_right, bfe_4_f_1_2, bfe_4_f_2_1])
bfe_2_i1, bfe_2_i2, bfe_2_f_1_2, bfe_2_f_2_1 = bidirectional_flow_estimation_2([concat_2nd_left, concat_2nd_right, bfe_3_f_1_2, bfe_3_f_2_1])
bfe_1_i1, bfe_1_i2, _, _ = bidirectional_flow_estimation_1([input_1_column_1_row_1_left, input_1_column_1_row_1_right, bfe_2_f_1_2, bfe_2_f_2_1])

# returned by warping: (bfe_1_i1, bfe_2_i1, bfe_3_i1, bfe_4_i1), (bfe_1_i2, bfe_2_i2, bfe_3_i2, bfe_4_i2)

# ------------- warped features fusion   
        
# merge 4th level
added_4th_level = layers.Add()([bfe_4_i1, bfe_4_i2])
cnn_4th_1 = layers.Conv2D(filter_count[0] + filter_count[1] + filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_4th_level)
cnn_4th_2 = layers.Conv2D(filter_count[0] + filter_count[1] + filter_count[2], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_4th_1)
up_4th = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_4th_2)

# merge 3rd level
added_3rd_level = layers.Add()([bfe_3_i1, bfe_3_i2, up_4th])
cnn_3rd_1 = layers.Conv2D(filter_count[0] + filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_3rd_level)
cnn_3rd_2 = layers.Conv2D(filter_count[0] + filter_count[1], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_3rd_1)
up_3rd = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_3rd_2)

# merge 2nd level
added_2nd_level = layers.Add()([bfe_2_i1, bfe_2_i2, up_3rd])
cnn_2nd_1 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_2nd_level)
cnn_2nd_2 = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(cnn_2nd_1)
up_2nd = layers.UpSampling2D((2, 2), interpolation=interpolation)(cnn_2nd_2)

# merge 1st level
added_1st_level = layers.Add()([bfe_1_i1, bfe_1_i2, up_2nd])
x = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(added_1st_level)
x = layers.Conv2D(filter_count[0], filter_size, activation=activation, kernel_regularizer=regularizer, padding='same')(x)

# net output
outputs = layers.Conv2D(3, (1, 1), activation=output_activation, padding='same')(x)

## Create the final custom loss function & build the model

In [16]:
def loss(y_true, y_pred):
    ssim_ = ssim(y_true, y_pred)
    psnr_ = psnr(y_true, y_pred)
    l1_ = l1(y_true, y_pred)
    l2_ = l2(y_true, y_pred)
    return ssim_ + psnr_ + 5.0*l1_ + 10.0*l2_

In [17]:
# create and compile the model
model = keras.Model(inputs=[input_1_left, input_1_right], outputs=outputs)

model.compile(
    loss = loss,
    optimizer = optimizers.Nadam(0.0001, clipnorm=0.5),
    metrics = [l1, l2, psnr, ssim]
)

model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_left (InputLayer)        [(None, 144, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 input_right (InputLayer)       [(None, 144, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 avg_input_left_1/2 (AveragePoo  (None, 72, 128, 3)  0           ['input_left[0][0]']             
 ling2D)                                                                                    

 bi_blow_above_3rd (Bidirection  multiple            36276       ['con_left_4[0][0]',             
 alFlowEstimation)                                                'con_right_4[0][0]',            
                                                                  'con_left_3[0][0]',             
                                                                  'con_right_3[0][0]',            
                                                                  'bi_blow_above_3rd[0][2]',      
                                                                  'bi_blow_above_3rd[0][3]']      
                                                                                                  
 avg_cnn_fe_left_2_1/4 (Average  (None, 36, 64, 16)  0           ['cnn_fe_2_1[0][0]']             
 Pooling2D)                                                                                       
                                                                                                  
 avg_cnn_f

Total params: 170,879
Trainable params: 170,879
Non-trainable params: 0
__________________________________________________________________________________________________


## Train the model

In [None]:
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data = valid_generator,
    steps_per_epoch = int(data_train_size) // batch_size,
    validation_steps = int(data_valid_size) // batch_size,
    callbacks = [
        callbacks.ModelCheckpoint(
            os.path.join(model_base_path, model_name+'_'+'{loss:.4f}_{epoch:02d}_'+str(int(time.time()))+'.h5'),
            monitor = 'loss',
            mode = 'min',
            save_best_only = True,
            save_weights_only = False,
            save_freq = 200,
        )
    ]
)

Epoch 1/10
 458/3800 [==>...........................] - ETA: 12:57 - loss: 1.2403 - l1: 0.0636 - l2: 0.0144 - psnr: 0.4475 - ssim: 0.3312

## Evaluate the training

In [None]:
def norm_0_1(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def norm(data):
    return (data - np.mean(data)) / np.std(data)

In [None]:
def plot_history(history, normalize_method=None, metrics_restrictions=None):
    plt.clf()
    plt.figure(figsize=(25,10))
    
    metrics = list(history.keys())
    metrics = [metric for metric in metrics if "val" not in metric]
    if metrics_restrictions is not None:
        metrics = [metric for metric in metrics if metric in metrics_restrictions]
    
    data = [(index, history[metric], history['val_'+metric], metric) for index, metric in enumerate(metrics)]
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    epochs = range(1, len(data[0][1]) + 1)
    
    for index, value, val_value, metric in data:
        if normalize_method is not None:
            buffer = value.copy()
            buffer.extend(val_value)
            buffer = normalize_method(buffer)
            value = buffer[0:len(epochs)]
            val_value = buffer[len(epochs)::]
        
        plt.plot(epochs, value, colors[index], label=f"Training {metric}")
        plt.plot(epochs, val_value, colors[index]+'--', label=f"Validation {metric}")
        
    plt.xticks(epochs, size=17)    
    plt.yticks(size=17)
    plt.title(f"Comparison of Training and Validation metrics", size=20)
    plt.xlabel('Epochs', size=17)
    plt.ylabel("Metric values", size=17)
    plt.legend(loc='upper right', fontsize=14)
    plt.show()

In [None]:
def evaluate(model, generator, verbose=0):
    result_dict = {}
    result = model.evaluate(generator, verbose=verbose)
    for index, metric in enumerate(model.metrics):
        result_dict[metric.name] = result[index]
        print(f'{metric.name.zfill(13).replace("0", " ")}: {np.round(result[index], 4)}')
        
    return result_dict

In [None]:
plot_history(history, normalize_method=norm)

In [None]:
plot_history(history, normalize_method=norm_0_1)

In [None]:
plot_history(history, normalize_method=None, metrics_restrictions=['loss'])

In [None]:
_ = evaluate(model, test_generator, verbose=1)

## Visualize generated frames

In [None]:
predictions = model.predict(vis_generator)

In [None]:
def visualize(generator, predictions, batch, index):
    # verify arguments
    batch_size = generator.batch_size
    assert generator.shuffle == False
    assert batch >= 0 and batch < len(generator)
    assert index >= 0 and index < batch_size
    
    # get neighbours frames
    neighbours, true = generator[batch]
    neighbours = np.array(neighbours)
    neighbours = neighbours[:, index, :, :, :]

    # get true and predicted frames
    true = np.array(true)[index]
    predicted = predictions[batch_size*batch + index]
    
    # mark true edges on predicted frame
    true_edges = cv2.cvtColor(true, cv2.COLOR_RGB2GRAY)
    true_edges = cv2.GaussianBlur(true_edges, (3, 3), 1)
    true_edges = cv2.medianBlur(true_edges, 3)
    true_edges = cv2.Canny((true_edges*255).astype('uint8'), 50, 100)
    predicted_marked = predicted.copy()
    predicted_marked[true_edges != 0] = (1, 1, 1)

    # plot images
    f, ax = plt.subplots(3, 2)
    f.set_size_inches(20, 20)

    ax[0][0].set_title("First frame")
    ax[0][0].set_xticks([])
    ax[0][0].set_yticks([])
    ax[0][0].imshow(neighbours[0])
    
    ax[1][0].set_title("Predicted frame")
    ax[1][0].set_xticks([])
    ax[1][0].set_yticks([])
    ax[1][0].imshow(predicted)
    
    ax[2][0].set_title("Second frame")
    ax[2][0].set_xticks([])
    ax[2][0].set_yticks([])
    ax[2][0].imshow(neighbours[1])
    
    ax[0][1].set_title("Predicted and Ground-truth difference")
    ax[0][1].set_xticks([])
    ax[0][1].set_yticks([])
    ax[0][1].imshow(cv2.absdiff(predicted, true))
    
    ax[1][1].set_title("Ground-truth frame")
    ax[1][1].set_xticks([])
    ax[1][1].set_yticks([])
    ax[1][1].imshow(true)
    
    ax[2][1].set_title("Edge shift")
    ax[2][1].set_xticks([])
    ax[2][1].set_yticks([])
    ax[2][1].imshow(predicted_marked)

In [None]:
visualize(vis_generator, predictions, 0, 3)

In [None]:
visualize(vis_generator, predictions, 2, 0)

In [None]:
visualize(vis_generator, predictions, 2, 2)

In [None]:
visualize(vis_generator, predictions, 4, 2)

In [None]:
visualize(vis_generator, predictions, 5, 3)

In [None]:
visualize(vis_generator, predictions, 12, 3)

## Save or load the model

In [None]:
model_creation_time = int(time.time())
model.save(os.path.join(model_base_path, f'{model_name}.h5'))

In [None]:
# current best model
model_best = keras.models.load_model(
    os.path.join(model_base_path, 'frame_booster_0.5210_02_1683730177.h5'),
    custom_objects = {
        "BidirectionalFlowEstimation": BidirectionalFlowEstimation,
        "output_activation": output_activation,
        'loss': loss,
        'l1': l1,
        "ssim": ssim,
        "psnr": psnr,
        "l2": l2
    }
)