In [None]:
from __future__ import print_function
import os
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import time

from math import log10
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
# from data_utils import DatasetFromH5_SFSR
# from model import Net_SRCNN
!pip install tensorboard_logger
from tensorboard_logger import configure, log_value

# from data_utils import DatasetFromH5_MFSR
# from model import Net_VSRNet

import torch.nn.functional as F
import torch.nn.init as init
import matplotlib.pyplot as plt


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorboard_logger
  Downloading tensorboard_logger-0.1.0-py2.py3-none-any.whl (17 kB)
Installing collected packages: tensorboard_logger
Successfully installed tensorboard_logger-0.1.0


## ESRGAN

In [None]:
from torch.utils.data.dataset import Dataset
import h5py

class DatasetFromH5_MFSR(Dataset):
    def __init__(self, image_dataset_dir, target_dataset_dir, upscale_factor, input_transform=None, target_transform=None):
        super(DatasetFromH5_MFSR, self).__init__()
        
        image_h5_file = h5py.File(image_dataset_dir, 'r')
        target_h5_file = h5py.File(target_dataset_dir, 'r')
        image_dataset = image_h5_file['data']
        target_dataset = target_h5_file['data']
        
        self.image_datasets = image_dataset
        self.target_datasets = target_dataset
        self.total_count = image_dataset.shape[0]
        
        self.input_transform = input_transform
        self.target_transform = target_transform
        
    def __getitem__(self, index):        
        image = self.image_datasets[index, :, :, :]
        target = self.target_datasets[index, [2], :, :]
        
        image  = image.astype(np.float32)
        target = target.astype(np.float32)
        
        #   Notice that image is the bicubic upscaled LR image patch, in float format, in range [0, 1]
#        image = image / 255.0 
        #   Notice that target is the HR image patch, in uint8 format, in range [0, 255]
        target = target / 255.0
        
        image =  torch.from_numpy(image)
        target = torch.from_numpy(target)

        return image, target

    def __len__(self):
        return self.total_count

In [None]:
data_dir = "./data"

downloads_dir = data_dir + '/downloads'
datasets_dir = data_dir + '/datasets'
models_dir = data_dir + '/models'
pretrained_models = data_dir + '/pretrained_models'

os.makedirs(downloads_dir, exist_ok=True)
os.makedirs(datasets_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(pretrained_models, exist_ok=True)

uf4_train_dir = datasets_dir + '/uf4_train'
uf4_val_dir = datasets_dir + '/uf4_val'

srrnet_train_lr = uf4_train_dir + '/srrnet_train_lr.h5'
srrnet_train_hr = uf4_train_dir + '/srrnet_train_hr.h5'

srrnet_val_lr = uf4_val_dir + '/srrnet_val_lr.h5'
srrnet_val_hr = uf4_val_dir + '/srrnet_val_hr.h5'

!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AACDZmHK7d2JQi0ADaoliM04a/uf_4/train/Data_CDVL_LR_Bic_MC_uf_4_ps_72_fn_5_tpn_225000.h5 -O srrnet_train_lr
!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AACmrvoqkXXnZTXUFsWvNDCsa/uf_4/train/Data_CDVL_HR_uf_4_ps_72_fn_5_tpn_225000.h5 -O srrnet_train_hr

!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AADJnJmRvFxmf7sxEk5G0Uuma/uf_4/val/Data_CDVL_LR_Bic_MC_uf_4_ps_72_fn_5_tpn_45000.h5 -O srrnet_val_lr
!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AAChoVG4fLqdpsmSuq9wrEvFa/uf_4/val/Data_CDVL_HR_uf_4_ps_72_fn_5_tpn_45000.h5 -O srrnet_val_hr



In [None]:
#import pickle

#from google.colab import drive
#drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#with open('/content/drive/MyDrive/train_subset_12800.pkl', 'rb') as f:
   # subset_train = pickle.load(f)


In [None]:
#with open('/content/drive/MyDrive/val_subset_12800.pkl', 'rb') as f:
   # subset_val = pickle.load(f)

In [None]:
subset_val.image_datasets = subset_val.image_datasets[0:256,:,:,:]
subset_val.target_datasets = subset_val.target_datasets[0:256,:,:,:]

subset_train.image_datasets = subset_train.image_datasets[0:256,:,:,:]
subset_train.target_datasets = subset_train.target_datasets[0:256,:,:,:]




In [None]:
upscale_factor = 4
threads = 1
batchSize = 64 #256


train_loader = DataLoader(dataset=subset_train, num_workers=threads, batch_size=batchSize, shuffle=False)
val_loader = DataLoader(dataset=subset_val, num_workers=threads, batch_size=batchSize, shuffle=False)

## Get the pretrained SRCNN

In [None]:

def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)


