<h1> RBPN-VSR </h1>

**Index:**

1) Imports
2) Hyperparameters
3) Helper Functions
4) Base Networks
5) Derivative Networks
6) DBPNS and RBPN
7) Image Processing Helper Functions
8) Dataset and Dataloader
9) (Optional) Visualisation of Dataset
10) Model and Training Loop
11) Upscale function
12) Statistics


TODO:
Problem with bottlenecked (Ignore for now)

Imports:

In [None]:
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image, ImageOps
import torch.optim as optim

import os
from os.path import join
from math import log2
import random
from torchvision.models import vgg19, VGG19_Weights

print("Succesfully loaded imports")

Hyperparameters:

In [None]:
epochs = 2
lr = 1e-4
base_filter = 256
feat = 64
n_resblock = 5
nFrames = 7
upscale_factor = 2
patch_size = 64 #0 to use original frame size
batch_size = 5
test_batch_size = 1
snapshot = 1
record = 50

residual = True
future_frame = False #Upscale function assumes the training has been done using future_frame=True. Will modify it later...
data_augmentation = True
useBottlenecked = False
useInitialLayer = False
usePixelShuffle = True
pretrained = False
testing = False
doScaleFlow = False

data_dir_T = "/kaggle/input/vsr-7-frame-videos-dataset/vimeo_test_clean/S_Train" #1 folder above where the all the subfolders containg 7-image sequences are kept
file_list_T = "../S_Train_List.txt" #the txt file containing the list of subfolders
data_dir_V = "/kaggle/input/vsr-7-frame-videos-dataset/vimeo_test_clean/S_Test"
file_list_V = "../S_Test_List.txt"
save_folder = "./"
logfile_name = "VSR_P_log.txt"
param_filename = "VSR_P_"
pretrained_params = "./VSR_P_1.pth"

epoch_loss_hist = []
iter_loss_hist = []

In [None]:
print(f"original_lr:{lr}, batch_size:{batch_size},\
               upscale_factor:{upscale_factor}, future_frame:{future_frame}, nFrames:{nFrames}, residual:{residual}, base_filter:{base_filter},\
                feat:{feat}, pretrained:{pretrained}, pretrained_params:{pretrained_params if pretrained else None}, n_resblock:{n_resblock},\
                data_augmentation:{data_augmentation}, useBottlenecked:{useBottlenecked}, useInitialLayer:{useInitialLayer}, usePixelShuffle:{usePixelShuffle}, doScaleFlow:{doScaleFlow}")

Helper Functions:

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class vggL(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:25].eval().to(device)
        self.MSE = nn.MSELoss()

    def forward(self, first, second):
        vgg_first = self.vgg(first)
        vgg_second = self.vgg(second)
        perceptual_loss = self.MSE(vgg_first, vgg_second)
        return perceptual_loss

def GetLayer(nameOfLayer, *num_params):
    if nameOfLayer=='batch':
        return nn.BatchNorm2d(*num_params)
    elif nameOfLayer=='instance':
        return nn.InstanceNorm2d(*num_params)
    elif nameOfLayer=='relu':
        return nn.ReLU(True)
    elif nameOfLayer=='lrelu':
        return nn.LeakyReLU(0.2, True)
    elif nameOfLayer=='prelu':
        return nn.PReLU(*num_params) #Otherwise 1
    elif nameOfLayer=='tanh':
        return nn.Tanh()
    elif nameOfLayer=='sigmoid':
        return nn.Sigmoid()
    else:
        return nn.Identity()

