## README

This program is for testing video resolution. 

I am just testing on one short video clip which can be obtain through this website (https://www.pexels.com/videos/)

The weighted file should be download through (https://www.dropbox.com/s/2fl5jz5nw9oiw1f/espcn_x3.pth?dl=0)

The result of the image will be locate in the result folder.


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import argparse
import math
import cv2
import numpy as np
import os
from os import listdir
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import PIL.Image as pil_image


def is_video_file(filename):
    return any(filename.endswith(extension) for extension in ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'])
class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x

if __name__ == "__main__":


    UPSCALE_FACTOR = 3
    IS_REAL_TIME = False
    DELAY_TIME = 1
    weights_file = '/content/drive/MyDrive/ECE570Project/ESPCNV1/ESPCN_x3.pb'
    path = '/content/drive/MyDrive/ECE570Project/ESPCNV1/'
    videos_name = [x for x in listdir(path) if is_video_file(x)]
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = ESPCN(scale_factor=3).to(device)
    if torch.cuda.is_available():
        model = model.cuda()

    state_dict = model.state_dict()
    for n, p in torch.load(weights_file, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)
    model.eval()

    out_path = '/content/drive/MyDrive/ECE570Project/ESPCNV1/results/SRF_' + str(UPSCALE_FACTOR) + '/'
    if not os.path.exists(out_path):
        os.makedirs(out_path)


    for video_name in tqdm(videos_name, desc='convert LR videos to HR videos'):
        videoCapture = cv2.VideoCapture(path + video_name)
        if not IS_REAL_TIME:
            fps = videoCapture.get(cv2.CAP_PROP_FPS)
            size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
                    int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)
            output_name = out_path + video_name.split('.')[0] + '.mp3'
            videoWriter = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc(*'MPEG'), fps, size)
        # read frame
        success, frame = videoCapture.read()
        while success:
            img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert('YCbCr')
            y, cb, cr = img.split()
            image = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0])
            if torch.cuda.is_available():
                image = image.cuda()

            out = model(image)
            out = out.cpu()
            out_img_y = out.data[0].numpy()
            out_img_y *= 255.0
            out_img_y = out_img_y.clip(0, 255)
            out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
            out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
            out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
            out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
            out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)

            if IS_REAL_TIME:
                cv2.imshow('LR Video ', frame)
                cv2.imshow('SR Video ', out_img)
                cv2.waitKey(DELAY_TIME)
            else:
                # save video
                videoWriter.write(out_img)
            # next frame
            success, frame = videoCapture.read()

convert LR videos to HR videos: 100%|██████████| 1/1 [22:53<00:00, 1373.17s/it]
