In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/epochx2/checkpoint.pth_x2.tar


In [2]:
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import torch
from torch import nn
import time
from datasets import load_dataset

In [3]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from progressbar import ProgressBar

In [4]:
try:
    import accimage
except:
    accimage = None


def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


def to_tensor(pic):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
    elif pic.mode == '1':
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


def normalize(tensor, mean, std):
    """Normalize a tensor image with mean and standard deviation.

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channely.

    Returns:
        Tensor: Normalized Tensor image.
    """
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor

In [5]:
import h5py
import random
import numpy as np
from PIL import Image
from pathlib import Path

from torch.utils.data import Dataset
from torchvision.transforms import transforms

## Utils 

In [6]:
def get_scale_from_dataset(dataset):
    scale = None
    if len(dataset) > 0:
        lr = Image.open(dataset[0]['lr'])
        hr = Image.open(dataset[0]['hr'])
        dim1 = round(hr.width / lr.width)
        dim2 = round(hr.height / lr.height)
        scale = max(dim1, dim2)
    return scale


def get_scale(lr, hr):
    dim1 = round(hr.width / lr.width)
    dim2 = round(hr.height / lr.height)
    scale = max(dim1, dim2)
    return scale


def resize_image(lr_image, hr_image, scale=None):
    if scale is None:
        scale = get_scale(lr_image, hr_image)
    if lr_image.width * scale != hr_image.width or lr_image.height * scale != hr_image.height:
        hr_width = lr_image.width * scale
        hr_height = lr_image.height * scale
        return hr_image.resize((hr_width, hr_height), resample=Image.BICUBIC)
    return hr_image

## Dataset

In [7]:
class EvalDataset(Dataset):
    def __init__(self, dataset, transform = None):
        super(EvalDataset, self).__init__()
        self.dataset = dataset
        self.scale = get_scale_from_dataset(dataset)
        self.transform = transform

    def __getitem__(self, idx):
        lr_image = Image.open(self.dataset[idx]['lr']).convert('RGB')
        hr_image = resize_image(lr_image, Image.open(self.dataset[idx]['hr']).convert('RGB'), scale=self.scale)
        lr = np.array(lr_image)
        hr = np.array(hr_image)
        lr = lr.astype(np.float32).transpose([2, 0, 1]) / 255
        hr = hr.astype(np.float32).transpose([2, 0, 1]) / 255
        if self.transform:
            lr, hr = self.transform(lr, hr)
            # label = self.transform(label)

        return lr, hr

    def __len__(self):
        return len(self.dataset)
    
    
class TrainDataset(Dataset):
    def __init__(self, dataset, transform = None, patch_size = 64):
        super(TrainDataset, self).__init__()
        self.dataset = dataset
        self.patch_size = patch_size
        self.scale = get_scale_from_dataset(dataset)
        self.transform = transform
    
    @staticmethod
    def random_crop(lr, hr, size, scale):
        lr_left = random.randint(0, lr.shape[1] - size)
        lr_right = lr_left + size
        lr_top = random.randint(0, lr.shape[0] - size)
        lr_bottom = lr_top + size
        hr_left = lr_left * scale
        hr_right = lr_right * scale
        hr_top = lr_top * scale
        hr_bottom = lr_bottom * scale
        lr = lr[lr_top:lr_bottom, lr_left:lr_right]
        hr = hr[hr_top:hr_bottom, hr_left:hr_right]
        return lr, hr

    @staticmethod
    def random_horizontal_flip(lr, hr):
        if random.random() < 0.5:
            lr = lr[:, ::-1, :].copy()
            hr = hr[:, ::-1, :].copy()
        return lr, hr

    @staticmethod
    def random_vertical_flip(lr, hr):
        if random.random() < 0.5:
            lr = lr[::-1, :, :].copy()
            hr = hr[::-1, :, :].copy()
        return lr, hr

    @staticmethod
    def random_rotate_90(lr, hr):
        if random.random() < 0.5:
            lr = np.rot90(lr, axes=(1, 0)).copy()
            hr = np.rot90(hr, axes=(1, 0)).copy()
        return lr, hr

    def __getitem__(self, idx):
        lr_image = Image.open(self.dataset[idx]['lr']).convert('RGB')
        hr_image = resize_image(lr_image, Image.open(self.dataset[idx]['hr']).convert('RGB'), scale=self.scale)
        lr = np.array(lr_image)
        hr = np.array(hr_image)
        lr, hr = self.random_crop(lr, hr, self.patch_size, self.scale)
        lr, hr = self.random_horizontal_flip(lr, hr)
        lr, hr = self.random_vertical_flip(lr, hr)
        lr, hr = self.random_rotate_90(lr, hr)
        lr = lr.astype(np.float32).transpose([2, 0, 1]) / 255
        hr = hr.astype(np.float32).transpose([2, 0, 1]) / 255
        
        if self.transform:
            lr, hr = self.transform(lr, hr)

        return lr, hr

    def __len__(self):
        return len(self.dataset)