ToResizedPIL = transforms.Compose([
    transforms.Resize((256, 448), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToPILImage()
])

def PrintNetwork(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print(f'Total number of parameters: {num_params}')

def Checkpoint(model, save_folder, curr_epoch, loss, file_name="", paramLogFilename=""):
    torch.save(model.state_dict(), save_folder+file_name+f"{curr_epoch}.pth")
    print(f"Checkpoint saved to {file_name}{curr_epoch}.pth")
    log = open(save_folder+paramLogFilename, 'a')
    log.write(f"Version:{file_name}{curr_epoch}.pth, Epoch:{curr_epoch}/{epochs}, original_lr:{lr}, batch_size:{batch_size}, gen_loss:{loss}, \
               upscale_factor:{upscale_factor}, future_frame:{future_frame}, nFrames:{nFrames}, residual:{residual}, base_filter:{base_filter},\
                feat:{feat}, pretrained:{pretrained}, pretrained_params:{pretrained_params if pretrained else None}, n_resblock:{n_resblock},\
                data_augmentation:{data_augmentation}, useBottlenecked:{useBottlenecked}, useInitialLayer:{useInitialLayer}, usePixelShuffle:{usePixelShuffle}, doScaleFlow:{doScaleFlow}")
    log.close()
    print(f"Checkpoint details updated to {save_folder+paramLogFilename}")

def custom_stretch(img_mat):
    img_mat_c = img_mat.astype(np.float32)
    max = np.max(img_mat)
    min = np.min(img_mat)
    return np.floor((255*(img_mat_c-min)/(max-min))).astype(np.uint8)


Base Networks:

In [None]:
#Conv2d -> BN -> Activation ->
#[B, input_size, H, W] -> [B, output_size, H, W] (Assuming default KSP)
class ConvBlock(nn.Module):
    def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None):
        super(ConvBlock, self).__init__()

        self.conv = nn.Conv2d(input_size, output_size, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.bn = GetLayer(norm, output_size)
        self.act = GetLayer(activation, output_size)

    def forward(self, x):
        out = self.bn(self.conv(x))
        return self.act(out)

#ConvTranspose -> BN -> Activation ->
#[B, input_size, H, W] -> [B, output_size, 2*H, 2*W] (Assuming default KSP)
class DeconvBlock(nn.Module):
    def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None):
        super(DeconvBlock, self).__init__()

        self.deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
        self.bn = GetLayer(norm, output_size)
        self.act = GetLayer(activation, output_size)

    def forward(self, x):
        out = self.bn(self.deconv(x))
        return self.act(out)


#Essentially a series of alternating ConvBlocks that take n_feat input channels and outputs n_feat output channels
#with no change in input image size and a pixel shuffle layer which takes the factor 4 in the received input channels and splits the
#extra pixels in half between width and height, doubling both dimensions and thus doubling the input size as output.
#These 2 layers are repeated at most k times where scale = 2**k, so if scale=8, the final result of these string of pairs of layers
#is an image with same number of channels as you had inputted but with 8x the size due to there being 3 x2 pairs of these layers.
#Optional batch norm after each pair.

# (Conv -> PS -> BN) -> (Conv -> PS -> BN) -> ... -> Activation ->
#[B, n_feat, H, W] -> [B, n_feat, scale*H, scale*W]
class Upsampler(nn.Module):
    def __init__(self, scale, n_feat, bn=False, activation='prelu', bias=True):
        super(Upsampler, self).__init__()

        modules = []
        for _ in range(int(log2(scale))):
            modules.append(ConvBlock(n_feat, 4 * n_feat, 3, 1, 1, bias, activation=None, norm=None))
            modules.append(nn.PixelShuffle(2))
            if bn:
                modules.append(nn.BatchNorm2d(n_feat))
        self.up = nn.Sequential(*modules)
        self.act = GetLayer(activation, n_feat)

    def forward(self, x):
        out = self.up(x)
        out = self.act(out)
        return out

Derivative Networks:

In [None]:
#None of the Conv2d layers change the size of the input images using default K,S,P
#[B, num_filter, H, W] -> [B, num_filter, H, W] (Assuming default KSP)
class ResnetBlock(nn.Module):
    def __init__(self, num_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm='batch', useBottlenecked=useBottlenecked):
        super(ResnetBlock, self).__init__()

        self.bottleneck_in = nn.Identity()
        self.bottleneck_out = nn.Identity()
        if useBottlenecked:
            print("Changing filter sizes:")
            orig_filter = num_filter
            num_filter = num_filter//2
            print(f"Orig_filter:{orig_filter}, num_filter:{num_filter}")
            self.bottleneck_in = nn.Conv2d(orig_filter, num_filter, kernel_size=3, stride=1, padding=1, bias=True)
            self.bottleneck_out = nn.Conv2d(num_filter, orig_filter, kernel_size=3, stride=1, padding=1, bias=True)

        self.conv1 = nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias)
        self.conv2 = nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias)
        self.bn = GetLayer(norm, num_filter)
        self.act = GetLayer(activation, num_filter)

    def forward(self, x):
        x = self.bottleneck_in(x)
        residual = x
        out = self.bn(self.conv1(x))
        out = self.act(out)
        out = self.bn(self.conv2(out))
        out = torch.add(out, residual)
        out = self.bottleneck_out(out)
        out = self.act(out)
        return out

