# Proof of concept notebook for the Frame Booster project
- Author: Kamil Barszczak
- Contact: kamilbarszczak62@gmail.com
- Project: https://github.com/kbarszczak/Frame_booster

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pickle
import tqdm
import time
import cv2
import os

import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torchsummary
import torchvision
import torch

#### Notebook parameters

In [None]:
base_path = 'D:/Data/Video_Frame_Interpolation/vimeo90k_pytorch'
data_subdir = 'data'
vis_subdir = 'vis'
train_ids = 'train.txt'
test_ids = 'test.txt'
valid_ids = 'valid.txt'
vis_ids = 'vis.txt'

width, height = 256, 144
epochs = 4
batch = 2

#### Setup device

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(device)

#### Load datasets

In [None]:
class ByteImageDataset(data.Dataset):
    def __init__(self, path, subdir, split_filename, shape):
        self.path = path
        self.subdir = subdir
        self.shape = shape
        self.ids = pd.read_csv(os.path.join(path, split_filename), names=["ids"])
        
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_path = os.path.join(self.path, self.subdir, str(self.ids.iloc[idx, 0]))
        imgs = [
            self._read_bytes_to_tensor(os.path.join(img_path, 'im1')),
            self._read_bytes_to_tensor(os.path.join(img_path, 'im3'))
        ]
        true = self._read_bytes_to_tensor(os.path.join(img_path, 'im2'))
        return imgs, true
    
    def _read_bytes_to_tensor(self, path):
        with open(path, 'rb') as bf:
            return torch.from_numpy(np.transpose(np.reshape(np.frombuffer(bf.read(), dtype='float32'), self.shape), (2, 0, 1)).copy())

In [None]:
train_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = train_ids,
        shape = (height, width, 3)
    ),
    shuffle = True,
    batch_size = batch,
    drop_last = True
)

test_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = test_ids,
        shape = (height, width, 3)
    ),
    batch_size = batch,
    drop_last = True
)

valid_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = data_subdir,
        split_filename = valid_ids,
        shape = (height, width, 3)
    ),
    batch_size = batch,
    drop_last = True
)

vis_dataloader = data.DataLoader(
    dataset = ByteImageDataset(
        path = base_path,
        subdir = vis_subdir,
        split_filename = vis_ids,
        shape = (height, width, 3)
    ),
    batch_size = batch,
    drop_last = True,
    shuffle = False
)

In [None]:
print(f'Training batches: {len(train_dataloader)}')
print(f'Testing batches: {len(test_dataloader)}')
print(f'Validating batches: {len(valid_dataloader)}')
print(f'Visualizing batches: {len(vis_dataloader)}')

#### Create the model

In [None]:
class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        self.__name__ = "perceptual"
        blocks = []
        blocks.append(torchvision.models.vgg16(weights='DEFAULT').features[:4].eval().to(device))
        blocks.append(torchvision.models.vgg16(weights='DEFAULT').features[4:9].eval().to(device))
        blocks.append(torchvision.models.vgg16(weights='DEFAULT').features[9:16].eval().to(device))
        blocks.append(torchvision.models.vgg16(weights='DEFAULT').features[16:23].eval().to(device))
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        self.blocks = nn.ModuleList(blocks).to(device)
        self.transform = F.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))

    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[0, 1, 2, 3]):
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = input
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += F.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

    
def mae(y_true, y_pred):
    return torch.mean(torch.abs(y_true - y_pred))


def mse(y_true, y_pred):
    return torch.mean((y_true - y_pred) ** 2)


def psnr(y_true, y_pred):
    mse = torch.mean((y_true - y_pred) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))
    return 1 - psnr / 40.0


perceptual_loss = VGGPerceptualLoss()
    
    
def loss(y_true, y_pred):
    perceptual_loss_ = perceptual_loss(y_true, y_pred)
    psnr_ = psnr(y_true, y_pred)
    mse_ = mse(y_true, y_pred)
    mae_ = mae(y_true, y_pred)
    
    return 0.5*perceptual_loss_ + psnr_ + 5.0*mae_ + 10.0*mse_