In [8]:
!pip install super_image

Collecting super_image
  Downloading super_image-0.1.7-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.0/91.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: super_image
Successfully installed super_image-0.1.7
[0m

In [9]:
from super_image.models.edsr.configuration_edsr import EdsrConfig
from super_image.modeling_utils import (
    default_conv,
    MeanShift,
    Upsampler,PreTrainedModel
)

## Model 

In [10]:
class ResBlock(nn.Module):
    def __init__(
            self, conv, n_feats, kernel_size,
            bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class edsr(PreTrainedModel):
    
#     config_class = EdsrConfig
    
    def __init__(self, args, conv=default_conv):
        super(edsr, self).__init__(args)

        self.args = args
        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        n_colors = args.n_colors
        kernel_size = 3
        scale = args.scale
        rgb_range = args.rgb_range
        act = nn.ReLU(True)
        self.sub_mean = MeanShift(rgb_range, rgb_mean=args.rgb_mean, rgb_std=args.rgb_std)  # standardize input
        self.add_mean = MeanShift(rgb_range, sign=1, rgb_mean=args.rgb_mean, rgb_std=args.rgb_std)  # restore output

        # define head module, channels: 3->64
        m_head = [conv(n_colors, n_feats, kernel_size)]

        # define body module, channels: 64->64
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)

        if args.no_upsampling:
            self.out_dim = n_feats
        else:
            self.out_dim = args.n_colors
            # define tail module
            m_tail = [
                Upsampler(conv, scale, n_feats, act=False),
                conv(n_feats, n_colors, kernel_size)
            ]
            self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.head(x)

        res = self.body(x)
        res += x

        if self.args.no_upsampling:
            x = res
        else:
            x = self.tail(res)

        return x

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load data

In [12]:
# train_dir = 'data/train'
# val_dir = 'data/validation'

# train_transform = Compose([
#                             Normalize([0.449, 0.438, 0.404],
#                                       [1.0, 1.0, 1.0])])

# valid_transform = Compose([
#                             Normalize([0.440, 0.435, 0.403],
#                                       [1.0, 1.0, 1.0])])

t_set = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='train')
e_set = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='validation')

# trainset = TrainDataset(t_set, transform=train_transform)
# validset = EvalDataset(e_set, transform=valid_transform)


# trainset = DIV2K_x2(root_dir=train_dir, im_size=40, scale=2, transform=train_transforms)
# validset = DIV2K_x2(root_dir=val_dir, im_size=40, scale=2, transform=valid_transforms)

Downloading builder script:   0%|          | 0.00/6.23k [00:00<?, ?B/s]

Downloading and preparing dataset div2k/bicubic_x2 to /root/.cache/huggingface/datasets/eugenesiow___div2k/bicubic_x2/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/925M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/118M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.53G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/449M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset div2k downloaded and prepared to /root/.cache/huggingface/datasets/eugenesiow___div2k/bicubic_x2/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677. Subsequent calls will reuse this data.


In [13]:
trainset = TrainDataset(t_set)
validset = EvalDataset(e_set)

trainloader = DataLoader(trainset, batch_size=8, shuffle=True)
validloader = DataLoader(validset, batch_size=1, shuffle=True)

In [14]:
def save_checkpoint(epoch, model, optimizer):
    """
    Save model checkpoint.
    :param epoch: epoch number
    :param model: model
    :param optimizer: optimizer
    """
    state = {'epoch': epoch,
             'model': model,
             'optimizer': optimizer}
    filename = 'checkpoint.pth.tar'
    torch.save(state, filename)
    
def adjust_learning_rate(optimizer, scale):
    """
    Scale learning rate by a specified factor.
    :param optimizer: optimizer whose learning rate must be shrunk.
    :param scale: factor to multiply learning rate with.
    """
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * scale
    print("DECAYING learning rate.\n The new LR is %f\n" % (optimizer.param_groups[0]['lr'],))

In [15]:
class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Training

In [21]:
criterion = nn.L1Loss()
scale = 2
epochs = 25
print_every = 5
train_loss = 0
batch_num = 0
decay_lr_at = 11, 15   # decay learning rate after these many iterations
decay_lr_to = 0.1

In [22]:
from super_image.trainer_utils import EvalPrediction
from super_image.utils.metrics import compute_metrics
import gc

In [25]:
checkpoint = "/kaggle/working/checkpoint.pth.tar"

best_metric = 0
best_epoch = 0