#[B, num_filter, H, W] -> [B, num_filter, stride*H, stride*W]
#If you want to affect size of output, change the stride to be 2**n
class UpBlock(nn.Module):
    def __init__(self, num_filter, kernel_size=8, stride=4, padding=2, num_stages=1, activation='prelu', norm=None, useInitialLayer=useInitialLayer, usePS=usePixelShuffle):
        super(UpBlock, self).__init__()

        self.initial_layer = nn.Identity()
        if useInitialLayer:
            self.initial_layer = ConvBlock(num_stages*num_filter, num_filter, kernel_size=1, stride=1, padding=0, activation=activation, norm=norm)

        self.up1 = Upsampler(stride, num_filter, activation=activation) if usePS else DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation=activation, norm=norm)
        self.up2 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)
        self.up3 = Upsampler(stride, num_filter, activation=activation) if usePS else DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation=activation, norm='batch')

    def forward(self, x):
        x = self.initial_layer(x)
        h0 = self.up1(x)
        l0 = self.up2(h0)
        h1 = self.up3(l0 - x)
        return h1 + h0

#Scale=Stride
#[B, num_filter, H, W] -> [B, num_filter, H/4, W/4] (Assuming default KSP)
class DownBlock(nn.Module):
    def __init__(self, num_filter, kernel_size=8, stride=4, padding=2, num_stages=1, activation='prelu', norm=None, useInitialLayer=useInitialLayer, usePS=usePixelShuffle):
        super(DownBlock, self).__init__()

        self.initial_layer = nn.Identity()
        if useInitialLayer:
            self.initial_layer = ConvBlock(num_stages*num_filter, num_filter, kernel_size=1, stride=1, padding=0, activation=activation, norm=norm)

        self.down1 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation=activation, norm=norm)
        self.down2 = Upsampler(stride, num_filter, activation=activation) if usePS else DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation=activation, norm=norm)
        self.down3 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation=activation, norm=norm)

    def forward(self, x):
        x = self.initial_layer(x)
        l0 = self.down1(x)
        h0 = self.down2(l0)
        l1 = self.down3(h0-x)
        return l0 + l1

DBPNS and RBPN:

In [None]:
#Essentially a series of upsampling and downsampling layers using the UpBlock and DownBlock respectively. Num_stages might be hardcoded to 3. I can't tell by
#how the number of pairs of these blocks are hardcoded and not variable... After looking at the supermost caller of this function, num_stages is indeed inputted as 3
#Output has (feat) number of out channels and outputs 2x the input dimensions
#The convolution-related layers are manually initialized with a mathematical distribution which is the kaiming normal distribution seen below
#The 2 conv blocks do not affect input dimension size, only number of channels

#Can change the num_stages so that the computation computes pairs of the blocks num_stages number of times but only concatenates the last 3 outputs...
class DBPNS(nn.Module):
    def __init__(self, base_filter, feat, num_stages, scale_factor):
        super(DBPNS, self).__init__()

        #Mostly defined for scale_factors of 2,4,8, ..., 2**n
        stride = scale_factor
        kernel = stride + 4
        padding = 2

        self.feat1 = ConvBlock(base_filter, feat, 1, 1, 0, activation='prelu', norm='batch')
        #Back-projection stages
        #Consider extending this to any number of num_stages other than just 3.
        self.up1 = UpBlock(num_filter=feat, kernel_size=kernel, stride=stride, padding=padding)
        self.down1 = DownBlock(num_filter=feat, kernel_size=kernel, stride=stride, padding=padding)
        self.up2 = UpBlock(num_filter=feat, kernel_size=kernel, stride=stride, padding=padding)
        self.down2 = DownBlock(num_filter=feat, kernel_size=kernel, stride=stride, padding=padding)
        self.up3 = UpBlock(num_filter=feat, kernel_size=kernel, stride=stride, padding=padding)
        #Reconstruction
        self.output = ConvBlock(num_stages*feat, feat, 1, 1, 0, activation=None, norm='batch') #num_stages*feat is essentially 3*feat unless I plan to modularize this

        #Initialization of convolution and conv transpose layers using a mathematical distribution called kaiming normal distribution
        # for m in self.modules():
        #     class_names = m.__class__.__name__
        #     if class_names.find('Conv2d') != -1 or class_names.find('ConvTranspose2d') != -1:
        #         nn.init.kaiming_normal_(m.weight)
        #         if m.bias is not None:
        #             m.bias.data.zero_()

    def forward(self, x):
        x = self.feat1(x)

        h1 = self.up1(x)
        h2 = self.up2(self.down1(h1))
        h3 = self.up3(self.down2(h2))
        #Think of each h_i being in the shape [B,C,H,W]=[B,feat,H,W] and torch.cat along dimension 1 adds feat 3 times => [B, 3*feat, H, W]
        #This furthers my suspicion that num_stages is actually hardcoded as 3 otherwise this wouldn't work since self.out would take num_stages*feat =/= 3*feat input channels
        x = self.output(torch.cat((h3, h2, h1),1))
        return x