In [None]:
class TReLU(nn.Module):
    def __init__(self, lower=0.0, upper=1.0, **kwargs):
        super(TReLU, self).__init__(**kwargs)
        self.lower = lower
        self.upper = upper

    def forward(self, x):
        return torch.clip(x, min=self.lower, max=self.upper)

In [None]:
class Conv2dBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output=False, **kwargs):
        super(Conv2dBlock, self).__init__(**kwargs)
        self.cnn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.act = nn.PReLU() if not output else TReLU()

    def forward(self, x):
        return self.act(self.cnn(x))

In [None]:
class FlowEstimator(nn.Module):
    def __init__(self, filters, fcount=32, fsize=5, **kwargs):
        super(FlowEstimator, self).__init__(**kwargs)
        
        # ------------- Flow estimation layers
        fpadding = fsize // 2
        self.cnn_1 = nn.Conv2d(2*filters, fcount, fsize, 1, fpadding)
        self.cnn_2 = nn.Conv2d(fcount, fcount, fsize, 1, fpadding)
        self.cnn_3 = nn.Conv2d(fcount, fcount, fsize, 1, fpadding)
        self.cnn_4 = nn.Conv2d(fcount, fcount, fsize, 1, fpadding)
        self.cnn_5 = nn.Conv2d(fcount, fcount, fsize, 1, fpadding)
        self.cnn_6 = nn.Conv2d(fcount, 2, fsize, 1, fpadding)

    def forward(self, source, target, flow):
        if torch.is_tensor(flow):
            x0 = FlowPyramid.warp(source, flow)
            x0 = torch.cat([x0, target], dim=1)
        else:
            x0 = torch.cat([source, target], dim=1)
        
        x1 = self.cnn_1(x0)
        x2 = self.cnn_2(x1) + x1
        x3 = self.cnn_3(x2) + x2
        x4 = self.cnn_4(x3) + x3
        x5 = self.cnn_5(x4) + x4
        x6 = self.cnn_6(x5)
        
        if torch.is_tensor(flow):
            return flow + x6
        else:
            return x6

In [None]:
class FlowPyramid(nn.Module):
    def __init__(self, in_channels, filters, fsizes, **kwargs):
        super(FlowPyramid, self).__init__(**kwargs)
        
        # ------------- Unet type encoding layers
        self.cnn_1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
        self.cnn_2 = nn.Conv2d(in_channels, in_channels, 3, 2, 1)
        self.cnn_3 = nn.Conv2d(in_channels, in_channels + filters, 3, 1, 1)
        self.cnn_4 = nn.Conv2d(in_channels + filters, in_channels + filters, 3, 2, 1)
        self.cnn_5 = nn.Conv2d(in_channels + filters, in_channels + 2*filters, 3, 1, 1)
        self.cnn_6 = nn.Conv2d(in_channels + 2*filters, in_channels + 2*filters, 3, 2, 1)
        self.cnn_7 = nn.Conv2d(in_channels + 2*filters, in_channels + 3*filters, 3, 1, 1)
        
        # ------------- Flow estimation layers
        self.flow_1 = FlowEstimator(filters=in_channels, fcount=in_channels, fsize=fsizes[0])
        self.flow_2 = FlowEstimator(filters=in_channels + filters, fcount=in_channels + filters, fsize=fsizes[1])
        self.flow_3 = FlowEstimator(filters=in_channels + 2*filters, fcount=in_channels + 2*filters, fsize=fsizes[2])
        self.flow_4 = FlowEstimator(filters=in_channels + 3*filters, fcount=in_channels + 3*filters, fsize=fsizes[3])
        
        # ------------- Upsample layer
        self.upsample = nn.Upsample(scale_factor=(2, 2), mode='bilinear', align_corners=True)
        
    def _encode(self, x):
        out1 = self.cnn_1(x)
        out2 = self.cnn_3(self.cnn_2(out1))
        out3 = self.cnn_5(self.cnn_4(out2))
        out4 = self.cnn_7(self.cnn_6(out3))
        return out1, out2, out3, out4
    
    def _process_flow(self, flow):
        return self.upsample(flow) * 2

    def forward(self, source, target):
        # make unet type encoded features
        s1, s2, s3, s4 = self._encode(source)
        t1, t2, t3, t4 = self._encode(target)
        
        # calculate flow
        f4 = self._process_flow(self.flow_4(s4, t4, None))
        f3 = self._process_flow(self.flow_3(s3, t3, f4))
        f2 = self._process_flow(self.flow_2(s2, t2, f3))
        f1 = self.flow_1(s1, t1, f2)

        return f1
    
    @staticmethod
    def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
        B, C, H, W = image.size()

        xx = torch.arange(0, W).view(1 ,-1).repeat(H, 1)
        yy = torch.arange(0, H).view(-1 ,1).repeat(1, W)
        xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
        yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)

        grid = torch.cat((xx, yy), 1).float()
        if image.is_cuda:
            grid = grid.cuda()

        vgrid = grid + flow
        vgrid[:, 0, :, :] = 2.0 * vgrid[: ,0 ,: ,:].clone() / max(W - 1, 1) - 1.0
        vgrid[:, 1, :, :] = 2.0 * vgrid[: ,1 ,: ,:].clone() / max(H - 1, 1) - 1.0

        vgrid = vgrid.permute(0, 2, 3, 1)
        flow = flow.permute(0, 2, 3, 1)
        output = F.grid_sample(image, vgrid, align_corners=False)

        return output