class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(RRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out


In [None]:
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
pretrained_net = torch.load('RRDB_ESRGAN_x4.pth')

In [None]:
path = 'RRDB_ESRGAN_x4.pth'

pretrained_net = torch.load('RRDB_ESRGAN_x4.pth')


crt_model = RRDBNet(3, 3, 64, 23, gc=32)
crt_net = crt_model.state_dict()

load_net_clean = {}

for k, v in pretrained_net.items():
    if k.startswith('module.'):
        load_net_clean[k[7:]] = v
    else:
        load_net_clean[k] = v

pretrained_net = load_net_clean


In [None]:
path = 'RRDB_ESRGAN_x4.pth'


srcnn = RRDBNet(3, 3, 64, 23, gc=32)

state_dict = srcnn.state_dict()
for n, p in torch.load(path, map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

torch.save(srcnn, path)

In [None]:
class RRDBNet_video(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, esrgan, gc=32):
        super(RRDBNet_video, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
        
        self.conv_first_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.conv_first_1 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.conv_first_2 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)

        #self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.esrgan = esrgan 
        
        self._initialize_weights()

    def forward(self, x):

        h10 = x[:,[0],:,:].repeat(1,3,1,1)
        h11 = x[:,[1],:,:].repeat(1,3,1,1)
        h12 = x[:,[2],:,:].repeat(1,3,1,1)
        h13 = x[:,[3],:,:].repeat(1,3,1,1)
        h14 = x[:,[4],:,:].repeat(1,3,1,1)

        
        h10 = self.conv_first_0(h10)
        h11 = self.conv_first_1(h11)
        h12 = self.conv_first_2(h12)
        h13 = self.conv_first_1(h13)
        h14 = self.conv_first_0(h14) 

        #fea = self.conv_first(x)
        x = F.relu(torch.cat((h10, h11, h12, h13, h14), 1))
        #print(x.shape)
        trunk = self.trunk_conv(x) #(self.RRDB_trunk(x))
        fea = h12 + trunk
        
        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))
        #print(out.shape)  
        out = F.interpolate(out, size=(72, 72), mode='bilinear', align_corners=True)
        out = out[:,[1],:,:]
        return out

    def _initialize_weights(self):
        
        esrgan = torch.load(self.esrgan, map_location=lambda storage, loc: storage) # forcing to load to CPU
        
        
        
        self.conv_first_0.weight.data = (esrgan.conv_first.weight.data).clone()
        self.conv_first_1.weight.data = (esrgan.conv_first.weight.data).clone()
        self.conv_first_2.weight.data = (esrgan.conv_first.weight.data).clone()
        
        self.conv_first_0.bias.data = (esrgan.conv_first.bias.data).clone()
        self.conv_first_1.bias.data = (esrgan.conv_first.bias.data).clone()
        self.conv_first_2.bias.data = (esrgan.conv_first.bias.data).clone()
        
        #self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv.bias.data = (esrgan.trunk_conv.bias.data).clone()  # Modify!
        self.trunk_conv.weight.data = torch.cat((esrgan.trunk_conv.weight.data,
                                                 esrgan.trunk_conv.weight.data,
                                                 esrgan.trunk_conv.weight.data,
                                                 esrgan.trunk_conv.weight.data,
                                                 esrgan.trunk_conv.weight.data),1
                                                 ).clone()/5.0         

        self.upconv1.bias.data = (esrgan.upconv1.bias.data).clone()  
        self.upconv1.weight.data = (esrgan.upconv1.weight.data).clone() 

        self.upconv2.bias.data = (esrgan.upconv2.bias.data).clone()  
        self.upconv2.weight.data = (esrgan.upconv2.weight.data).clone() 
        
        self.HRconv.bias.data = (esrgan.HRconv.bias.data).clone()  
        self.HRconv.weight.data = (esrgan.HRconv.weight.data).clone() 

        self.conv_last.bias.data = (esrgan.conv_last.bias.data).clone()  
        self.conv_last.weight.data = (esrgan.conv_last.weight.data).clone() 