#Output has num_channels output channels with x scale factor size
class RBPN(nn.Module):
    def __init__(self, num_channels, base_filter, feat, num_stages, n_resblock, nFrames, scale_factor):
        super(RBPN, self).__init__()
        #base_filter=256
        #feat=64
        self.nFrames = nFrames

        #Mostly defined for scale_factors of 2,4,8, ..., 2**n
        stride = scale_factor
        kernel = stride + 4
        padding = 2

        #Initial Feature Extraction (These 2 layers are not consecutive during forward pass.
        #8 input channels for feat1 because input frame has 3 channels, neighbouring
		#frame has 3 channels and flow between these frames has 2 channels for (x,y) displacement vector)
        self.feat0 = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='prelu', norm='batch')
        self.feat1 = ConvBlock(8, base_filter, 3, 1, 1, activation='prelu', norm=None)

        # --- START OF ENCODER --- #
        ###DBPNS (Output is x scale_factor) (SISR Block)
        self.DBPN = DBPNS(base_filter, feat, num_stages, scale_factor)

        #Res-Block1 (Resnet MISR Block) (Output is times scale_factor)
        modules_body1 = [ResnetBlock(base_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm='batch', useBottlenecked=useBottlenecked) for _ in range(n_resblock)]
        modules_body1.append(DeconvBlock(base_filter, feat, kernel, stride, padding, activation='prelu', norm='batch')) # x scale_factor
        self.res_feat1 = nn.Sequential(*modules_body1)

        #Res-Block2 (String of residual blocks)(Output_size = Input_size and number of channels is constant)(As per Fig 4. a) in the original paper)
        modules_body2 = [ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm='batch', useBottlenecked=useBottlenecked) for _ in range(n_resblock)]
        modules_body2.append(ConvBlock(feat, feat, 3, 1, 1, activation='prelu', norm=None)) #No change to sizes and channels
        self.res_feat2 = nn.Sequential(*modules_body2)

        # --- END OF ENCODER --- #

        #Res-Block3 (Decoder block as per diagram 4 b) of the original paper)(Output size is input size divided by scale_factor)(Output channels = base_filter always)
        modules_body3 = [ResnetBlock(feat, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm='batch', useBottlenecked=useBottlenecked) for _ in range(n_resblock)]
        modules_body3.append(ConvBlock(feat, base_filter, kernel, stride, padding, activation='prelu', norm='batch')) #Divides size by scale_factor
        self.res_feat3 = nn.Sequential(*modules_body3)

        #Reconstruction
        self.output = ConvBlock((nFrames-1)*feat, num_channels, 3, 1, 1, activation=None, norm=None) #Doesn't change size of input. nFrames is the
                                                                                                     #interval of frames around the initial input frame
                                                                                                     #to consider for flow computation and MISR. nFrames<=7

        #Initialization of Convolution-related layers.
        for m in self.modules():
            class_names = m.__class__.__name__
            if class_names.find('Conv2d') != -1 or class_names.find('ConvTranspose2d') != -1:
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, x, neighbour, flow):
        feat_frame=[]
		    #Concat
        for j in range(len(neighbour)):
            feat_frame.append(self.feat1(torch.cat((x, neighbour[j], flow[j]),1))) #8 input channels
        ### initial feature extraction
        feat_input = self.feat0(x)

        ####Projection
        Ht = []
        for j in range(len(neighbour)):
            h0 = self.DBPN(feat_input) #SISR using DBPNS, output is x scale_factor
            h1 = self.res_feat1(feat_frame[j]) #MISR using Resnet, output is x scale_factor
            e = h0-h1 #SISR - MISR
            e = self.res_feat2(e) #Does not affect input size.
            h = h0+e #As per diagram 4 a) of original paper
            Ht.append(h) #All outputs of the projection are finally needed to be concatenated to get final output. Each h has (feat) output channels
            feat_input = self.res_feat3(h) #Input to the next projection module updated - This is the recurrent part of the network.
        ####Reconstruction
        out = torch.cat(Ht,1) #Concatenating along channels
        output = self.output(out) #No change in size. Output has num_channel output channels. Convolution layer as per stated architecture.
        return output

Image Processing Helper Functions:

In [None]:
def modcrop(img, modulo):
    (ih, iw) = img.size
    ih = ih - (ih%modulo)
    iw = iw - (iw%modulo)
    img = img.crop((0, 0, ih, iw))
    return img

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def rescale_flow(x,min_range,max_range):
    max_val = np.max(x)
    min_val = np.min(x)
    l = x
    try:
        l = ((max_range-min_range)/(max_val-min_val))*(x-min_val)+min_range
    except RuntimeWarning:
        l = x
    finally:
        return l