In [None]:
class FBNet(nn.Module):
    def __init__(self, input_shape, device, filters=16, fsizes=[3, 3, 3, 3], **kwargs):
        super(FBNet, self).__init__(**kwargs)
        
        self.b = input_shape[0]
        self.c = input_shape[1]
        self.h = input_shape[2]
        self.w = input_shape[3]
        
        # ------------- Feature encoding layers
        self.cnn_enblock_1 = Conv2dBlock(self.c, filters, 3, 1, 1)
        self.cnn_enblock_2 = Conv2dBlock(self.c + filters, filters, 3, 1, 1)
        self.cnn_enblock_3 = Conv2dBlock(self.c + 2*filters, filters, 3, 1, 1)
        self.cnn_enblock_4 = Conv2dBlock(self.c + 3*filters, filters, 3, 1, 1)
        
        # ------------- Feature decoding layers
        self.cnn_decblock_1 = Conv2dBlock(2*(self.c + 4*filters), 4*filters, 3, 1, 1)
        self.cnn_decblock_2 = Conv2dBlock(4*filters, 4*filters, 3, 1, 1)
        self.cnn_decblock_3 = Conv2dBlock(4*filters, 2*filters, 3, 1, 1)
        self.cnn_decblock_4 = Conv2dBlock(2*filters, 2*filters, 3, 1, 1)
        self.cnn_decblock_5 = Conv2dBlock(2*filters, 3, 3, 1, 1, output=True)
        
        # ------------- Flow pyramid layer
        self.flow_pyramid = FlowPyramid(in_channels=self.c + 4*filters, filters=filters, fsizes=fsizes)
        
    def _encode(self, x):
        x = torch.cat([self.cnn_enblock_1(x), x], dim=1)
        x = torch.cat([self.cnn_enblock_2(x), x], dim=1)
        x = torch.cat([self.cnn_enblock_3(x), x], dim=1)
        x = torch.cat([self.cnn_enblock_4(x), x], dim=1)
        return x
    
    def _flow(self, left, right):
        flow_left = self.flow_pyramid(left, right)
        flow_right = self.flow_pyramid(right, left)
        return flow_left, flow_right
    
    def _decode(self, left, right):
        x0 = torch.cat([left, right], dim=1)
        x1 = self.cnn_decblock_1(x0)
        x2 = self.cnn_decblock_2(x1) + x1
        x3 = self.cnn_decblock_3(x2)
        x4 = self.cnn_decblock_4(x3) + x3
        x5 = self.cnn_decblock_5(x4)
        return x5
    
    def forward(self, left, right):
        # encode features
        left = self._encode(left)
        right = self._encode(right)
        
        # calculate flow & warp features
        flow_left, flow_right = self._flow(left, right)
        left = FlowPyramid.warp(left, flow_left)
        right = FlowPyramid.warp(right, flow_right)
        
        # decode features
        result = self._decode(left, right)
        