model = RRDBNet_video(3, 3, 64, 23, esrgan = path)
criterion = nn.MSELoss()

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

In [None]:
nEpochs = 4
lr = 0.005

val(0)
checkpoint(0)
for epoch in range(1, nEpochs + 1):
    train(epoch)
    val(epoch)
    checkpoint(epoch)

===> Epoch 0 Validation CDVL: Avg. Loss: 0.0016, Avg.PSNR:  29.5735 dB, Time: 8.0910
Checkpoint saved to epochs_VSRNet/model_epoch_0.pth
===> Epoch 1 Complete: lr: 0.0001, Avg. Loss: 0.0000, Avg.PSNR:  0.0252 dB, Time: 3.3703
===> Epoch 1 Validation CDVL: Avg. Loss: 0.0010, Avg.PSNR:  31.2599 dB, Time: 1.6148
===> Epoch 2 Complete: lr: 1.0000000000000003e-05, Avg. Loss: 0.0000, Avg.PSNR:  0.0263 dB, Time: 2.3980
===> Epoch 2 Validation CDVL: Avg. Loss: 0.0006, Avg.PSNR:  34.3938 dB, Time: 1.5874
===> Epoch 3 Complete: lr: 1.0000000000000002e-06, Avg. Loss: 0.0000, Avg.PSNR:  0.0273 dB, Time: 2.4134
===> Epoch 3 Validation CDVL: Avg. Loss: 0.0006, Avg.PSNR:  35.0972 dB, Time: 1.5931
===> Epoch 4 Complete: lr: 1.0000000000000002e-07, Avg. Loss: 0.0000, Avg.PSNR:  0.0274 dB, Time: 2.4413
===> Epoch 4 Validation CDVL: Avg. Loss: 0.0006, Avg.PSNR:  35.1694 dB, Time: 1.6272
Checkpoint saved to epochs_VSRNet/model_epoch_4.pth


In [None]:
import torch.nn.functional as F

In [None]:
lr = 0.001
optimizer = optim.Adam([{'params': model.conv_first_0.parameters()},
                        {'params': model.conv_first_1.parameters()},
                        {'params': model.conv_first_2.parameters()},
                        {'params': model.trunk_conv.parameters()},
                        {'params': model.upconv1.parameters()},
                        {'params': model.upconv2.parameters()},
                        {'params': model.HRconv.parameters()},
                        {'params': model.conv_last.parameters(), 'lr': lr/10.0}
                        ], lr=lr)

In [None]:
configure("tensorBoardRuns/VSRNet-relu-mid-fusion-pretrain-sym-x4-batch-128-CDVL-225000x5x72x72-wd")

In [None]:

def train(epoch):
    lr = 0.001
    epoch_loss = 0
    epoch_psnr = 0
    start = time.time()
    #   Step up learning rate decay
    #   The network have 3 layers
    lr = lr * (0.1 ** (epoch // (nEpochs // 4)))
    
    optimizer.param_groups[0]['lr'] = lr
    optimizer.param_groups[1]['lr'] = lr
    optimizer.param_groups[2]['lr'] = lr
    optimizer.param_groups[3]['lr'] = lr
    optimizer.param_groups[4]['lr'] = lr/10.0
    

    n = 0
    for iteration, batch in enumerate(train_loader, 1):
        if n >= 3:
          break 
        n = n+1
        image, target = Variable(batch[0]), Variable(batch[1])
        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        loss = criterion(model(image), target)
        psnr = 10 * log10(1 / loss.data.item())
        epoch_loss += loss.data.item()
        epoch_psnr += psnr
        loss.backward()
        optimizer.step()
        
    end = time.time()
    print("===> Epoch {} Complete: lr: {}, Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Time: {:.4f}".format(epoch, lr, epoch_loss / len(train_loader), epoch_psnr / len(train_loader), (end-start)))
    
    log_value('train_loss', epoch_loss / len(train_loader), epoch)
    log_value('train_psnr', epoch_psnr / len(train_loader), epoch)

In [None]:
def val(epoch):
    #   Validation on CDVL val set
    lr = 0.001
    avg_psnr = 0
    avg_mse = 0
    frame_count = 0
    start = time.time()
    n = 0
    for batch in val_loader:
        if n >= 3:
          break 
        n = n+1
        image, target = Variable(batch[0]), Variable(batch[1])
        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        prediction = model(image)

        for i in range(0, image.shape[0]):
            mse = criterion(prediction[i], target[i])
            psnr = 10 * log10(1 / mse.data.item())
            avg_psnr += psnr
            avg_mse  += mse.data.item()
            frame_count += 1

    end = time.time()
    print("===> Epoch {} Validation CDVL: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Time: {:.4f}".format(epoch, avg_mse / frame_count, avg_psnr / frame_count, (end-start)))

    log_value('val_loss', avg_mse / frame_count, epoch)
    log_value('val_psnr', avg_psnr / frame_count, epoch)

In [None]:
def checkpoint(epoch):
    if epoch%4 == 0:
        if not os.path.exists("epochs_VSRNet"):
            os.makedirs("epochs_VSRNet")
        model_out_path = "epochs_VSRNet/" + "model_epoch_{}.pth".format(epoch)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

In [None]:
nEpochs = 2
lr = 0.001

val(0)
checkpoint(0)
for epoch in range(1, nEpochs + 1):
    train(epoch)
    val(epoch)
    checkpoint(epoch)

### Let's test with a video!

In [None]:
uf4_test_dir = datasets_dir + '/uf4_test'

vsrnet_test_lr = uf4_test_dir + '/vsrnet_test_lr.h5'
vsrnet_test_hr = uf4_test_dir + '/vsrnet_test_hr.h5'

!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AAADzBQ7iA492oQ26ag67ZsKa/uf_4/test/LR_Bic_MC/scene_30.h5 -O vsrnet_test_lr
!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AADSka3PgSR5EuCt9ByugfY6a/uf_4/test/HR/scene_30.h5 -O vsrnet_test_hr


--2023-05-07 20:14:54--  https://www.dropbox.com/sh/1jz9zeer9wxetx2/AAADzBQ7iA492oQ26ag67ZsKa/uf_4/test/LR_Bic_MC/scene_30.h5
Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /sh/raw/1jz9zeer9wxetx2/AAADzBQ7iA492oQ26ag67ZsKa/uf_4/test/LR_Bic_MC/scene_30.h5 [following]
--2023-05-07 20:14:54--  https://www.dropbox.com/sh/raw/1jz9zeer9wxetx2/AAADzBQ7iA492oQ26ag67ZsKa/uf_4/test/LR_Bic_MC/scene_30.h5
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucfb90e0bc5f994783b10cd5a537.dl.dropboxusercontent.com/cd/0/inline/B7mG7htFWiqRZrfWwijhSTkoLLzoY2WNv-5KtuIo-U982Mw_DOwm4X2RnBVaSN1C9ngt47S2mfVBOLsl2Y48x3fMPUw0yMoAtj0z_MQ7oMwa1pmjOpwVYAEJcH5nxVgbGbhpvSKOtcjELOJRy-l0W-cty5D5tRm1BJa51agx826X9Q/file# [following]
--2023-05-07 20:14:55--  https://ucfb90e0b

In [None]:
path_LR_Bic_MC = './vsrnet_test_hr'
path_HR = './vsrnet_test_hr'
videos_h5_name = ['scene_50.h5']
videos_h5_name.sort()

In [None]:
h5_len = len(videos_h5_name)
model_PSNR   = np.zeros(h5_len)
model_SSIM   = np.zeros(h5_len)
bicubic_PSNR = np.zeros(h5_len)
bicubic_SSIM = np.zeros(h5_len)
model_time   = np.zeros(h5_len)

In [None]:
out_path = './'
if not os.path.exists(out_path):
    os.makedirs(out_path)

In [None]:
import numpy
import math

def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

In [None]:
from scipy.ndimage import gaussian_filter

from numpy.lib.stride_tricks import as_strided as ast

"""
Hat tip: http://stackoverflow.com/a/5078155/1828289
"""
def block_view(A, block=(3, 3)):
    """Provide a 2D block view to 2D array. No error checking made.
    Therefore meaningful (as implemented) only for blocks strictly
    compatible with the shape of A."""
    # simple shape and strides computations may seem at first strange
    # unless one is able to recognize the 'tuple additions' involved ;-)
    shape = (A.shape[0]// block[0], A.shape[1]// block[1])+ block
    strides = (block[0]* A.strides[0], block[1]* A.strides[1])+ A.strides
    return ast(A, shape= shape, strides= strides)


def ssim(img1, img2, C1=0.01**2, C2=0.03**2):

    bimg1 = block_view(img1, (4,4))
    bimg2 = block_view(img2, (4,4))
    s1  = numpy.sum(bimg1, (-1, -2))
    s2  = numpy.sum(bimg2, (-1, -2))
    ss  = numpy.sum(bimg1*bimg1, (-1, -2)) + numpy.sum(bimg2*bimg2, (-1, -2))
    s12 = numpy.sum(bimg1*bimg2, (-1, -2))

    vari = ss - s1*s1 - s2*s2
    covar = s12 - s1*s2

    ssim_map =  (2*s1*s2 + C1) * (2*covar + C2) / ((s1*s1 + s2*s2 + C1) * (vari + C2))
    return numpy.mean(ssim_map)

# FIXME there seems to be a problem with this code
def ssim_exact(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2):

    mu1 = gaussian_filter(img1, sd)
    mu2 = gaussian_filter(img2, sd)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq
    sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq
    sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2

    ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2))

    ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    ssim_map = ssim_num / ssim_den
    return numpy.mean(ssim_map)

In [None]:
from tqdm import tqdm
import cv2

video_idx = 0
#   Read h5 file
LR_Bic_MC_h5_file = h5py.File('./vsrnet_test_lr', 'r')
LR_Bic_MC_h5_data = LR_Bic_MC_h5_file['data']
HR_h5_file = h5py.File('./vsrnet_test_hr', 'r')
HR_h5_data = HR_h5_file['data']
    
# load to memory
HR_h5_data = HR_h5_data[()]#.value
LR_Bic_MC_h5_data = LR_Bic_MC_h5_data[()]#.value
    
# transpose to correct order
HR_h5_data = np.transpose(HR_h5_data, (3, 2, 1, 0))
LR_Bic_MC_h5_data = np.transpose(LR_Bic_MC_h5_data, (3, 2, 1, 0))
    
frame_number = LR_Bic_MC_h5_data.shape[0]

IS_REAL_TIME = True

video_name = 'scene_40'
    
if not IS_REAL_TIME:
    fps = 30
    size = (LR_Bic_MC_h5_data.shape[3], LR_Bic_MC_h5_data.shape[2])
    output_name = out_path + video_name.split('.')[0] + '.avi'
    videoWriter = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc('M','J','P','G'), fps, size)