def get_flow(im1, im2, returnTensor=True, rescale=False):
    im1_mat = np.array(im1.convert('L'), dtype=np.uint8)
    im2_mat = np.array(im2.convert('L'), dtype=np.uint8)

    h,w = im1_mat.shape
    im1_mat = im1_mat.reshape(h, w, 1)
    im2_mat = im2_mat.reshape(h, w, 1)

    flow = cv2.calcOpticalFlowFarneback(im1_mat, im2_mat, None, 0.5, 3, 15, 3, 5, 1.1, 0)
    if rescale:
        flow = rescale_flow(flow, 0, 1)
    flow = flow.reshape(2,h,w)
    if returnTensor:
        flow = torch.from_numpy(flow)

    return flow

def rescale_img(img_in, scale):
    size_in = img_in.size
    new_size_in = tuple([int(x * scale) for x in size_in])
    img_in = img_in.resize(new_size_in, resample=Image.BICUBIC)
    return img_in


#Obtains a patch of an image with bounding box of size patch_size with upper left corner at (ix,iy). Returns PIL Image
def get_patch(img_in, img_tar, img_nn, patch_size, scale, ix=-1, iy=-1):
    (ih, iw) = img_in.size

    patch_mult = scale
    tp = patch_mult * patch_size
    ip = tp // scale
    #ip = patch_size if scale = patch_mult = integer, which is the case for default paramaters...

    if ix == -1:
        ix = torch.randint(0, iw - ip + 1,(1,)).item()
    if iy == -1:
        iy = torch.randint(0, ih - ip + 1,(1,)).item()

    (tx, ty) = (scale * ix, scale * iy)

    img_in = img_in.crop((iy,ix,iy + ip, ix + ip))
    img_tar = img_tar.crop((ty,tx,ty + tp, tx + tp))
    img_nn = [j.crop((iy,ix,iy + ip, ix + ip)) for j in img_nn]

    info_patch = {
        'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp}

    return img_in, img_tar, img_nn, info_patch

def augment(img_in, img_tar, img_nn, flip_h=True, rot=True):
    info_aug = {'flip_h': False, 'flip_v': False, 'trans': False} #used for seeing what actually happened to a particular batch of images. Debugging

    if random.random() < 0.5 and flip_h:
        img_in = ImageOps.flip(img_in)
        img_tar = ImageOps.flip(img_tar)
        img_nn = [ImageOps.flip(j) for j in img_nn]
        info_aug['flip_h'] = True

    if rot:
        if random.random() < 0.5:
            img_in = ImageOps.mirror(img_in)
            img_tar = ImageOps.mirror(img_tar)
            img_nn = [ImageOps.mirror(j) for j in img_nn]
            info_aug['flip_v'] = True
        if random.random() < 0.5:
            img_in = img_in.rotate(180)
            img_tar = img_tar.rotate(180)
            img_nn = [j.rotate(180) for j in img_nn]
            info_aug['trans'] = True

    return img_in, img_tar, img_nn, info_aug

Dataset and Dataloaders:

In [None]:
#nFrames = 7 default (Can be reduced)
#file path is actually the path to the folder containing the 7 image sequence w.r.t a cwd
def load_img(filepath, nFrames, scale, useModcrop=True):
    seq = [i for i in range(1, nFrames)] #seq=[1,2,3,4,5,6]

    target = Image.open(join(filepath,'im'+str(nFrames)+'.png')).convert('RGB')
    if useModcrop:
        target = modcrop(target, scale) #scale is default 4, target is just the 7th frame modcropped

    input = target.resize((int(target.size[0]/scale),int(target.size[1]/scale)), Image.BICUBIC)
    neighbour = [modcrop(Image.open(filepath+'/im'+str(j)+'.png').convert('RGB'), scale).resize((int(target.size[0]/scale),int(target.size[1]/scale)), Image.BICUBIC) for j in reversed(seq)]

    #Neighbours is a list of PIL images in the order of frames 6,5,...,1 which has 1) been modcropped 2) then resized to same dimensions as input, in that order.
    return target, input, neighbour

#does what load_img does but for the 4th frame, not the 7th
def load_img_future(filepath, nFrames, scale, useModcrop=True):
    tt = int(nFrames/2)
    target = Image.open(join(filepath,'im'+str(nFrames)+'.png')).convert('RGB')
    if useModcrop:
        target = modcrop(target, scale)
    inp = target.resize((int(target.size[0]/scale),int(target.size[1]/scale)), Image.BICUBIC)
    neighbour = []
    seq = [x for x in range(4-tt,5+tt) if x!=4]

    for j in seq:
        neighbour.append(modcrop(Image.open(filepath+'/im'+str(j)+'.png').convert('RGB'), scale).resize((int(target.size[0]/scale),int(target.size[1]/scale)), Image.BICUBIC))
    return target, inp, neighbour

class VideoSequenceSet(data.Dataset):
    def __init__(self, image_dir, image_list, nFrames=7, scale_factor=2, patch_size=64, useFuture=True, doAugmentation = True, train=True, useModcrop=True):
        super(VideoSequenceSet, self).__init__()
        listOfVideos = [line.rstrip() for line in open(join(image_dir,image_list))]
        self.pathToVideos = [join(image_dir,x) for x in listOfVideos]

        self.num_frames = nFrames
        self.upscale_factor = scale_factor
        self.patch_size = patch_size

        self.useFuture = useFuture
        self.train = train
        self.doAugmentation = doAugmentation
        self.useModcrop = useModcrop

    def __getitem__(self, index):
        if self.useFuture:
            target, inp, neighbour = load_img_future(self.pathToVideos[index], self.num_frames, self.upscale_factor, useModcrop=self.useModcrop)
        else:
            target, inp, neighbour = load_img(self.pathToVideos[index], self.num_frames, self.upscale_factor, useModcrop=self.useModcrop)

        if self.train:
            if self.patch_size != 0:
                inp, target, neighbour, _ = get_patch(inp,target,neighbour,self.patch_size, self.upscale_factor)

            if self.doAugmentation: #default True
                inp, target, neighbour, _ = augment(inp, target, neighbour, flip_h=True, rot=True)

        flow_list = [get_flow(inp,j, returnTensor=True, rescale=doScaleFlow) for j in neighbour]
        bicubic = rescale_img(inp, self.upscale_factor)

        T = transforms.ToTensor()
        target = T(target)
        inp = T(inp)
        bicubic = T(bicubic)
        neighbour = [T(j) for j in neighbour]

        return inp, target, neighbour, flow_list, bicubic

    def __len__(self):
        return len(self.pathToVideos)

train_set = VideoSequenceSet(data_dir_T, file_list_T, nFrames=nFrames, scale_factor=upscale_factor, patch_size=patch_size, \
                             useFuture=future_frame, train=True, doAugmentation=data_augmentation, useModcrop=True)
val_set = VideoSequenceSet(data_dir_V, file_list_V, nFrames=nFrames, scale_factor=upscale_factor, patch_size=0, \
                             useFuture=future_frame, train=False, doAugmentation=False, useModcrop=False)

videoT_loader = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2)
videoV_loader = data.DataLoader(dataset=val_set, batch_size=test_batch_size, shuffle=True, num_workers=2)

