In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch.utils.data as data
import torch
import h5py

class DatasetFromHdf5(data.Dataset):
    def __init__(self, file_path):
        super(DatasetFromHdf5, self).__init__()
        hf = h5py.File(file_path)
        self.data = hf.get('data')
        self.target = hf.get('label')

    def __getitem__(self, index):
        return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float()
        
    def __len__(self):
        return self.data.shape[0]

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
#dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

class MemNet(nn.Module):
    def __init__(self, in_channels, channels, num_memblock, num_resblock):
        super(MemNet, self).__init__()
        self.feature_extractor = BNReLUConv(in_channels, channels, True)  #FENet: staic(bn)+relu+conv1
        self.reconstructor = BNReLUConv(channels, in_channels, True)      #ReconNet: static(bn)+relu+conv 
        self.dense_memory = nn.ModuleList(
            [MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)]
        )
        #ModuleList can be indexed like a regular Python list, but modules it contains are 
        #properly registered, and will be visible by all Module methods.
        
        
        self.weights = nn.Parameter((torch.ones(1, num_memblock)/num_memblock), requires_grad=True)  
        #output1,...,outputn corresponding w1,...,w2


    #Multi-supervised MemNet architecture
    def forward(self, x):
        residual = x
        out = self.feature_extractor(x)
        w_sum=self.weights.sum(1)  
        mid_feat=[]   # A lsit contains the output of each memblock
        ys = [out]  #A list contains previous memblock output(long-term memory)  and the output of FENet
        for memory_block in self.dense_memory:
            out = memory_block(out, ys)  #out is the output of GateUnit  channels=64
            mid_feat.append(out);
        #pred = Variable(torch.zeros(x.shape).type(dtype),requires_grad=False)
        pred = (self.reconstructor(mid_feat[0])+residual)*self.weights.data[0][0]/w_sum
        for i in range(1,len(mid_feat)):
            pred = pred + (self.reconstructor(mid_feat[i])+residual)*self.weights.data[0][i]/w_sum

        return pred

    #Base MemNet architecture
    '''
    def forward(self, x):
        residual = x   #input data 1 channel
        out = self.feature_extractor(x)
        ys = [out]  #A list contains previous memblock output and the output of FENet
        for memory_block in self.dense_memory:
            out = memory_block(out, ys)
        out = self.reconstructor(out)
        out = out + residual
        
        return out
    '''


class MemoryBlock(nn.Module):
    """Note: num_memblock denotes the number of MemoryBlock currently"""
    def __init__(self, channels, num_resblock, num_memblock):
        super(MemoryBlock, self).__init__()
        self.recursive_unit = nn.ModuleList(
            [ResidualBlock(channels) for i in range(num_resblock)]
        )
        #self.gate_unit = BNReLUConv((num_resblock+num_memblock) * channels, channels, True)  #kernel 3x3
        self.gate_unit = GateUnit((num_resblock+num_memblock) * channels, channels, True)   #kernel 1x1

    def forward(self, x, ys):
        """ys is a list which contains long-term memory coming from previous memory block
        xs denotes the short-term memory coming from recursive unit
        """
        xs = []
        residual = x
        for layer in self.recursive_unit:
            x = layer(x)
            xs.append(x)
       
        
        #gate_out = self.gate_unit(torch.cat([xs,ys], dim=1))
        gate_out = self.gate_unit(torch.cat(xs+ys, 1))  #where xs and ys are list, so concat operation is xs+ys
        ys.append(gate_out)
        return gate_out


class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    x - Relu - Conv - Relu - Conv - x
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.relu_conv1 = BNReLUConv(channels, channels, True)
        self.relu_conv2 = BNReLUConv(channels, channels, True)
        
    def forward(self, x):
        residual = x
        out = self.relu_conv1(x)
        out = self.relu_conv2(out)
        out = out + residual
        return out


class BNReLUConv(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(BNReLUConv, self).__init__()
        self.add_module('bn', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=inplace))  #tureL: direct modified x, false: new object and the modified
        self.add_module('conv', nn.Conv2d(in_channels, channels, 3, 1, 1))  #bias: defautl: ture on pytorch, learnable bias

class GateUnit(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(GateUnit, self).__init__()
        self.add_module('bn',nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=inplace))
        self.add_module('conv', nn.Conv2d(in_channels, channels,1,1,0))

In [4]:
import os
import torch
import numpy as np

from collections import OrderedDict


def convert_state_dict(state_dict):
    """Converts a state dict saved from a dataParallel module to normal 
       module state_dict inplace
       :param state_dict is the loaded DataParallel model_state
       You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it 
       without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can 
       load the weights file, create a new ordered dict without the module prefix, and load it back 
    """
    state_dict_new = OrderedDict()
    #print(type(state_dict))
    for k, v in state_dict.items():
        #print(k)
        name = k[7:] # remove the prefix module.
        # My heart is borken, the pytorch have no ability to do with the problem.
        state_dict_new[name] = v
    return state_dict_new