#        # return the result
#         if self.training:
#             return result, flow_left
#         else:
#             return result
        # return the result
        return result

In [None]:
fbnet = FBNet(input_shape=(batch, 3, height, width), device=device).to(device)
torchsummary.summary(fbnet, [(3, height, width), (3, height, width)])

#### Train the model

In [None]:
def plot_triplet(left, right, y, y_pred, figsize=(20, 4)):
    plt.figure(figsize=figsize)
    data = torch.cat([
        torchvision.transforms.functional.rotate(right, 90, expand=True),
        torchvision.transforms.functional.rotate(y_pred, 90, expand=True), 
        torchvision.transforms.functional.rotate(y, 90, expand=True), 
        torchvision.transforms.functional.rotate(left, 90, expand=True)
    ], dim=0)
    grid = torchvision.utils.make_grid(data, nrow=left.shape[0])
    grid = torchvision.transforms.functional.rotate(grid, 270, expand=True)
    plt.imshow(torch.permute(grid, (1, 2, 0)).cpu())
    plt.axis('off')
    plt.show()

In [None]:
def fit(model, train, valid, optimizer, loss, metrics, epochs, batch, save_freq=500, log_freq=1, log_perf_freq=2500, mode="best"):  
    # create dict for a history
    history = {loss.__name__: []} | {metric.__name__: [] for metric in metrics} | {'val_' + loss.__name__: []} | {"val_" + metric.__name__: [] for metric in metrics}
    best_loss = None
    
    # loop over epochs
    for epoch in range(epochs):
        print(f"Epoch: {epoch+1}/{epochs}")
        
        # create empty dict for loss and metrics
        loss_metrics = {loss.__name__: []} | {metric.__name__: [] for metric in metrics}

        # loop over training batches
        model.train(True)
        for step, record in enumerate(train):
            start = time.time()
            
            # extract the data
            left, right, y = record[0][0].to(device), record[0][1].to(device), record[1].to(device)

            # clear gradient
            model.zero_grad()
            
            # forward pass
            y_pred = model(left, right) 
            
            # calculate loss and apply the gradient
            loss_value = loss(y, y_pred)
            loss_value.backward()
            optimizer.step()
            
            # calculate metrics
            y_pred_detached = y_pred.detach()
            metrics_values = [metric(y, y_pred_detached) for metric in metrics]
            
            # save the loss and metrics
            loss_metrics[loss.__name__].append(loss_value.item())
            for metric, value in zip(metrics, metrics_values):
                loss_metrics[metric.__name__].append(value.item())
                
            end = time.time()
            
            # save the model
            if save_freq is not None and step % save_freq == 0 and step > 0:
                loss_avg = np.mean(loss_metrics[loss.__name__])
                if mode == "all" or (mode == "best" and (best_loss is None or best_loss > loss_avg)):
                    filename = f'../models/model_v6_1/fbnet_l={loss_avg}_e={epoch+1}_t={int(time.time())}.pth'
                    torch.save(model.state_dict(), filename)
                    
            # log the model performance
            if log_perf_freq is not None and step % log_perf_freq == 0 and step > 0:
                plot_triplet(left, right, y, y_pred.detach())
                
            # log the state
            if step % log_freq == 0:
                time_left = (end-start) * (len(train) - (step+1))
                print('\r[%5d/%5d] (eta: %s)' % ((step+1), len(train), time.strftime('%H:%M:%S', time.gmtime(time_left))), end='')
                for metric, values in loss_metrics.items():
                    print(f' {metric}=%.4f' % (np.mean(values)), end='')
            
        # save the training history
        for metric, values in loss_metrics.items():
            history[metric].extend(values)

        # setup dict for validation loss and metrics
        loss_metrics = {loss.__name__: []} | {metric.__name__: [] for metric in metrics}
        
        # process the full validating dataset
        model.train(False)
        for step, record in enumerate(valid):
            left, right, y = record[0][0].to(device), record[0][1].to(device), record[1].to(device)

            # forward pass
            y_pred = model(left, right).detach()
            
            # save the loss and metrics
            loss_metrics[loss.__name__].append(loss(y, y_pred).item())
            for metric, value in zip(metrics, [metric(y, y_pred) for metric in metrics]):
                loss_metrics[metric.__name__].append(value.item())
            
        # log the validation score & save the validation history
        for metric, values in loss_metrics.items():
            print(f' val_{metric}=%.4f' % (np.mean(values)), end='')
            history[f"val_{metric}"].extend(values)
            
        # restart state printer
        print()

    return history