(Optional) Visualisation of dataset:

[Assumes that future_frame=True]

In [None]:
for idx, x in enumerate(videoV_loader):
    if idx==0:
        inp,t,n,fl,b = x

        fig, axs = plt.subplots(1,3, figsize=(16,4))
        axs[0].imshow(inp[0].permute(1,2,0).numpy())
        axs[0].set_title("Input")
        axs[1].imshow(t[0].permute(1,2,0).numpy())
        axs[1].set_title("Target")
        axs[2].imshow(b[0].permute(1,2,0).numpy())
        axs[2].set_title("Bicubic")

        plt.show()

        step=3
        fig, axs = plt.subplots(5,6, figsize=(16,9))
        for i in range(nFrames-1):
            axs[0][i].imshow(inp[0].permute(1,2,0).numpy())
            axs[0][i].set_title(f"Input")

            axs[1][i].imshow(n[i][0].permute(1,2,0).numpy())
            axs[1][i].set_title(f"n_0({i+1 if (i+1)<4 else i+2})")

            flow = fl[i][0].permute(1,2,0).numpy()
            axs[2][i].quiver(np.arange(0, flow.shape[1], step), np.arange(flow.shape[0], -1, -step), flow[::step, ::step, 0], flow[::step, ::step, 1])
            axs[2][i].set_title(f"Quiver Diag for n_0({i+1 if (i+1)<4 else i+2})")

            axs[3][i].imshow(custom_stretch(fl[i][0][0].numpy()), cmap='gray')
            axs[3][i].set_title(f"Flow(x) for n_0({i+1 if (i+1)<4 else i+2})")

            axs[4][i].imshow(custom_stretch(fl[i][0][1].numpy()), cmap='gray')
            axs[4][i].set_title(f"Flow(y) for n_0({i+1 if (i+1)<4 else i+2})")

        for i in range(5):
            for j in range(6):
                axs[i][j].axis('off')
        plt.show()
    else:
        break

Model and Training loop:

In [None]:
model = RBPN(num_channels=3, base_filter=base_filter,  feat = feat, num_stages=3, n_resblock=n_resblock, nFrames=nFrames, scale_factor=upscale_factor).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
criterion = nn.L1Loss()
vgg_loss = vggL()

PrintNetwork(model)

In [None]:
def show_examples(model, loader):
    model.eval()
    model = model.cpu()
    chosen_batch = random.randint(1, len(loader)-1)
    fig, axs = plt.subplots(1, 4, figsize=(14, 10))

    for idx, batch in enumerate(loader, 1): #Batch size=8 default
        if chosen_batch==idx:
            inp, target, neigbor, flow, bicubic = batch[0].cpu(), batch[1].cpu(), [x.cpu() for x in batch[2]], [x.cpu() for x in batch[3]], batch[4].cpu()
            chosen = random.randint(0, len(inp)-1)
            with torch.no_grad():
                prediction = model(inp, neigbor, flow).cpu()

            axs[0].set_axis_off()
            axs[0].imshow(inp[chosen].detach().permute(1, 2, 0).numpy())
            axs[0].set_title("Input")

            axs[1].set_axis_off()
            axs[1].imshow(prediction.detach().permute(0, 2, 3, 1)[0].numpy())
            axs[1].set_title("Predicted")

            axs[2].set_axis_off()
            axs[2].imshow(target[chosen].detach().permute(1, 2, 0).numpy())
            axs[2].set_title("Target")

            axs[3].set_axis_off()
            axs[3].imshow(bicubic[chosen].detach().permute(1, 2, 0).numpy())
            axs[3].set_title("Bicubic")

            plt.show()

    plt.show()
    model.train()
    model = model.to(device)
    
l = len(videoT_loader)
def train(model, optimizer, criterion, loader):
    epoch_loss = 0
    print("Started a new train function call!")
    model.train()

    for idx, batch in enumerate(loader, 1): #Batch size=8 default
        inp, target, neigbor, flow, bicubic = batch[0].to(device), batch[1].to(device), [x.to(device) for x in batch[2]], [x.to(device) for x in batch[3]], batch[4].to(device)
        prediction = model(inp, neigbor, flow)
        if (idx+1) % (np.ceil(l/4)) == 0:
            print("Learning rate now decaying by half...")
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 2.0 #This could be implemented via lr scheduler. Learn about that
            print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
        if residual:
            prediction = prediction + bicubic

        loss = criterion(prediction, target)+vgg_loss(prediction, target)
        epoch_loss = loss.data

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx%record == 0:
            print(f"Still training... Currently at idx={idx}/{l} with current loss={epoch_loss}")
            iter_loss_hist.append(epoch_loss)
            show_examples(model, videoV_loader)


    return epoch_loss

In [None]:
if pretrained:
    if pretrained_params=="" or pretrained_params is None:
        print("File path to pretrained parameters is empty...")
    else:
        model.load_state_dict(torch.load(pretrained_params, map_location=device))
        print("Parameters successfully loaded to model")

In [None]:
if not testing:
    for epoch in range(epochs):
        epoch_loss = train(model, optimizer, criterion, videoT_loader)
        print(f"Epoch:{epoch+1}/{epochs}, loss={epoch_loss}, Progress:")
        if (epoch+1) % (snapshot) == 0: #Every snapshot epochs, save params.
            Checkpoint(model, save_folder, epoch+1, epoch_loss, param_filename, logfile_name)
        epoch_loss_hist.append(epoch_loss)
        show_examples(model, videoV_loader)


        # learning rate is decayed by a factor of 10 every half of total epochs
        if (epoch+1) % (np.ceil(epochs/2)) == 0:
            print("Learning rate now decaying by half...")
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10.0 #This could be implemented via lr scheduler. Learn about that
            print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))

Upscale function:

In [None]:
T = transforms.ToTensor()

def rescale_flow(x,min_range,max_range):
    max_val = np.max(x)
    min_val = np.min(x)
    l = x
    try:
        l = ((max_range-min_range)/(max_val-min_val))*(x-min_val)+min_range
    except RuntimeWarning:
        l = x
    finally:
        return l

def get_flow(im1, im2, returnTensor=True, rescale=False):
    im1_mat = np.array(im1.convert('L'), dtype=np.uint8)
    im2_mat = np.array(im2.convert('L'), dtype=np.uint8)

    h,w = im1_mat.shape
    im1_mat = im1_mat.reshape(h, w, 1)
    im2_mat = im2_mat.reshape(h, w, 1)

    flow = cv2.calcOpticalFlowFarneback(im1_mat, im2_mat, None, 0.5, 3, 15, 3, 5, 1.1, 0)
    if rescale:
        flow = rescale_flow(flow, 0, 1)
    flow = flow.reshape(2,h,w)
    if returnTensor:
        flow = torch.from_numpy(flow)

    return flow