def calc_patch_size(func):
    def wrapper(args):
        if args.scale == 2:
            args.patch_size = 10
        elif args.scale == 3:
            args.patch_size = 7
        elif args.scale == 4:
            args.patch_size = 6
        else:
            raise Exception('Scale Error', args.scale)
        return func(args)
    return wrapper


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def convert_rgb_to_ycbcr(img, dim_order='hwc'):
    if dim_order == 'hwc':
        y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
        cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
        cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
    else:
        y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
        cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
        cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
    return np.array([y, cb, cr]).transpose([1, 2, 0])


def convert_ycbcr_to_rgb(img, dim_order='hwc'):
    if dim_order == 'hwc':
        r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
        g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
        b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
    else:
        r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
        g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
        b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
    return np.array([r, g, b]).transpose([1, 2, 0])


def preprocess(img, device):
    img = np.array(img).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(img)
    x = ycbcr[..., 0]
    x /= 255.
    x = torch.from_numpy(x).to(device)
    x = x.unsqueeze(0).unsqueeze(0)
    return x, ycbcr


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


from scipy.ndimage import gaussian_filter

def calc_ssim(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2):
    img1 = img1.cpu()
    img2 = img2.cpu()
    mu1 = gaussian_filter(img1, sd)
    mu2 = gaussian_filter(img2, sd)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq
    sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq
    sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2
    
    ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2))
    ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    ssim_map = ssim_num / ssim_den
    mssim = np.mean(ssim_map)
    
    return mssim


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [5]:
import os
import torch
import random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every num(step) epochs"""
    lr = lr * (0.1 ** (epoch // step))
    return lr

def train(training_data_loader, optimizer, model, criterion, epoch):
    lr = adjust_learning_rate(optimizer, epoch-1)

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    print("Epoch = {}, lr = {}".format(epoch, optimizer.param_groups[0]["lr"]))

    model.train()

    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)

        if cuda:
            input = input.cuda()
            target = target.cuda()

        #loss = criterion(model(input), target)
        prediction = model(input)
        loss = criterion(prediction, target)
        #print("Outside: input size", input.size(),"prediction_size", prediction.size())
        optimizer.zero_grad()
        loss.backward() 
        nn.utils.clip_grad_norm(model.parameters(),clip) 
        optimizer.step()

        if iteration%100 == 0:
            print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data[0]))

def save_checkpoint(model, epoch):
    model_out_path = "checkpoint1/" + "model_epoch_{}.pth".format(epoch)
    #state = {"epoch": epoch ,"model": model.state_dict}
    state = {"epoch": epoch ,"model": model.state_dict()}
    if not os.path.exists("checkpoint1/"):
        os.makedirs("checkpoint1/")

    torch.save(state, model_out_path)  # save weights and network architecture

    print("Checkpoint saved to {}".format(model_out_path))

def train_main(batchSize=512, nEpochs=50, lr=0.1, step=20, resume='', seed=123, start_epoch=1, clip=0.4, threads=1, momentum=0.9, weight_decay=1e-4, pretrained='', gpus='0,1,2,3'):
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)

    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    print("===> Loading datasets")
    train_set = DatasetFromHdf5("data/SuperResolution/train_291_31_x234.h5")
    training_data_loader = DataLoader(dataset=train_set, num_workers=threads, batch_size=batchSize, shuffle=True)

    print("===> Building model")
    model = MemNet(1, 64, 6, 6)
    criterion = nn.MSELoss(size_average=False)

    print("===> Setting GPU")
    if cuda:
        #model = model.cuda()
        model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel
        criterion = criterion.cuda()

    # optionally resume from a checkpoint
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"])
        else:
            print("=> no checkpoint found at '{}'".format(resume))

    # optionally copy weights from a checkpoint
    if pretrained:
        if os.path.isfile(pretrained):
            print("=> loading model '{}'".format(pretrained))
            weights = torch.load(pretrained)
            model.load_state_dict(weights['model'].state_dict())
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(pretrained))  

    print("===> Setting Optimizer")
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    print("===> Training")
    for epoch in range(start_epoch, nEpochs + 1):
        train(training_data_loader, optimizer, model, criterion, epoch)
        save_checkpoint(model, epoch)


In [6]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image


import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from torch.autograd import Variable

def test(weights_file, image_file, scale, save=False, debug=False, B=1, U=9, num_features=128):
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = MemNet(in_channels=1, channels=64, num_memblock=6, num_resblock=6)
    model.load_state_dict(convert_state_dict(torch.load(weights_file)["model"]))

    model.eval()
    model.to(device)

    image = pil_image.open(image_file).convert('RGB')
    image_file = os.path.basename(image_file)

    image_width = (image.width // scale) * scale
    image_height = (image.height // scale) * scale

    hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
    lr = hr.resize((hr.width // scale, hr.height // scale), resample=pil_image.BICUBIC)
    bicubic = lr.resize((lr.width * scale, lr.height * scale), resample=pil_image.BICUBIC)

    lr, _ = preprocess(lr, device)
    hr, _ = preprocess(hr, device)
    bicubic, ycbcr = preprocess(bicubic, device)

    with torch.no_grad():
        # Pre-upsampling
        preds = model(bicubic).clamp(0.0, 1.0)

    psnr = calc_psnr(hr, preds)
    ssim = calc_ssim(hr, preds)

    if debug:
        print(f'PSNR/SSIM: {psnr:.2f}/{ssim:.4f}')

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    if save:
        save_path = f'/content/drive/Shareddrives/BTP Meets/results/Set5/{scale}x/{image_file}'
        output.save(save_path.replace('.', '_memnet.'))
    return float(psnr), float(ssim)


In [7]:
import os

def do_test(psnr, ssim, BASE_DIR, save=False, debug=False):
    scales = [2, 3, 4]

    for file in os.listdir(BASE_DIR):
        if file.endswith(".png"):
            image_file_path = os.path.join(BASE_DIR, file)
            if debug:
                print(file)
            for scale in scales:
                if debug:
                    print(f"Scale: {scale}")
                result = test(f'/content/drive/Shareddrives/BTP Meets/models/memnet.pth', image_file_path, scale, save, debug)
                if scale not in psnr:
                    psnr[scale] = []
                if scale not in ssim:
                    ssim[scale] = []
                psnr[scale].append(result[0])
                ssim[scale].append(result[1])
            if debug:
                print()


In [8]:
psnr = {}
ssim = {}
do_test(psnr, ssim, '/content/drive/Shareddrives/BTP Meets/datasets/test/Set5/', True, True)

head.png
Scale: 2
PSNR/SSIM: 35.52/0.8857
Scale: 3
PSNR/SSIM: 34.87/0.8717
Scale: 4
PSNR/SSIM: 32.49/0.7855

butterfly.png
Scale: 2
PSNR/SSIM: 33.28/0.9693
Scale: 3
PSNR/SSIM: 29.20/0.9356
Scale: 4
PSNR/SSIM: 26.52/0.8870

bird.png
Scale: 2
PSNR/SSIM: 41.01/0.9847
Scale: 3
PSNR/SSIM: 36.06/0.9598
Scale: 4
PSNR/SSIM: 32.17/0.9148

baby.png
Scale: 2
PSNR/SSIM: 38.23/0.9624
Scale: 3
PSNR/SSIM: 35.75/0.9358
Scale: 4
PSNR/SSIM: 32.99/0.8859

woman.png
Scale: 2
PSNR/SSIM: 35.57/0.9687
Scale: 3
PSNR/SSIM: 32.33/0.9402
Scale: 4
PSNR/SSIM: 29.48/0.8980



In [9]:
import statistics

scales = [2, 3, 4]
for scale in scales:
    print(f'Avg PSNR/SSIM {scale}x: {statistics.mean(psnr[scale]):.2f}/{statistics.mean(ssim[scale]):.4f}')

Avg PSNR/SSIM 2x: 36.72/0.9542
Avg PSNR/SSIM 3x: 33.64/0.9286
Avg PSNR/SSIM 4x: 30.73/0.8742


In [10]:
scales = [2, 3, 4]

def calc_result(dataset):
    print()
    print(dataset)
    psnr = {}
    ssim = {}
    do_test(psnr, ssim, f'/content/drive/Shareddrives/BTP Meets/datasets/test/{dataset}/')
    for scale in scales:
        print(f'Avg PSNR/SSIM {scale}x: {statistics.mean(psnr[scale]):.2f}/{statistics.mean(ssim[scale]):.4f}')

# calc_result('Set14')
# calc_result('BSDS100')
# calc_result('Manga109')
calc_result('Urban100')

'''
Set14
Avg PSNR/SSIM 2x: 32.54/0.9110
Avg PSNR/SSIM 3x: 29.66/0.8468
Avg PSNR/SSIM 4x: 27.41/0.7648

BSDS100
Avg PSNR/SSIM 2x: 33.10/0.9184
Avg PSNR/SSIM 3x: 29.02/0.8139
Avg PSNR/SSIM 4x: 27.79/0.7489

Manga109
Avg PSNR/SSIM 2x: 36.16/0.9699
Avg PSNR/SSIM 3x: 31.08/0.9312
Avg PSNR/SSIM 4x: 28.28/0.8857

Urban100
Avg PSNR/SSIM 2x: 30.69/0.9153
Avg PSNR/SSIM 3x: 27.66/0.8508
Avg PSNR/SSIM 4x: 24.74/0.7500
'''


Urban100
Avg PSNR/SSIM 2x: 30.69/0.9153
Avg PSNR/SSIM 3x: 27.66/0.8508
Avg PSNR/SSIM 4x: 24.74/0.7500


'\nSet14\nAvg PSNR/SSIM 2x: 32.54/0.9110\nAvg PSNR/SSIM 3x: 29.66/0.8468\nAvg PSNR/SSIM 4x: 27.41/0.7648\n\nBSDS100\nAvg PSNR/SSIM 2x: 33.10/0.9184\nAvg PSNR/SSIM 3x: 29.02/0.8139\nAvg PSNR/SSIM 4x: 27.79/0.7489\n\nManga109\nAvg PSNR/SSIM 2x: 36.16/0.9699\nAvg PSNR/SSIM 3x: 31.08/0.9312\nAvg PSNR/SSIM 4x: 28.28/0.8857\n'