In [None]:
history = fit(
    model = fbnet, 
    train = train_dataloader,
    valid = valid_dataloader,
    optimizer = optim.NAdam(fbnet.parameters(), lr=1e-4), 
    loss = loss, 
    metrics = [psnr],
    epochs = epochs, 
    batch = batch, 
    save_freq = 500,
    log_freq = 1,
    log_perf_freq = 2500,
    mode = "best"
)

In [None]:
torch.save(fbnet.state_dict(), f'../models/model_v6_1/fbnet_e={epochs}_t={int(time.time())}.pth')    

#### Evaluate the model

In [None]:
def norm_0_1(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def compress_loss(loss, steps):
    return [np.average(loss[steps*i:steps*(i+1)]) for i in range(len(loss)//steps)]

def plot_history(history, norm=norm_0_1, figsize=(10,5), steps=None):
    plt.clf()
    plt.figure(figsize=figsize)
    
    metrics = list(history.keys())
    metrics = [metric for metric in metrics if "val" not in metric]
    
    if steps is None:
        data = [(index, history[metric], history['val_'+metric], metric) for index, metric in enumerate(metrics)]
    else:
        tlen = len(history[metrics[0]])
        vlen = len(history['val_'+metrics[0]])
        data = [(index, compress_loss(history[metric], tlen//steps), compress_loss(history['val_'+metric], vlen//steps), metric) for index, metric in enumerate(metrics)]
        
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    epochs = range(1, len(data[0][1]) + 1)
    
    for index, value, val_value, metric in data:
        if norm is not None:
            buffer = norm(value + val_value)
            value = buffer[0:len(epochs)]
            val_value = buffer[len(epochs):]
        
        plt.plot(epochs, value, colors[index], label=f"train {metric}")
        plt.plot(epochs, val_value, colors[index]+'--', label=f"valid {metric}")
        
    plt.title("Comparision of training and validating scores")
    plt.xlabel('Steps')
    plt.ylabel("Values" if norm is None else "Values normalized")
    plt.legend(loc='upper right')
    plt.show()

In [None]:
def evaluate(model, data, loss, metrics):
    loss_metrics = {loss.__name__: []} | {metric.__name__: [] for metric in metrics}
    model.train(False)
    for step, record in enumerate(data):
        left, right, y = record[0][0].to(device), record[0][1].to(device), record[1].to(device)

        # forward pass
        y_pred = model(left, right).detach()

        # save the loss and metrics
        loss_metrics[loss.__name__].append(loss(y, y_pred).item())
        for metric, value in zip(metrics, [metric(y, y_pred) for metric in metrics]):
            loss_metrics[metric.__name__].append(value.item())
            
        print('\rProgress [%4d/%4d]' % ((step+1), len(data)), end='')
        
    print()
            
    return {k: np.mean(v) for k, v in loss_metrics.items()}

In [None]:
plot_history(history, norm=None, steps=200)

In [None]:
plot_history(history, steps=200)

In [None]:
results = evaluate(fbnet, test_dataloader, loss, [mae, mse, psnr])
for k, v in results.items():
    print('%s: %.6f' % (k, v))

#### Visualize net output

In [None]:
def display_grid(data, nrow, figsize):
    if figsize == 'auto':
        figsize = (20, 5*(data.shape[0]//nrow))
    
    if data is not None:
        plt.figure(figsize=figsize)
        grid = torchvision.utils.make_grid(data, nrow=nrow)
        plt.imshow(torch.permute(grid, (1, 2, 0)).cpu())
        plt.axis('off')
        plt.show()

In [None]:
def visualise_output(model, batches, figsize='auto'):
    data = None
    batch_size = batches[0][1].shape[0]
    model.train(False)
    for batch in batches:
        left, right, y = batch[0][0].to(device), batch[0][1].to(device), batch[1].to(device)
        y_pred = model(left, right).detach()
        for index in range(batch_size):
            cat_list = [
                torch.unsqueeze(left[index, :, :, :], 0),
                torch.unsqueeze(y[index, :, :, :], 0),
                torch.unsqueeze(y_pred[index, :, :, :], 0),
                torch.unsqueeze(right[index, :, :, :], 0)
            ]
            if data is not None:
                cat_list = [data] + cat_list
            data = torch.cat(cat_list, dim=0)
                
    display_grid(data, 4, figsize)

In [None]:
vis_iterator = iter(vis_dataloader)
visualise_output(fbnet, batches=[next(vis_iterator) for bi in range(len(vis_dataloader)) if bi in [0, 1, 2, 3, 4, 5]])

#### Visualize conv2d filters

In [None]:
def vis_conv2d_weight(kernels, all_kernels=True, nrow=32, padding=1, ch=0): 
    b, c, w, h = kernels.shape

    if all_kernels: 
        kernels = kernels.view(b*c, -1, w, h)
    elif c != 3: 
        kernels = kernels[:, ch, :, :].unsqueeze(dim=1)

    rows = np.min((kernels.shape[0] // nrow + 1, 64))    
    grid = torchvision.utils.make_grid(kernels, nrow=nrow, normalize=True, padding=padding)
    plt.figure(figsize=(nrow, rows))
    plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))
    plt.axis('off')
    plt.show()
    
    
def deprocess_image(img):
    img -= img.mean()
    img /= img.std() + 1e-5
    img *= 0.15

    img += 0.5
    img = np.clip(img, 0, 1)

    img *= 255
    img = np.clip(img, 0, 255).astype("uint8")
    
    return img
    
    
def append_image(image, append_image, row, col, margin):
    horizontal_start = row * height + row * margin
    horizontal_end = horizontal_start + height
    vertical_start = col * width + col * margin
    vertical_end = vertical_start + width
    image[horizontal_start : horizontal_end, vertical_start : vertical_end, : ] = append_image
    return image

    
def rows_cols(value):
    assert value >= 1
    rows, cols = 1, value
    for i in range(2, value//2+1):
        if value % i == 0:
            if np.abs(i - int(value / i)) < np.abs(rows - cols):
                rows = i
                cols = int(value / i)
    return rows, cols
    

def visualize_cnn_layers(model, margin=3, steps=10, lr=0.1, include_nested=True):
    assert margin >= 0, "Margin cannot be negative"
    assert steps > 0, "Steps has to be positive"
    
    activations = {}
    def hook_fun(model, input, output):
        activations['activation'] = output
    
    model.train(False)
    queue = list(model.named_children())
    while queue:
        name, layer = queue.pop(0)
        if include_nested:
            queue.extend([(f'{name}_{n}', l) for n, l in layer.named_children()])
        
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
            print(f"Layer name: {name}")
            
            f_count = layer.out_channels
            rows, cols = rows_cols(f_count)
            result = np.zeros((rows * height + (rows-1) * margin, cols * width + (cols-1) * margin, 3), dtype='uint8')

            for index in tqdm.tqdm(range(rows*cols)):
                i, j = index//cols, index % cols
                filter_index = j + (i * cols)
                
                hook = layer.register_forward_hook(hook_fun)
                noise = (np.random.rand(batch, model.c, model.h, model.w) * 0.2 + 0.4).astype('float32')
                tensor = torch.from_numpy(noise).to(device).requires_grad_(True)
                optimizer = optim.NAdam([tensor], lr=lr)
                
                for step in range(steps):
                    optimizer.zero_grad()
                    _ = model(tensor, tensor)
                    activation = activations['activation'][:, filter_index, :, :].unsqueeze(dim=1)
                    loss = torch.mean(activation)
                    loss.backward()
                    optimizer.step()
                    
                tensor = tensor.detach()[0].permute((1, 2, 0))
                filter_img = deprocess_image(tensor.cpu().numpy())
                result = append_image(result, filter_img, i, j, margin)
                hook.remove()

            plt.figure(figsize=(model.h // 2, model.w // 2))
            plt.imshow(result)
            plt.axis('off')
            plt.show()

In [None]:
visualize_cnn_layers(fbnet, margin=3, steps=10, lr=0.1, include_nested=True)

In [None]:
vis_conv2d_weight(fbnet.cnn_r1_1.weight.detach(), nrow=16)

#### Visualize inner optical flows

In [None]:
def visualise_flow(model, batches, figsize='auto'):
    flows, names, hooks = {}, [], []
    
    # hook registration function
    def get_activation(name):
        def hook(model, input, output):
            flows[name] = output[-1].detach()
        return hook

    # register hooks
    for child in fbnet.named_children():
        if "flow" in child[0]:
            hooks.append(child[1].register_forward_hook(get_activation(child[0])))
            names.append(child[0])
            
    # iterate over the given batches
    model.train(False)
    data, batch_size = None, batches[0][1].shape[0]
    resize = torchvision.transforms.Resize((height, width), antialias=True)
    for batch in batches:
        # forward pass (hooks register outputs)
        left, right = batch[0][0].to(device), batch[0][1].to(device)
        _ = model(left, right)
        
        # process each sample in the batch
        for index in range(batch_size):
            cat_list = [torch.unsqueeze(left[index, :, :, :], dim=0)]
            
            for name in names:
                flow = flows[name][index, :, :, :]
                flow = resize(torchvision.utils.flow_to_image(flow)) / 255.0
                cat_list.append(torch.unsqueeze(flow, dim=0))
                
            cat_list.append(torch.unsqueeze(right[index, :, :, :], dim=0))
            
            if data is not None:
                cat_list = [data] + cat_list
                
            data = torch.cat(cat_list, dim=0)

    # display data
    display_grid(data, (2 + len(names)), figsize)
        
    # remove hooks
    for hook in hooks:
        hook.remove()

In [None]:
vis_iterator = iter(vis_dataloader)
visualise_flow(fbnet, batches=[next(vis_iterator) for bi in range(len(vis_dataloader)) if bi in [0, 1, 2, 3, 4, 5]])

#### Visualise attention mask

In [None]:
def visualise_attention(model, batches, figsize='auto'):
    attentions, names, hooks = {}, [], []
    
    # hook registration function
    def get_activation(name, act, upsample):
        def hook(model, input, output):
            attentions[name] = upsample(act(output.detach()))
        return hook

    # register hooks
    for child in model.named_children():
        if isinstance(child[1], AttentionGate):
            hooks.append(child[1].ocnn.register_forward_hook(get_activation(child[0], child[1].out_act, child[1].upsample)))
            names.append(child[0])
            
    # iterate over the given batches
    model.train(False)
    data, batch_size = None, batches[0][1].shape[0]
    resize = torchvision.transforms.Resize((height, width), antialias=True)
    gray2rgb = torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)
    for batch in batches:
        # forward pass (hooks register outputs)
        left, right = batch[0][0].to(device), batch[0][1].to(device)
        _ = model(left, right)
        
        # process each sample in the batch
        for index in range(batch_size):
            cat_list = [torch.unsqueeze(left[index, :, :, :], dim=0)]
            
            for name in names:
                attention = attentions[name][index, :, :, :]
                attention = gray2rgb(resize(attention))
                cat_list.append(torch.unsqueeze(attention, dim=0))
                
            cat_list.append(torch.unsqueeze(right[index, :, :, :], dim=0))
            
            if data is not None:
                cat_list = [data] + cat_list
                
            data = torch.cat(cat_list, dim=0)

    # display data
    display_grid(data, (2 + len(names)), figsize)
        
    # remove hooks
    for hook in hooks:
        hook.remove()

In [None]:
vis_iterator = iter(vis_dataloader)
visualise_attention(fbnet, batches=[next(vis_iterator) for bi in range(len(vis_dataloader)) if bi in [0, 1, 2, 3, 4, 5]])

#### Load the model

In [None]:
fbnet.load_state_dict(torch.load('../tmp/model_v6_3/1686927783/models/fbnet_l=1.4013603990221608_e=1_s=12001_t=1686933041.pt'))    