# Proof of concept notebook for the Frame Booster project
- Author: Kamil Barszczak
- Contact: kamilbarszczak62@gmail.com
- Project: https://github.com/kbarszczak/Frame_booster

In [241]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import torch.utils.data as data
import torch.nn.functional as F
import torch.nn as nn
import torch
import torchsummary
import torchvision

#### Notebook parameters

In [2]:
base_path = 'E:/Data/Video_Frame_Interpolation/processed/vimeo90k_pytorch'
data_subdir = 'data'
train_ids = 'train.txt'
test_ids = 'test.txt'
valid_ids = 'valid.txt'

width, height = 256, 144
epochs = 5
batch = 5

#### Setup device

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)

#### Load datasets

In [3]:
class ByteImageDataset(data.Dataset):
    def __init__(self, path, subdir, split_filename, shape):
        self.path = path
        self.subdir = subdir
        self.shape = shape
        self.ids = pd.read_csv(os.path.join(path, split_filename), names=["ids"])
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_path = os.path.join(self.path, self.subdir, str(self.ids.iloc[idx, 0]))
        
        imgs = [
            self._read_bytes_to_tensor(os.path.join(img_path, 'im1')),
            self._read_bytes_to_tensor(os.path.join(img_path, 'im3'))
        ]
        true = self._read_bytes_to_tensor(os.path.join(img_path, 'im2'))
        
        return imgs, true
    
    def _read_bytes_to_tensor(self, path):
        with open(path, 'rb') as bf:
            return torch.permute(torch.reshape(torch.frombuffer(bf.read(), dtype=torch.float), self.shape), (2, 0, 1))

In [4]:
train_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = train_ids,
        shape = (height, width, 3)
    ),
    shuffle = True,
    batch_size = batch,
    drop_last = True
)

test_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = test_ids,
        shape = (height, width, 3)
    ),
    batch_size = batch,
    drop_last = True
)

valid_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = valid_ids,
        shape = (height, width, 3)
    ),
    batch_size = batch,
    drop_last = True
)

In [5]:
print(f'Training batches: {len(train_dataloader)}')
print(f'Testing batches: {len(test_dataloader)}')
print(f'Validating batches: {len(valid_dataloader)}')

Training batches: 10000
Testing batches: 600
Validating batches: 200


#### Create the model

In [127]:
class TruncateActivation(nn.Module):
    def __init__(self, lower=0.0, upper=1.0, **kwargs):
        super(TruncateActivation, self).__init__(**kwargs)
        self.a = lower
        self.b = upper

    def forward(self, x):
        return torch.min(torch.max(x, self.a), self.b)

    
def l1(y_true, y_pred):
    return torch.sum(torch.abs(y_true - y_pred))


def l2(y_true, y_pred):
    return torch.sqrt(torch.sum((y_true - y_pred) ** 2))


def mae(y_true, y_pred):
    return torch.mean(torch.abs(y_true - y_pred))


def mse(y_true, y_pred):
    return torch.mean((y_true - y_pred) ** 2)


def psnr(y_true, y_pred):
    mse = torch.mean((y_true - y_pred) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))
    return 1 - psnr / 40.0


# def ssim(y_true, y_pred):
#     ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
#     return 1 - ssim

    
def loss(y_true, y_pred):
    # ssim_ = ssim(y_true, y_pred)
    psnr_ = psnr(y_true, y_pred)
    mse_ = mse(y_true, y_pred)
    mae_ = mae(y_true, y_pred)
    
    # return ssim_ + psnr_ + 5.0*l1_ + 10.0*l2_    
    return psnr_ + 5.0*mae_ + 10.0*mse_