def keyframes_to_video(image_list, video_path, fps):
    image_list_numpy = [image[0].permute(1,2,0).numpy() for image in image_list]
    h, w, c = image_list_numpy[0].shape
    compiled_video = cv2.VideoWriter(video_path, fourcc = cv2.VideoWriter_fourcc('D','I','V','X'), fps = fps, frameSize = (w, h))
    print(image_list_numpy[0].shape)
    for image in image_list_numpy:
        compiled_video.write(image.astype(np.uint8))

    compiled_video.release()

def video_to_keyframes(video_path, target_fps = 15):
    captured_video = cv2.VideoCapture(video_path)
    fps = round(captured_video.get(cv2.CAP_PROP_FPS))
    skip_size = round(fps/target_fps)
    curr_frame = 0

    keyframes=[]
    while True:
        ret, frame = captured_video.read()
        if not ret:
            break
        elif curr_frame % skip_size == 0:
            keyframes.append(ToResizedPIL(torch.from_numpy(frame).permute(2,0,1)))
        curr_frame+=1

    captured_video.release()
    return keyframes

def upscale(path_to_vid, upscaler, store_vid_path, frame_interval=nFrames, fps=15, device=device):
    print("Upscaling your video now...")
    input_vid_keyframes=video_to_keyframes(path_to_vid)
    output_vid_keyframes=[]

    upscaler.to(device)
    upscaler.eval()

    length=len(input_vid_keyframes)
    y = (frame_interval-1)//2
    width, height = input_vid_keyframes[0].size

    for i in range(length):
        if (i+1)%5 == 0:
            print(f"{((i+1)/length)*100:.2f}% done...")
        if i<y:
            l=0
        elif i>=length-y:
            l=length-frame_interval
        else:
            l=i-y

        if i>length-1-y:
            h=length
        elif i<=y:
            h=frame_interval
        else:
            h=i+y+1
        #print("Started getting neighbours!")
        neighbours = [input_vid_keyframes[j] for j in [j for j in range(l, h) if j!=i]]
        #print("Finished getting neighbours!")
        #print("Started getting flows!")
        flows = [get_flow(input_vid_keyframes[i], n, returnTensor=True, rescale=True).reshape(1,2,height,width).to(device) for n in neighbours]
        #print("Finished getting flows! Now reshaping and converting everything to tensors!")
        inp = T(input_vid_keyframes[i]).reshape(1,3,height,width).to(device)
        neighbours = [T(n).reshape(1,3,height,width).to(device) for n in neighbours]
        #print("Finished reshaping!")
        model.zero_grad()
        with torch.no_grad():
          output_vid_keyframes.append(upscaler(inp, neighbours, flows))
        #print("Finished computing flows. Appending to output_vid_keyframes done!")

    #print("Starting keyframes_to_video!")
    keyframes_to_video([x.cpu() for x in output_vid_keyframes], store_vid_path, fps)
    print(f"Finished upscaling. Saved to {store_vid_path}")

    upscaler.to(device)
    upscaler.train()

    #Convert vid to keyframes
    #Rescale each keyframe to trained sizes
    #Store keyframes into a list
    #Consider an n-frame interval around the central frame
    #Compute the flow w.r.t central frame
    #Pass frame, neighbours and flows into model
    #Store frame into an output list
    #keyframes (from list) to video
    #store video at store_vid_path

In [None]:
#upscale("Video.mp4",model, "video_gen.avi")

In [None]:
# x = [x for x in range(150)]
# y = (nFrames-1)//2
# for i in range(len(x)):
#     if i<y:
#         l=0
#     elif i>=len(x)-y:
#         l=len(x)-nFrames
#     else:
#         l=i-y

#     if i>len(x)-1-y:
#         h=len(x)
#     elif i<=y:
#         h=nFrames
#     else:
#         h=i+y+1
#     print(f"i:{i}, l:{l}, h:{h}")
#     print([x[j] for j in range(l, h) if j!=i])
#     print("")

Statistics:

In [None]:
fig, axs = plt.subplots(1,2)

axs[0].scatter([x for x in range(len(epoch_loss_hist))], [x.detach().cpu() for x in epoch_loss_hist])
axs[0].set_title("Epoch losses")

axs[1].plot([x for x in range(len(iter_loss_hist))], [x.detach().cpu() for x in iter_loss_hist])
axs[1].set_title(f"Every {record} iteration losses")

plt.show()