#            videoWriter = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc(*'XVID'), fps, size)
        
#   Prepare to save PSNR and SSIM of the current video
#   Each value corresponding to one test frame
model_PSNR_cur   = np.zeros(frame_number)
model_SSIM_cur   = np.zeros(frame_number)
bicubic_PSNR_cur = np.zeros(frame_number)
bicubic_SSIM_cur = np.zeros(frame_number)
model_time_cur   = np.zeros(frame_number)
    
for idx in tqdm(range(0, frame_number)):
    img_HR = HR_h5_data[idx, 0, :, :] #2D
    img_LR_Bic_MC = LR_Bic_MC_h5_data[idx, :, :, :] #3D 5x1080x1920
    
    # Reshape to 4D
    img_LR_Bic_MC = img_LR_Bic_MC.reshape((1, img_LR_Bic_MC.shape[0], img_LR_Bic_MC.shape[1], img_LR_Bic_MC.shape[2]))
    
    img_LR_Bic_MC = img_LR_Bic_MC.astype(np.float32)

    img_LR_Bic_MC =  torch.from_numpy(img_LR_Bic_MC)
                        
    if torch.cuda.is_available():
        img_LR_Bic_MC = img_LR_Bic_MC.cuda()

    start = time.time()
    if img_LR_Bic_MC.sum() != 0:
        #print(img_LR_Bic_MC)
        img_HR_net = model(img_LR_Bic_MC)
        #print('SUPER RESOLUTION!')
        #print(img_HR_net)
        break
    else:
        img_HR_net = img_LR_Bic_MC[:,2,:,:]
        img_HR_net = img_HR_net.reshape((1, 1, img_HR.shape[0], img_HR.shape[1])) # reshape to 1x1x1080x1920
        
    end = time.time() # measure the computation time
    
    img_HR_net = img_HR_net.cpu()
    img_HR_net = img_HR_net.data[0].numpy()
    img_HR_net *= 255.0
    img_HR_net = img_HR_net.clip(0, 255)
    img_HR_net = img_HR_net.astype(np.uint8)
    
    img_LR_Bic_MC = img_LR_Bic_MC.cpu()
    img_LR_Bic = img_LR_Bic_MC[:, 2, :, :] # center frame
    img_LR_Bic = img_LR_Bic.data[0].numpy()
    img_LR_Bic *= 255.0
    img_LR_Bic = img_LR_Bic.clip(0, 255)
    img_LR_Bic = img_LR_Bic.astype(np.uint8)
    
    img_HR = img_HR.reshape((1, img_HR.shape[0], img_HR.shape[1]))
    img_LR_Bic = img_LR_Bic.reshape((1, img_LR_Bic.shape[0], img_LR_Bic.shape[1]))

    
    model_PSNR_cur[idx]   = psnr((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_HR_net).reshape(img_HR_net.shape[1], img_HR_net.shape[2]).astype(int))
    #model_SSIM_cur[idx]   = ssim((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_HR_net).reshape(img_HR_net.shape[1], img_HR_net.shape[2]).astype(int))
    bicubic_PSNR_cur[idx] = psnr((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_LR_Bic).reshape(img_LR_Bic.shape[1], img_LR_Bic.shape[2]).astype(int))
    bicubic_SSIM_cur[idx] = ssim((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_LR_Bic).reshape(img_LR_Bic.shape[1], img_LR_Bic.shape[2]).astype(int))
    model_time_cur[idx]   = (end-start)

    # Repeat to 3 channels to save and display
    img_HR_net = np.repeat(img_HR_net, 3, axis=0)
    img_HR_net = np.transpose(img_HR_net, (1, 2, 0))

    if IS_REAL_TIME:
        plt.imshow(img_HR_net, cmap = 'gray')
        plt.show()