In [240]:
class FlowEstimation(nn.Module):
    def __init__(self, flow_input_chanels,
                flow_info = {
                    "filter_counts": [32, 64, 64, 16, 12, 12],
                    "filter_sizes": [(3, 3), (3, 3), (3, 3), (3, 3), (1, 1), (1, 1)],
                    "filter_strides": [1, 1, 1, 1, 1, 1],
                    "filter_padding": [1, 1, 1, 1, 0, 0],
                    "activations": [nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU()],
                }, **kwargs):
        super(FlowEstimation, self).__init__(**kwargs)

        modules = []
        last_output_size = flow_input_chanels
        for fcount, fsize, fstride, fpad, fact in zip(flow_info['filter_counts'], flow_info['filter_sizes'], flow_info['filter_strides'], flow_info['filter_padding'], flow_info['activations']):
            modules.append(nn.Conv2d(last_output_size, fcount, fsize, fstride, fpad))
            modules.append(fact)
            last_output_size = fcount

        modules.append(nn.Conv2d(last_output_size, 2, 1))
        self.flow = nn.Sequential(*modules)

    def forward(self, x):
        return self.flow(x)

In [233]:
class BidirectionalFeatureWarp(nn.Module):
    def __init__(self, flow_prediction, interpolation='bilinear', **kwargs):
        super(BidirectionalFeatureWarp, self).__init__(**kwargs)
        
        self.flow_prediction = flow_prediction
        self.flow_upsample_1_2 = nn.Upsample(scale_factor=(2, 2), mode=interpolation)
        self.flow_upsample_2_1 = nn.Upsample(scale_factor=(2, 2), mode=interpolation)

    def forward(self, input_1, input_2, flow_1_2, flow_2_1):
        if torch.is_tensor(flow_1_2) and torch.is_tensor(flow_2_1):
            input_1_warped_1 = BidirectionalFlowEstimation.warp(input_1, flow_1_2)
            input_2_warped_1 = BidirectionalFlowEstimation.warp(input_2, flow_2_1)
        else:
            input_1_warped_1 = input_1
            input_2_warped_1 = input_2
            
        flow_change_1_2_concat = torch.cat([input_2, input_1_warped_1], dim=1)
        flow_change_1_2 = self.flow_prediction(flow_change_1_2_concat)
        
        flow_change_2_1_concat = torch.cat([input_1, input_2_warped_1], dim=1)
        flow_change_2_1 = self.flow_prediction(flow_change_2_1_concat)
        
        if torch.is_tensor(flow_1_2) and torch.is_tensor(flow_2_1):
            flow_1_2_changed = flow_1_2 + flow_change_1_2
            flow_2_1_changed = flow_2_1 + flow_change_2_1
        else:
            flow_1_2_changed = flow_change_1_2
            flow_2_1_changed = flow_change_2_1
            
        input_1_warped_2 = BidirectionalFlowEstimation.warp(input_1, flow_1_2_changed)
        input_2_warped_2 = BidirectionalFlowEstimation.warp(input_2, flow_2_1_changed)
        flow_1_2_changed_upsampled = self.flow_upsample_1_2(flow_1_2_changed)
        flow_2_1_changed_upsampled = self.flow_upsample_2_1(flow_2_1_changed)
        
        return input_1_warped_2, input_2_warped_2, flow_1_2_changed_upsampled, flow_2_1_changed_upsampled
    
    @staticmethod
    def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
        B, C, H, W = image.size()

        xx = torch.arange(0, W).view(1 ,-1).repeat(H, 1)
        yy = torch.arange(0, H).view(-1 ,1).repeat(1, W)
        xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
        yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)

        grid = torch.cat((xx, yy), 1).float()
        if image.is_cuda:
            grid = grid.cuda()

        vgrid = grid + flow
        vgrid[:, 0, :, :] = 2.0 * vgrid[: ,0 ,: ,:].clone() / max(W - 1, 1) - 1.0
        vgrid[:, 1, :, :] = 2.0 * vgrid[: ,1 ,: ,:].clone() / max(H - 1, 1) - 1.0

        vgrid = vgrid.permute(0, 2, 3, 1)
        flow = flow.permute(0, 2, 3, 1)
        output = F.grid_sample(image, vgrid)

        return output