def train(train_loader,valid_loader, model, criterion, optimizer, epoch):
    
    global best_metric, best_epoch
    losses = AverageMeter()
    
    for i, (img, label) in enumerate(train_loader):
        
        start = time.time()

        img, label = img.to(device), label.to(device)
        pred = model(img)
        # print(pred.shape, label.shape)
        loss = criterion(pred, label)
        
        losses.update(loss.item(), img.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print status
        if i % print_every == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Training Time {3:.3f} \t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader),
                                                                  (time.time()-start)*print_every, loss=losses))

    with torch.no_grad():
        
        model.eval()
        val_losses = AverageMeter()
        epoch_psnr = AverageMeter()
        epoch_ssim = AverageMeter()
        
        for i, (val_inputs, val_labels) in enumerate(valid_loader):
            
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_preds = model(val_inputs)
            val_loss = criterion(val_preds, val_labels)
            
            val_losses.update(val_loss.item(), val_inputs.size(0))
            
            metrics = compute_metrics(EvalPrediction(predictions=val_preds, labels=val_labels), scale=scale)

            epoch_psnr.update(metrics['psnr'], val_inputs.size(0))
            epoch_ssim.update(metrics['ssim'], val_inputs.size(0))

        print(f'Validation Loss:{val_losses.avg:.2f}      eval psnr: {epoch_psnr.avg:.2f}     ssim: {epoch_ssim.avg:.4f}')

        if epoch_psnr.avg > best_metric:
            best_epoch = epoch
            best_metric = epoch_psnr.avg

            print(f'best epoch: {epoch}, psnr: {epoch_psnr.avg:.6f}, ssim: {epoch_ssim.avg:.6f}')
            
            # Save checkpoint
            print("Saving checkpoint epoch:", epoch)
            save_checkpoint(epoch, model, optimizer)

#         print('Epoch : {}/{}'.format(epoch_num, epochs))
#         print('Training Loss : {:.4f}'.format(losses.avg))
#         print('Validation Loss: {:.4f}'.format(val_losses.avg))

In [26]:
def main():
    """
    Training.
    """
    global start_epoch, label_map, epoch, checkpoint, decay_lr_at, optimizer, criterion, scale

    # Initialize model or load checkpoint
    if checkpoint is None:
        config = EdsrConfig(
        scale=scale,                               
        n_resblocks=32,
        n_feats=256
    )
        start_epoch = 0
        model = edsr(config)
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
        
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = criterion

    print("Number of epochs: ", epochs)
    
    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate at particular epochs
        if epoch in decay_lr_at:
            adjust_learning_rate(optimizer, decay_lr_to)

        # One epoch's training
        train(train_loader=trainloader,
              valid_loader = validloader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch)

In [27]:
main()


Loaded checkpoint from epoch 22.

Number of epochs:  25
DECAYING learning rate.
 The new LR is 0.000003

Epoch: [22][0/100]	Training Time 1.264 	Loss 0.0070 (0.0070)	
Epoch: [22][5/100]	Training Time 1.211 	Loss 0.0229 (0.0164)	
Epoch: [22][10/100]	Training Time 1.210 	Loss 0.0195 (0.0155)	
Epoch: [22][15/100]	Training Time 1.211 	Loss 0.0192 (0.0158)	
Epoch: [22][20/100]	Training Time 1.215 	Loss 0.0079 (0.0156)	
Epoch: [22][25/100]	Training Time 1.211 	Loss 0.0098 (0.0154)	
Epoch: [22][30/100]	Training Time 1.212 	Loss 0.0113 (0.0151)	
Epoch: [22][35/100]	Training Time 1.214 	Loss 0.0078 (0.0149)	
Epoch: [22][40/100]	Training Time 1.208 	Loss 0.0122 (0.0149)	
Epoch: [22][45/100]	Training Time 1.210 	Loss 0.0117 (0.0147)	
Epoch: [22][50/100]	Training Time 1.210 	Loss 0.0278 (0.0153)	
Epoch: [22][55/100]	Training Time 1.216 	Loss 0.0171 (0.0154)	
Epoch: [22][60/100]	Training Time 1.210 	Loss 0.0143 (0.0155)	
Epoch: [22][65/100]	Training Time 1.210 	Loss 0.0145 (0.0154)	
Epoch: [22][70

In [28]:
torch.cuda.empty_cache()
gc.collect()

0

In [29]:
import os
os.chdir(r'/kaggle/working')

!tar -czf checkpoint.pth.tar 

from IPython.display import FileLink

FileLink(r'checkpoint.pth.tar')

tar: Cowardly refusing to create an empty archive
Try 'tar --help' or 'tar --usage' for more information.