#                cv2.imshow('LR Video ', img_LR_Bic)
#                cv2.imshow('SR Video ', img_HR_net)
#                cv2.waitKey(DELAY_TIME)
    else:
        # save video
        videoWriter.write(img_HR_net)
    
# Done video writing
videoWriter.release()

# Save PSNR and SSIM
# Exclude PSNR = 100 cases (caused by black frames)
cal_flag = (model_PSNR_cur != 100)
model_PSNR[video_idx]   = np.mean(model_PSNR_cur[cal_flag])
model_SSIM[video_idx]   = np.mean(model_SSIM_cur[cal_flag])
bicubic_PSNR[video_idx] = np.mean(bicubic_PSNR_cur[cal_flag])
bicubic_SSIM[video_idx] = np.mean(bicubic_SSIM_cur[cal_flag])
model_time[video_idx]   = np.mean(model_time_cur[cal_flag])

print("===> Test on Video Idx: " + str(video_idx) +" Complete: Model PSNR: {:.4f} dB, Model SSIM: {:.4f} , Bicubic PSNR:  {:.4f} dB, Bicubic SSIM: {:.4f} , Average time: {:.4f}"
  .format(model_PSNR[video_idx], model_SSIM[video_idx], bicubic_PSNR[video_idx], bicubic_SSIM[video_idx], model_time[video_idx]*1000))
video_idx += 1

  0%|          | 0/14 [00:00<?, ?it/s]


OutOfMemoryError: ignored