In [None]:
class FBNet(nn.Module):
    def __init__(self, 
                 input_shape,
                 encoder_filters = [
                     [64, 48, 32, 32],  # encoder_filters_col_1
                     [48, 32, 32],  # encoder_filters_col_2
                     [32, 32]  # encoder_filters_col_3
                 ], 
                 decoder_filters = [32, 24],  # decoder_filters
                 flow_info = [
                     {  # flow_1
                        "filter_counts": [32, 48, 64, 80, 80, 48],
                        "filter_sizes": [7, 5, 5, 3, 1, 1],
                        "filter_strides": [1, 1, 1, 1, 1, 1],
                        "filter_padding": [3, 2, 2, 1, 0, 0],
                        "activations": [nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU()]
                     }, 
                     {  # flow_2
                        "filter_counts": [24, 32, 64, 64, 32],
                        "filter_sizes": [5, 3, 3, 1, 1],
                        "filter_strides": [1, 1, 1, 1, 1],
                        "filter_padding": [2, 1, 1, 0, 0],
                        "activations": [nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU()]
                     }, 
                     {  # flow_3
                        "filter_counts": [24, 48, 48, 16],
                        "filter_sizes": [3, 3, 1, 1],
                        "filter_strides": [1, 1, 1, 1],
                        "filter_padding": [1, 1, 0, 0],
                        "activations": [nn.PReLU(), nn.PReLU(), nn.PReLU(), nn.PReLU()]
                     },
                 ], interpolation="bilinear", **kwargs):
        super(FBNet, self).__init__(**kwargs)
        
        self.c = input_shape[0]
        self.h = input_shape[1]
        self.w = input_shape[2]
        
        # ------------- Shared Conv2d, AvgPool2d & Resize encoding layers
        self.resize_1_2 = torchvision.transforms.Resize(size=(height//2, width//2))
        self.resize_1_4 = torchvision.transforms.Resize(size=(height//4, width//4))
        self.resize_1_8 = torchvision.transforms.Resize(size=(height//8, width//8))
        
        self.cnn_r1_c1 = nn.Conv2d(self.c, encoder_filters[0][0], 3, 1, 1, name="fe_conv2d_r1_c1")
        self.cnn_r2_c1 = nn.Conv2d(self.c, encoder_filters[0][1], 3, 1, 1, name="fe_conv2d_r2_c1")
        self.cnn_r3_c1 = nn.Conv2d(self.c, encoder_filters[0][2], 3, 1, 1, name="fe_conv2d_r3_c1")
        self.cnn_r4_c1 = nn.Conv2d(self.c, encoder_filters[0][3], 3, 1, 1, name="fe_conv2d_r4_c1")
        selc.act_r1_c1 = nn.PReLU()
        selc.act_r2_c1 = nn.PReLU()
        selc.act_r3_c1 = nn.PReLU()
        selc.act_r4_c1 = nn.PReLU()

        self.cnn_r2_c2 = nn.Conv2d(encoder_filters[0][0], encoder_filters[1][0], 3, 1, 1, name="fe_conv2d_r2_c2")
        self.cnn_r3_c2 = nn.Conv2d(encoder_filters[0][1], encoder_filters[1][1], 3, 1, 1, name="fe_conv2d_r3_c2")
        self.cnn_r4_c2 = nn.Conv2d(encoder_filters[0][2], encoder_filters[1][2], 3, 1, 1, name="fe_conv2d_r4_c2")
        selc.act_r2_c2 = nn.PReLU()
        selc.act_r3_c2 = nn.PReLU()
        selc.act_r4_c2 = nn.PReLU()

        self.cnn_r3_c3 = nn.Conv2d(encoder_filters[1][0], encoder_filters[2][0], 3, 1, 1, name="fe_conv2d_r3_c3")
        self.cnn_r4_c3 = nn.Conv2d(encoder_filters[1][1], encoder_filters[2][1], 3, 1, 1, name="fe_conv2d_r4_c3")
        selc.act_r3_c3 = nn.PReLU()
        selc.act_r4_c3 = nn.PReLU()
        
        self.avg_r2_c1 = nn.AvgPool2d(2, name="avg_r2_c1")
        self.avg_r3_c1 = nn.AvgPool2d(2, name="avg_r3_c1")
        self.avg_r4_c1 = nn.AvgPool2d(2, name="avg_r4_c1")
        
        self.avg_r3_c2 = nn.AvgPool2d(2, name="avg_r3_c2")
        self.avg_r4_c2 = nn.AvgPool2d(2, name="avg_r3_c2")
        
        # ------------- Feature warping layers 
        self.bidirectional_warp_row_1 = BidirectionalFeatureWarp(
            flow_prediction = FlowEstimation(
                flow_input_chanels = encoder_filters[0][0],
                flow_info = flow_info[0]
            ),
            interpolation = interpolation
        )
        self.bidirectional_warp_row_2 = BidirectionalFeatureWarp(
            flow_prediction = FlowEstimation(
                flow_input_chanels = encoder_filters[0][1] + encoder_filters[1][0],
                flow_info = flow_info[1]
            ),
            interpolation = interpolation
        )
        self.bidirectional_warp_row_3 = BidirectionalFeatureWarp(
            flow_prediction = FlowEstimation(
                flow_input_chanels = encoder_filters[0][2] + encoder_filters[1][1] + encoder_filters[2][0],
                flow_info = flow_info[2]
            ),
            interpolation = interpolation
        )
        
        # ------------- Decoding Conv2d layers
        self.cnn_r4_1 = nn.Conv2d(encoder_filters[0][3] + encoder_filters[1][2] + encoder_filters[2][1], encoder_filters[0][2] + encoder_filters[1][1] + encoder_filters[2][0], 3, 1, 1, name="fus_conv2d_r_4_1")
        self.act_r4_1 = nn.PReLU()
        self.up_r4 = nn.Upsample(scale_factor=(2, 2), mode=interpolation)
        
        self.cnn_r3_1 = nn.Conv2d(encoder_filters[0][2] + encoder_filters[1][1] + encoder_filters[2][0], encoder_filters[0][1] + encoder_filters[1][0], 3, 1, 1, name="fus_conv2d_r_3_1")
        self.act_r3_1 = nn.PReLU()
        self.up_r3 = nn.Upsample(scale_factor=(2, 2), mode=interpolation)
        
        self.cnn_r2_1 = nn.Conv2d(encoder_filters[0][1] + encoder_filters[1][0], encoder_filters[0][0], 3, 1, 1, name="fus_conv2d_r_2_1")
        self.cnn_r2_2 = nn.Conv2d(encoder_filters[0][0], encoder_filters[0][0], 3, 1, 1, name="fus_conv2d_r_2_2")
        self.act_r2_1 = nn.PReLU()
        self.act_r2_2 = nn.PReLU()
        self.up_r2 = nn.Upsample(scale_factor=(2, 2), mode=interpolation)
        
        self.cnn_r1_1 = nn.Conv2d(encoder_filters[0][0], decoder_filters[0], 3, 1, 1, name="fus_conv2d_r_1_1")
        self.cnn_r1_2 = nn.Conv2d(decoder_filters[0], decoder_filters[1], 3, 1, 1, name="fus_conv2d_r_1_2")
        self.act_r1_1 = nn.PReLU()
        self.act_r1_2 = nn.PReLU()
        
        self.cnn_out = nn.Conv2d(decoder_filters[1], 3, 1, 1, 0, name="fus_conv2d_output")
        self.act_out = TruncateActivation()

    def forward(self, input_1_left, input_1_right):
        # ------------- Process left input
        input_2_left = self.resize_1_2(input_1_left)
        input_3_left = self.resize_1_4(input_2_left)
        input_4_left = self.resize_1_8(input_3_left)
        
        # Feature extraction for layer 1
        input_1_left_cnn_r1_c1 = self.act_r1_c1(self.cnn_r1_c1(input_1_left))
        input_2_left_cnn_r2_c1 = self.act_r2_c1(self.cnn_r2_c1(input_2_left))
        input_3_left_cnn_r3_c1 = self.act_r3_c1(self.cnn_r3_c1(input_3_left))
        input_4_left_cnn_r4_c1 = self.act_r4_c1(self.cnn_r4_c1(input_4_left))

        # Downsample layer 1
        input_1_left_cnn_r2_c1 = self.avg_r2_c1(input_1_left_cnn_r1_c1)
        input_2_left_cnn_r3_c1 = self.avg_r3_c1(input_2_left_cnn_r2_c1)
        input_3_left_cnn_r4_c1 = self.avg_r4_c1(input_3_left_cnn_r3_c1)

        # Feature extraction for layer 2
        input_1_left_cnn_r2_c2 = self.act_r2_c2(self.cnn_r2_c2(input_1_left_cnn_r2_c1))
        input_2_left_cnn_r3_c2 = self.act_r3_c2(self.cnn_r3_c2(input_2_left_cnn_r3_c1))
        input_3_left_cnn_r4_c2 = self.act_r4_c2(self.cnn_r4_c2(input_3_left_cnn_r4_c1))

        # Downsample layer 2
        input_1_left_cnn_r3_c2 = self.avg_r3_c2(input_1_left_cnn_r2_c2)
        input_2_left_cnn_r4_c2 = self.avg_r4_c2(input_2_left_cnn_r3_c2)

        # Feature extraction for layer 3
        input_1_left_cnn_r3_c3 = self.act_r3_c3(self.cnn_r3_c3(input_1_left_cnn_r3_c2))
        input_2_left_cnn_r4_c3 = self.act_r4_c3(self.cnn_r4_c3(input_2_left_cnn_r4_c2))

        # Concatenate
        concat_left_row_2 = torch.cat([input_2_left_cnn_r2_c1, input_1_left_cnn_r2_c2], dim=1)
        concat_left_row_3 = torch.cat([input_3_left_cnn_r3_c1, input_2_left_cnn_r3_c2, input_1_left_cnn_r3_c3], dim=1)
        concat_left_row_4 = torch.cat([input_4_left_cnn_r4_c1, input_3_left_cnn_r4_c2, input_2_left_cnn_r4_c3], dim=1)
        
        # Feature extraction left side output: 
        # * input_1_left_cnn_r1_c1
        # * concat_left_row_2
        # * concat_left_row_3
        # * concat_left_row_4
        
        # ------------- Process right input
        input_2_right = self.resize_1_2(input_1_right)
        input_3_right = self.resize_1_4(input_2_right)
        input_4_right = self.resize_1_8(input_3_right)

        # Feature extraction for layer 1
        input_1_right_cnn_r1_c1 = self.act_r1_c1(self.cnn_r1_c1(input_1_right))
        input_2_right_cnn_r2_c1 = self.act_r2_c1(self.cnn_r2_c1(input_2_right))
        input_3_right_cnn_r3_c1 = self.act_r3_c1(self.cnn_r3_c1(input_3_right))
        input_4_right_cnn_r4_c1 = self.act_r4_c1(self.cnn_r4_c1(input_4_right))

        # Downsample layer 1
        input_1_right_cnn_r2_c1 = self.avg_r2_c1(input_1_right_cnn_r1_c1)
        input_2_right_cnn_r3_c1 = self.avg_r3_c1(input_2_right_cnn_r2_c1)
        input_3_right_cnn_r4_c1 = self.avg_r4_c1(input_3_right_cnn_r3_c1)

        # Feature extraction for layer 2
        input_1_right_cnn_r2_c2 = self.act_r2_c2(self.cnn_r2_c2(input_1_right_cnn_r2_c1))
        input_2_right_cnn_r3_c2 = self.act_r3_c2(self.cnn_r3_c2(input_2_right_cnn_r3_c1))
        input_3_right_cnn_r4_c2 = self.act_r4_c2(self.cnn_r4_c2(input_3_right_cnn_r4_c1))

        # Downsample layer 2
        input_1_right_cnn_r3_c2 = self.avg_r3_c2(input_1_right_cnn_r2_c2)
        input_2_right_cnn_r4_c2 = self.avg_r4_c2(input_2_right_cnn_r3_c2)

        # Feature extraction for layer 3
        input_1_right_cnn_r3_c3 = self.act_r3_c3(self.cnn_r3_c3(input_1_right_cnn_r3_c2))
        input_2_right_cnn_r4_c3 = self.act_r4_c3(self.cnn_r4_c3(input_2_right_cnn_r4_c2))

        # Concatenate
        concat_right_row_2 = torch.cat([input_2_right_cnn_r2_c1, input_1_right_cnn_r2_c2], dim=1)
        concat_right_row_3 = torch.cat([input_3_right_cnn_r3_c1, input_2_right_cnn_r3_c2, input_1_right_cnn_r3_c3], dim=1)
        concat_right_row_4 = torch.cat([input_4_right_cnn_r4_c1, input_3_right_cnn_r4_c2, input_2_right_cnn_r4_c3], dim=1)

        # Feature extraction right side output: 
        # * input_1_right_cnn_r1_c1
        # * concat_right_row_2
        # * concat_right_row_3
        # * concat_right_row_4
        
        # ------------- Warping features at each level
        # Calculate the flow for each level using the input of current level and the upsampled flow from the level + 1
        bfe_4_i1, bfe_4_i2, bfe_4_f_1_2, bfe_4_f_2_1 = self.bidirectional_warp_row_3([concat_left_row_4, concat_right_row_4, None, None])
        bfe_3_i1, bfe_3_i2, bfe_3_f_1_2, bfe_3_f_2_1 = self.bidirectional_warp_row_3([concat_left_row_3, concat_right_row_3, bfe_4_f_1_2, bfe_4_f_2_1])
        bfe_2_i1, bfe_2_i2, bfe_2_f_1_2, bfe_2_f_2_1 = self.bidirectional_warp_row_2([concat_left_row_2, concat_right_row_2, bfe_3_f_1_2, bfe_3_f_2_1])
        bfe_1_i1, bfe_1_i2, _, _ = self.bidirectional_warp_row_1([input_1_left_cnn_r1_c1, input_1_right_cnn_r1_c1, bfe_2_f_1_2, bfe_2_f_2_1])

        # Flow estimation output: 
        # * (bfe_1_i1, bfe_2_i1, bfe_3_i1, bfe_4_i1) 
        # * (bfe_1_i2, bfe_2_i2, bfe_3_i2, bfe_4_i2)
        
        # ------------- Warped features fusion   
        # Merge row 4
        add_row_4 = bfe_4_i1 + bfe_4_i2
        cnn_row_4_1 = self.act_r4_1(self.cnn_r4_1(add_row_4))
        upsample_row_4 = self.up_r4(cnn_row_4_1)

        # Merge row 3
        add_row_3 = bfe_3_i1 + bfe_3_i2 + upsample_row_4
        cnn_row_3_1 = self.act_r3_1(self.cnn_r3_1(add_row_3))
        upsample_row_3 = self.up_r3(cnn_row_3_1)

        # Merge row 2
        add_row_2 = bfe_2_i1 + bfe_2_i2 + upsample_row_3
        cnn_row_2_1 = self.act_r2_1(self.cnn_r2_1(add_row_2))
        cnn_row_2_2 = self.act_r2_2(self.cnn_r2_2(cnn_row_2_1))
        upsample_row_2 = self.up_r2(cnn_row_2_2)

        # Merge row 1
        add_row_1 = bfe_1_i1 + bfe_1_i2 + upsample_row_2
        cnn_row_1_1 = self.act_r1_1(self.cnn_r1_1(add_row_1))
        cnn_row_1_2 = self.act_r1_2(self.cnn_r1_2(cnn_row_1_1))

        # Create the output layer
        fus_conv2d_outputs = self.act_out(self.cnn_out(cnn_row_1_2))
         
        # Feature fusion output: 
        # * fus_conv2d_outputs
        
        return fus_conv2d_outputs

In [None]:
model = keras.Model(inputs=[input_1_left, input_1_right], outputs=fus_conv2d_outputs)
model.compile(
    loss = loss,
    metrics = [l1, l2, psnr, ssim]
)
model.summary()

#### Train the model

In [None]:
def fit(model, train_generator, train_size, valid_generator, valid_size, optimizer, loss, metrics, epochs, batch_size, save_freq=50, log_freq=10, bad_input_limit=5, mode="all"):
    @tf.function
    def train_step(x, y):
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)
            loss_value = loss(y, y_pred)
            metrics_values = [metric(y, y_pred) for metric in metrics]

        if tf.math.is_finite(loss_value):
            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))
            
        return loss_value, metrics_values
        
    
    @tf.function
    def valid_step(x, y):
        y_pred = model(x, training=False)
        return loss(y, y_pred), [metric(y, y_pred) for metric in metrics]
    
    
    def get_loss_metrics_str(loss_value, metrics_values, sep=' '):
        result = 'loss=' + '{:.5f}'.format(loss_value)
        for metric_value, metric in zip(metrics_values, metrics):
            result += f'{sep}{metric.__name__}='+'{:.5f}'.format(metric_value)
        return result

    
    # create dict for a history and a list for bad input
    history = {metric.__name__: [] for metric in metrics}
    history = history | {"val_" + metric.__name__: [] for metric in metrics}
    history[loss.__name__] = []
    history["val_" + loss.__name__] = []
    bad_input = []
    best_loss = None
    
    try:
        # loop over epochs
        for epoch in range(1, epochs+1):
            print(f"Epoch: {epoch}/{epochs}")

            # process the full training dataset
            total_metrics = tf.zeros(len(metrics))
            total_loss = 0
            batch_index = 1.0
            for step, record in enumerate(train_generator):
                # extract the data
                x = record[0]
                y = record[1]

                # calculate metrics values, the loss and then apply the gradient change if loss is not NaN
                start = time.time()
                loss_value, metrics_values = train_step(x, y)
                end = time.time()
                
                # is loss was NaN save the bad input and get to next iteration
                if not tf.math.is_finite(loss_value):
                    print(f"Loss NaN detected at epoch {epoch} in step {(step+1)}. Wrong data saved to bad_input list")
                    bad_input.append((x, y))
                    if len(bad_input) >= bad_input_limit:
                        raise OverflowError(f"The bad_input limit of {bad_input_limit} was reached")
                    continue

                # save the loss & metrics values
                total_loss += loss_value
                total_metrics += metrics_values

                # save the model
                if step % save_freq == 0 and step > 0:
                    loss_avg = total_loss / batch_index
                    if mode == "all" or (mode == "best" and (best_loss is None or best_loss > loss_avg)):
                        print("Saving model with loss " + '{:.5f}'.format(loss_avg))
                        model.save(os.path.join(model_base_path, f'{model_name}_{get_loss_metrics_str(loss_avg, total_metrics/batch_index, sep="_")}_e={(epoch)}_s={(step+1)}_t={int(time.time())}.h5'))
                        best_loss = loss_avg

                # log the loss
                if step % log_freq == 0:
                    time_left = (end - start) * ((train_size // batch_size) - (step+1))
                    time_formatted = time.strftime("%Hh %Mm %Ss", time.gmtime(time_left))
                    ljust_length = len(str(train_size//batch_size)) * 2 + 1
                    prefix = 'Step ' + f'{(step+1)}/{(train_size//batch_size)}'.rjust(ljust_length) + f" (eta: {time_formatted}): "
                    print(f'{prefix}{get_loss_metrics_str(total_loss/batch_index, total_metrics/batch_index)}')

                # break the learning if the generator is over
                if step >= ((train_size // batch_size) - 1):
                    break
                
                batch_index += 1.0

            # save the loss value
            history[loss.__name__].append(total_loss / batch_index)
            for index, metric in enumerate(metrics):
                history[metric.__name__].append(total_metrics[index] / batch_index)

            # process the full validating dataset
            total_loss = 0
            total_metrics = tf.zeros(len(metrics))
            batch_index = 1.0
            for step, record in enumerate(valid_generator):
                x = record[0]
                y = record[1]

                loss_value, metrics_values = valid_step(x, y)
                total_loss += loss_value
                total_metrics += metrics_values

                if step >= ((valid_size // batch_size) - 1):
                    break

                batch_index += 1.0

            # log the validation score
            print(f'Validation for epoch {epoch}: {get_loss_metrics_str(total_loss/batch_index, total_metrics/batch_index)}')

            # save the validation score
            history["val_" + loss.__name__].append(total_loss/batch_index)
            for index, metric in enumerate(metrics):
                history["val_" + metric.__name__].append(total_metrics[index]/batch_index)
    except (OverflowError, KeyboardInterrupt) as e:
        print(f"Learning interrupted. Details: '{e}'")
    
    return history, bad_input

In [None]:
history = fit(
    model=model, 
    train_generator=train_generator,
    train_size=data_train_size, 
    valid_generator=valid_generator,
    valid_size=data_valid_size,
    optimizer=optimizers.Nadam(0.0001), 
    loss=loss, 
    metrics=[l1, l2, psnr, ssim],
    epochs=epochs, 
    batch_size=batch_size, 
    save_freq=100,
    log_freq=100,
    bad_input_limit=50,
    mode="best"
)