In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp /content/drive/MyDrive/DATA/VHT-DATA/data.zip /content/data.zip

!unzip /content/data.zip

Archive:  /content/data.zip
   creating: data/
   creating: data/train/
   creating: data/train/HD/
  inflating: data/train/HD/img20.bmp  
  inflating: data/train/HD/img15.bmp  
  inflating: data/train/HD/img12.bmp  
  inflating: data/train/HD/img54.bmp  
  inflating: data/train/HD/img11.bmp  
  inflating: data/train/HD/img95.bmp  
  inflating: data/train/HD/img3.bmp  
  inflating: data/train/HD/img86.bmp  
  inflating: data/train/HD/img67.bmp  
  inflating: data/train/HD/img29.bmp  
  inflating: data/train/HD/img58.bmp  
  inflating: data/train/HD/img27.bmp  
  inflating: data/train/HD/img45.bmp  
  inflating: data/train/HD/img90.bmp  
  inflating: data/train/HD/img43.bmp  
  inflating: data/train/HD/img40.bmp  
  inflating: data/train/HD/img7.bmp  
  inflating: data/train/HD/img77.bmp  
  inflating: data/train/HD/img74.bmp  
  inflating: data/train/HD/img62.bmp  
  inflating: data/train/HD/img88.bmp  
  inflating: data/train/HD/img4.bmp  
  inflating: data/train/HD/img81.bmp  
  infl

In [3]:
%cd /content

/content


In [4]:
!pip install super-image

Collecting super-image
  Downloading super_image-0.1.7-py3-none-any.whl.metadata (14 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.9.0->super-image)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.9.0->super-image)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.9.0->super-image)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.9.0->super-image)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.9.0->super-image)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.9.0->

### Pipeline 1: Image processing algorithm

In [5]:
!mkdir results
!mkdir /content/data/train/SD-up
!mkdir /content/data/val/SD-up

In [22]:
import os
from PIL import Image
from skimage.metrics import structural_similarity as ssim
import numpy as np
import cv2
from tqdm import tqdm


def resize_image(input_path, output_path, scale=2):
    """
    Resize the image to the specified size.

    :param input_path: Path to the input image.
    :param output_path: Path to save the resized image.
    :param scale: Scale factor.
    """
    with Image.open(input_path) as img:
        w = int(img.size[0])*scale
        h = int(img.size[1])*scale
        resized_img = img.resize((w, h), Image.LANCZOS)
        resized_img.save(output_path)
        image = cv2.imread(output_path, cv2.IMREAD_GRAYSCALE)
        kernel = np.array([[0, -1, 0],
                          [-1, 5, -1],
                          [0, -1, 0]])
        image = cv2.filter2D(image, -1, kernel)
        image = cv2.bilateralFilter(image, d=1, sigmaColor=5, sigmaSpace=10)
        cv2.imwrite(output_path, image)

def compute_ssim(original_path, resized_path):
    """
    Compute SSIM between the original and resized images.

    :param original_path: Path to the original image.
    :param resized_path: Path to the resized image.
    :return: SSIM value.
    """
    original = Image.open(original_path).convert('L')
    resized = Image.open(resized_path).convert('L')

    original_np = np.array(original)
    resized_np = np.array(resized)
    return ssim(original_np, resized_np)

def process_images(input_dir, output_dir, size=(1280, 1024)):
    """
    Resize all images in the input directory and compute SSIM.

    :param input_dir: Directory with the original images.
    :param output_dir: Directory to save the resized images.
    :param size: Desired size for resizing.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    score = []
    for filename in tqdm(os.listdir(input_dir)):
        if filename.lower().endswith(('.png', '.jpg', '.bmp')):
            original_path = os.path.join(input_dir, filename)
            resized_path = os.path.join(output_dir, filename)
            resize_image(original_path, resized_path)
            root_image = original_path.replace("SD","HD")
            ssim_value = compute_ssim(root_image, resized_path)
            score.append(ssim_value)
    print(f"\nTotal Image: {len(score)} --- AVG: {sum(score)/len(score)}\n")


print("RESULT TRAINING DATASET:")
input_directory = '/content/data/train/SD'
output_dir = '/content/data/train/SD-up'
process_images(input_directory, output_dir)

print("RESULT VALIDATION DATASET:")
input_directory = '/content/data/val/SD'
output_dir = '/content/data/val/SD-up'
process_images(input_directory, output_dir)

RESULT TRAINING DATASET:


100%|██████████| 99/99 [00:21<00:00,  4.63it/s]



Total Image: 99 --- AVG: 0.49287948635564754

RESULT VALIDATION DATASET:


100%|██████████| 22/22 [00:04<00:00,  4.63it/s]


Total Image: 22 --- AVG: 0.5641593445634667






### Pipeline 2: Using DNN

In [1]:
from super_image.data import EvalDataset, TrainDataset, augment_five_crop
from super_image import Trainer, TrainingArguments, DrlnConfig
from super_image.modeling_utils import (
    BamBlock,
    PreTrainedModel
)
import h5py
import random
import numpy as np
from PIL import Image
from pathlib import Path
import os
from tqdm import tqdm

from torch.utils.data import Dataset
from torchvision.transforms import transforms


def get_scale(lr, hr):
    dim1 = round(hr.width / lr.width)
    dim2 = round(hr.height / lr.height)
    scale = max(dim1, dim2)
    return scale


def resize_image(lr_image, hr_image, scale=None):
    if scale is None:
        scale = get_scale(lr_image, hr_image)
    if lr_image.width * scale != hr_image.width or lr_image.height * scale != hr_image.height:
        hr_width = lr_image.width * scale
        hr_height = lr_image.height * scale
        return hr_image.resize((hr_width, hr_height), resample=Image.BICUBIC)
    return hr_image

def get_scale_from_dataset(dataset):
    scale = None
    if len(dataset) > 0:
        lr = Image.open(dataset[0]['lr'])
        hr = Image.open(dataset[0]['hr'])
        dim1 = round(hr.width / lr.width)
        dim2 = round(hr.height / lr.height)
        scale = max(dim1, dim2)
    return scale

class EvalDataset(Dataset):
    def __init__(self, dataset):
        super(EvalDataset, self).__init__()
        self.dataset = dataset
        self.scale = get_scale_from_dataset(dataset)

    def __getitem__(self, idx):
        lr_image = Image.open(self.dataset[idx]['lr']).convert('RGB')
        hr_image = resize_image(lr_image, Image.open(self.dataset[idx]['hr']).convert('RGB'), scale=self.scale)
        lr = np.array(lr_image)
        hr = np.array(hr_image)
        lr = lr.astype(np.float32).transpose([2, 0, 1]) / 255.0
        hr = hr.astype(np.float32).transpose([2, 0, 1]) / 255.0
        return lr, hr

    def __len__(self):
        return len(self.dataset)

class TrainDataset(Dataset):
    def __init__(self, dataset, patch_size=64):
        super(TrainDataset, self).__init__()
        self.dataset = dataset
        self.patch_size = patch_size
        self.scale = get_scale_from_dataset(dataset)

    @staticmethod
    def random_crop(lr, hr, size, scale):
        lr_left = random.randint(0, lr.shape[1] - size)
        lr_right = lr_left + size
        lr_top = random.randint(0, lr.shape[0] - size)
        lr_bottom = lr_top + size
        hr_left = lr_left * scale
        hr_right = lr_right * scale
        hr_top = lr_top * scale
        hr_bottom = lr_bottom * scale
        lr = lr[lr_top:lr_bottom, lr_left:lr_right]
        hr = hr[hr_top:hr_bottom, hr_left:hr_right]
        return lr, hr

    @staticmethod
    def random_horizontal_flip(lr, hr):
        if random.random() < 0.5:
            lr = lr[:, ::-1, :].copy()
            hr = hr[:, ::-1, :].copy()
        return lr, hr

    @staticmethod
    def random_vertical_flip(lr, hr):
        if random.random() < 0.5:
            lr = lr[::-1, :, :].copy()
            hr = hr[::-1, :, :].copy()
        return lr, hr

    @staticmethod
    def random_rotate_90(lr, hr):
        if random.random() < 0.5:
            lr = np.rot90(lr, axes=(1, 0)).copy()
            hr = np.rot90(hr, axes=(1, 0)).copy()
        return lr, hr

    def __getitem__(self, idx):
        lr_image = Image.open(self.dataset[idx]['lr']).convert('RGB')
        hr_image = resize_image(lr_image, Image.open(self.dataset[idx]['hr']).convert('RGB'), scale=self.scale)
        lr = np.array(lr_image)
        hr = np.array(hr_image)
        lr, hr = self.random_crop(lr, hr, self.patch_size, self.scale)
        lr, hr = self.random_horizontal_flip(lr, hr)
        lr, hr = self.random_vertical_flip(lr, hr)
        lr, hr = self.random_rotate_90(lr, hr)
        lr = lr.astype(np.float32).transpose([2, 0, 1]) / 255.0
        hr = hr.astype(np.float32).transpose([2, 0, 1]) / 255.0
        return lr, hr

    def __len__(self):
        return len(self.dataset)

In [2]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F



def init_weights(modules):
    pass


class UpsampleBlock(nn.Module):
    def __init__(self,
                 n_channels, scale, multi_scale,
                 group=1):
        super(UpsampleBlock, self).__init__()

        if multi_scale:
            self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
            self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
            self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
        else:
            self.up = _UpsampleBlock(n_channels, scale=scale, group=group)

        self.multi_scale = multi_scale

    def forward(self, x, scale):
        if self.multi_scale:
            if scale == 2:
                return self.up2(x)
            elif scale == 3:
                return self.up3(x)
            elif scale == 4:
                return self.up4(x)
        else:
            return self.up(x)


class _UpsampleBlock(nn.Module):
    def __init__(self,
                 n_channels, scale,
                 group=1):
        super(_UpsampleBlock, self).__init__()

        modules = []
        if scale == 2 or scale == 4 or scale == 8:
            for _ in range(int(math.log(scale, 2))):
                modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
                modules += [nn.PixelShuffle(2)]
        elif scale == 3:
            modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
            modules += [nn.PixelShuffle(3)]

        self.body = nn.Sequential(*modules)
        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class MeanShift(nn.Module):
    def __init__(self, mean_rgb, sub):
        super(MeanShift, self).__init__()

        sign = -1 if sub else 1
        r = mean_rgb[0] * sign
        g = mean_rgb[1] * sign
        b = mean_rgb[2] * sign

        self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
        self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.shifter.bias.data = torch.Tensor([r, g, b])

        # Freeze the mean shift layer
        for params in self.shifter.parameters():
            params.requires_grad = False

    def forward(self, x):
        x = self.shifter(x)
        return x


class BasicBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1, dilation=1):
        super(BasicBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad, dilation),
            nn.ReLU(inplace=True)
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class BasicBlockSig(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BasicBlockSig, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.Sigmoid()
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.c1 = BasicBlock(channel, channel // reduction, 3, 1, 3, 3)
        self.c2 = BasicBlock(channel, channel // reduction, 3, 1, 5, 5)
        self.c3 = BasicBlock(channel, channel // reduction, 3, 1, 7, 7)
        self.c4 = BasicBlockSig((channel // reduction) * 3, channel, 3, 1, 1)

    def forward(self, x):
        y = self.avg_pool(x)
        c1 = self.c1(y)
        c2 = self.c2(y)
        c3 = self.c3(y)
        c_out = torch.cat([c1, c2, c3], dim=1)
        y = self.c4(c_out)
        return x * y


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, group=1):
        super(Block, self).__init__()

        self.r1 = ResidualBlock(in_channels, out_channels)
        self.r2 = ResidualBlock(in_channels * 2, out_channels * 2)
        self.r3 = ResidualBlock(in_channels * 4, out_channels * 4)
        self.g = BasicBlock(in_channels * 8, out_channels, 1, 1, 0)
        self.ca = CALayer(in_channels)

    def forward(self, x):
        c0 = x

        r1 = self.r1(c0)
        c1 = torch.cat([c0, r1], dim=1)

        r2 = self.r2(c1)
        c2 = torch.cat([c1, r2], dim=1)

        r3 = self.r3(c2)
        c3 = torch.cat([c2, r3], dim=1)

        g = self.g(c3)
        out = self.ca(g)
        return out


class Bam(nn.Module):
    def __init__(self, in_channels, out_channels, group=1):
        super(Bam, self).__init__()

        self.r1 = ResidualBlock(in_channels, out_channels)
        self.r2 = ResidualBlock(in_channels * 2, out_channels * 2)
        self.r3 = ResidualBlock(in_channels * 4, out_channels * 4)
        self.g = BasicBlock(in_channels * 8, out_channels, 1, 1, 0)
        self.ca = BamBlock(in_channels)

    def forward(self, x):
        c0 = x

        r1 = self.r1(c0)
        c1 = torch.cat([c0, r1], dim=1)

        r2 = self.r2(c1)
        c2 = torch.cat([c1, r2], dim=1)

        r3 = self.r3(c2)
        c3 = torch.cat([c2, r3], dim=1)

        g = self.g(c3)
        out = self.ca(g)
        return out


class DrlnModel(PreTrainedModel):
    config_class = DrlnConfig

    def __init__(self, args):
        super(DrlnModel, self).__init__(args)

        # n_resgroups = args.n_resgroups
        # n_resblocks = args.n_resblocks
        # n_feats = args.n_feats
        # kernel_size = 3
        # reduction = args.reduction
        # scale = args.scale[0]
        # act = nn.ReLU(True)

        self.scale = args.scale
        chs = 64

        self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)

        self.head = nn.Conv2d(3, chs, 3, 1, 1)

        if args.bam:
            self.b1 = Bam(chs, chs)
            self.b2 = Bam(chs, chs)
            self.b3 = Bam(chs, chs)
            self.b4 = Bam(chs, chs)
            self.b5 = Bam(chs, chs)
            self.b6 = Bam(chs, chs)
            self.b7 = Bam(chs, chs)
            self.b8 = Bam(chs, chs)
            self.b9 = Bam(chs, chs)
            self.b10 = Bam(chs, chs)
            self.b11 = Bam(chs, chs)
            self.b12 = Bam(chs, chs)
            self.b13 = Bam(chs, chs)
            self.b14 = Bam(chs, chs)
            self.b15 = Bam(chs, chs)
            self.b16 = Bam(chs, chs)
            self.b17 = Bam(chs, chs)
            self.b18 = Bam(chs, chs)
            self.b19 = Bam(chs, chs)
            self.b20 = Bam(chs, chs)
        else:
            self.b1 = Block(chs, chs)
            self.b2 = Block(chs, chs)
            self.b3 = Block(chs, chs)
            self.b4 = Block(chs, chs)
            self.b5 = Block(chs, chs)
            self.b6 = Block(chs, chs)
            self.b7 = Block(chs, chs)
            self.b8 = Block(chs, chs)
            self.b9 = Block(chs, chs)
            self.b10 = Block(chs, chs)
            self.b11 = Block(chs, chs)
            self.b12 = Block(chs, chs)
            self.b13 = Block(chs, chs)
            self.b14 = Block(chs, chs)
            self.b15 = Block(chs, chs)
            self.b16 = Block(chs, chs)
            self.b17 = Block(chs, chs)
            self.b18 = Block(chs, chs)
            self.b19 = Block(chs, chs)
            self.b20 = Block(chs, chs)

        self.c1 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c2 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c3 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c4 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c5 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c6 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c7 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c8 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c9 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c10 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c11 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c12 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c13 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c14 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c15 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c16 = BasicBlock(chs * 5, chs, 3, 1, 1)
        self.c17 = BasicBlock(chs * 2, chs, 3, 1, 1)
        self.c18 = BasicBlock(chs * 3, chs, 3, 1, 1)
        self.c19 = BasicBlock(chs * 4, chs, 3, 1, 1)
        self.c20 = BasicBlock(chs * 5, chs, 3, 1, 1)

        self.upsample = UpsampleBlock(chs, self.scale, multi_scale=False)
        # self.convert = ConvertBlock(chs, chs, 20)
        self.tail = nn.Conv2d(chs, 3, 3, 1, 1)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)

        b2 = self.b2(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)

        b3 = self.b3(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)
        a1 = o3 + c0

        b4 = self.b4(a1)
        c4 = torch.cat([o3, b4], dim=1)
        o4 = self.c4(c4)

        b5 = self.b5(a1)
        c5 = torch.cat([c4, b5], dim=1)
        o5 = self.c5(c5)

        b6 = self.b6(o5)
        c6 = torch.cat([c5, b6], dim=1)
        o6 = self.c6(c6)
        a2 = o6 + a1

        b7 = self.b7(a2)
        c7 = torch.cat([o6, b7], dim=1)
        o7 = self.c7(c7)

        b8 = self.b8(o7)
        c8 = torch.cat([c7, b8], dim=1)
        o8 = self.c8(c8)

        b9 = self.b9(o8)
        c9 = torch.cat([c8, b9], dim=1)
        o9 = self.c9(c9)
        a3 = o9 + a2

        b10 = self.b10(a3)
        c10 = torch.cat([o9, b10], dim=1)
        o10 = self.c10(c10)

        b11 = self.b11(o10)
        c11 = torch.cat([c10, b11], dim=1)
        o11 = self.c11(c11)

        b12 = self.b12(o11)
        c12 = torch.cat([c11, b12], dim=1)
        o12 = self.c12(c12)
        a4 = o12 + a3

        b13 = self.b13(a4)
        c13 = torch.cat([o12, b13], dim=1)
        o13 = self.c13(c13)

        b14 = self.b14(o13)
        c14 = torch.cat([c13, b14], dim=1)
        o14 = self.c14(c14)

        b15 = self.b15(o14)
        c15 = torch.cat([c14, b15], dim=1)
        o15 = self.c15(c15)

        b16 = self.b16(o15)
        c16 = torch.cat([c15, b16], dim=1)
        o16 = self.c16(c16)
        a5 = o16 + a4

        b17 = self.b17(a5)
        c17 = torch.cat([o16, b17], dim=1)
        o17 = self.c17(c17)

        b18 = self.b18(o17)
        c18 = torch.cat([c17, b18], dim=1)
        o18 = self.c18(c18)

        b19 = self.b19(o18)
        c19 = torch.cat([c18, b19], dim=1)
        o19 = self.c19(c19)

        b20 = self.b20(o19)
        c20 = torch.cat([c19, b20], dim=1)
        o20 = self.c20(c20)
        a6 = o20 + a5

        # c_out = torch.cat([b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20], dim=1)

        # b = self.convert(c_out)
        b_out = a6 + x
        out = self.upsample(b_out, scale=self.scale)

        out = self.tail(out)
        f_out = self.add_mean(out)

        return f_out

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError(f'While copying the parameter named {name}, '
                                           f'whose dimensions in the model are {own_state[name].size()} and '
                                           f'whose dimensions in the checkpoint are {param.size()}.')
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError(f'unexpected key "{name}" in state_dict')

In [16]:
"""
The Trainer class, to easily train a super-image model from scratch.
The design is inspired by the HuggingFace transformers library at
https://github.com/huggingface/transformers/.
"""

import os
import copy
import logging
from typing import Optional, Union, Dict, Callable

from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.optim import lr_scheduler, Adam

from super_image.modeling_utils import PreTrainedModel
from super_image.configuration_utils import PretrainedConfig
from super_image.file_utils import (
    WEIGHTS_NAME,
    WEIGHTS_NAME_SCALE,
    CONFIG_NAME
)
from super_image.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    EvalPrediction,
    set_seed
)
from super_image.utils.metrics import AverageMeter, compute_metrics

logger = logging.getLogger(__name__)


class Trainer:
    """
    Trainer is a simple class implementing the training and eval loop for PyTorch to train a super-image model.
    Args:
        model (:class:`~super_image.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
            The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
            .. note::
                :class:`~super_image.Trainer` is optimized to work with the :class:`~super_image.PreTrainedModel`
                provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
                they work the same way as the super_image models.
        args (:class:`~super_image.TrainingArguments`, `optional`):
            The arguments to tweak for training. Will default to a basic instance of
            :class:`~super_image.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in
            the current directory if not provided.
        train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`):
            The dataset to use for training.
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
             The dataset to use for evaluation.
    """

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        train_dataset: Dataset = None,
        eval_dataset: Optional[Dataset] = None,
    ):
        if args is None:
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = TrainingArguments(output_dir=output_dir)
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)

        if model is None:
            raise RuntimeError("`Trainer` requires a `model`")

        if torch.cuda.is_available():
            self.model = model.cuda()
        else:
            self.model = model
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.best_epoch = 0
        self.best_metric = 0.0

    def train(
            self,
            resume_from_checkpoint: Optional[Union[str, bool]] = None,
            **kwargs,
    ):
        """
        Main training entry point.
        Args:
            resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
                If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
                :class:`~super_image.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
                `args.output_dir` as saved by a previous instance of :class:`~super_image.Trainer`. If present,
                training will resume from the model/optimizer/scheduler states loaded here.
            kwargs:
                Additional keyword arguments used to hide deprecated arguments
        """
        args = self.args

        epochs_trained = 0
        device = args.device
        num_train_epochs = args.num_train_epochs
        learning_rate = args.learning_rate
        train_batch_size = args.train_batch_size
        train_dataset = self.train_dataset
        train_dataloader = self.get_train_dataloader()
        step_size = int(len(train_dataset) / train_batch_size * 200)

        if args.n_gpu > 1:
            self.model = nn.DataParallel(self.model)

        optimizer = Adam(self.model.parameters(), lr=learning_rate)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=self.args.gamma)

        for epoch in range(epochs_trained, num_train_epochs):
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate * (0.1 ** (epoch // int(num_train_epochs * 0.8)))

            self.model.train()
            epoch_losses = AverageMeter()

            with tqdm(total=(len(train_dataset) - len(train_dataset) % train_batch_size)) as t:
                t.set_description(f'epoch: {epoch}/{num_train_epochs - 1}')

                for data in train_dataloader:
                    inputs, labels = data

                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    preds = self.model(inputs)
                    criterion = nn.L1Loss()
                    loss = criterion(preds, labels)

                    epoch_losses.update(loss.item(), len(inputs))

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                    t.set_postfix(loss=f'{epoch_losses.avg:.6f}')
                    t.update(len(inputs))

            self.eval(epoch)

    def eval(self, epoch):
        args = self.args

        if isinstance(self.model, nn.DataParallel):
            scale = self.model.module.config.scale
        else:
            scale = self.model.config.scale
        device = args.device
        eval_dataloader = self.get_eval_dataloader()
        epoch_psnr = AverageMeter()
        epoch_ssim = AverageMeter()

        self.model.eval()

        for data in eval_dataloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = self.model(inputs)

            metrics = compute_metrics(EvalPrediction(predictions=preds, labels=labels), scale=scale)

            epoch_psnr.update(metrics['psnr'], len(inputs))
            epoch_ssim.update(metrics['ssim'], len(inputs))

        print(f'scale:{str(scale)}      eval psnr: {epoch_psnr.avg:.2f}     ssim: {epoch_ssim.avg:.4f}')

        if epoch_psnr.avg > self.best_metric:
            self.best_epoch = epoch
            self.best_metric = epoch_psnr.avg

            print(f'best epoch: {epoch}, psnr: {epoch_psnr.avg:.6f}, ssim: {epoch_ssim.avg:.6f}')
            self.save_model()

    def _load_state_dict_in_model(self, state_dict):
        load_result = self.model.load_state_dict(state_dict, strict=False)

    def _save_checkpoint(self, model, trial, metrics=None):
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        run_dir = self.args.output_dir
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir)

    def save_model(self, output_dir: Optional[str] = None):
        """
        Will save the model, so you can reload it using :obj:`from_pretrained()`.
        Will only save from the main process.
        """

        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        if not isinstance(self.model, PreTrainedModel):
            # Setup scale
            scale = self.model.config.scale
            if scale is not None:
                weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
            else:
                weights_name = WEIGHTS_NAME

            weights = copy.deepcopy(self.model.state_dict())
            torch.save(weights, os.path.join(output_dir, weights_name))
        else:
            self.model.save_pretrained(output_dir)

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.
        """

        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset

        return DataLoader(
            dataset=train_dataset,
            batch_size=self.args.train_batch_size,
            shuffle=True,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def get_eval_dataloader(self) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.
        """

        eval_dataset = self.eval_dataset
        if eval_dataset is None:
            eval_dataset = self.train_dataset

        return DataLoader(
            dataset=eval_dataset,
            batch_size=1,
        )

### Training on our data from 640 to 1280

In [18]:
train_dict = []
input_dir = "/content/data/train/SD"
for filename in tqdm(os.listdir(input_dir)):
  train_dict.append({"lr": os.path.join(input_dir, filename),
                     "hr":os.path.join(input_dir, filename).replace("SD", "HD") })

val_dict = []
input_dir = "/content/data/val/SD"
for filename in tqdm(os.listdir(input_dir)):
  val_dict.append({"lr": os.path.join(input_dir, filename),
                     "hr":os.path.join(input_dir, filename).replace("SD", "HD") })

  0%|          | 0/99 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

In [19]:

train_dataset = TrainDataset(train_dict)
eval_dataset = EvalDataset(val_dict)


training_args = TrainingArguments(
    learning_rate=5e-4,
    output_dir='./results',
    num_train_epochs=100,
)

config = DrlnConfig(
    scale=2,
    bam=True
)
model = DrlnModel(config)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 25.53     ssim: 0.3674
best epoch: 0, psnr: 25.527969, ssim: 0.367359


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.87     ssim: 0.4170
best epoch: 1, psnr: 26.865194, ssim: 0.416983


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.18     ssim: 0.4286
best epoch: 2, psnr: 27.184301, ssim: 0.428633


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.21     ssim: 0.4434
best epoch: 3, psnr: 27.208982, ssim: 0.443379


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.56     ssim: 0.4556
best epoch: 4, psnr: 27.564941, ssim: 0.455599


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.69     ssim: 0.4658
best epoch: 5, psnr: 27.691807, ssim: 0.465843


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.83     ssim: 0.4786
best epoch: 6, psnr: 27.831310, ssim: 0.478590


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.86     ssim: 0.4897
best epoch: 7, psnr: 27.864536, ssim: 0.489660


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.06     ssim: 0.5017
best epoch: 8, psnr: 28.056131, ssim: 0.501683


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.13     ssim: 0.5107
best epoch: 9, psnr: 28.126221, ssim: 0.510682


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.17     ssim: 0.5175
best epoch: 10, psnr: 28.169565, ssim: 0.517501


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.96     ssim: 0.5125


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.21     ssim: 0.5278
best epoch: 12, psnr: 28.205048, ssim: 0.527818


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.29     ssim: 0.5386
best epoch: 13, psnr: 28.287563, ssim: 0.538613


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.32     ssim: 0.5425
best epoch: 14, psnr: 28.315165, ssim: 0.542475


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.29     ssim: 0.5454


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.37     ssim: 0.5490
best epoch: 16, psnr: 28.366848, ssim: 0.549037


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.36     ssim: 0.5497


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.38     ssim: 0.5523
best epoch: 18, psnr: 28.380636, ssim: 0.552309


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.37     ssim: 0.5540


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.37     ssim: 0.5567


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.41     ssim: 0.5569
best epoch: 21, psnr: 28.408871, ssim: 0.556950


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.41     ssim: 0.5584
best epoch: 22, psnr: 28.413900, ssim: 0.558420


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.40     ssim: 0.5579


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.42     ssim: 0.5610
best epoch: 24, psnr: 28.422123, ssim: 0.560966


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.43     ssim: 0.5608
best epoch: 25, psnr: 28.433563, ssim: 0.560825


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.44     ssim: 0.5611
best epoch: 26, psnr: 28.437231, ssim: 0.561058


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.44     ssim: 0.5616
best epoch: 27, psnr: 28.441385, ssim: 0.561615


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.43     ssim: 0.5620


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.42     ssim: 0.5631


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.34     ssim: 0.5629


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.40     ssim: 0.5622


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.28     ssim: 0.5626


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.06     ssim: 0.5574


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.81     ssim: 0.5587


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.24     ssim: 0.5627


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.17     ssim: 0.5590


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.35     ssim: 0.5629


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.42     ssim: 0.5622


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.44     ssim: 0.5637
best epoch: 39, psnr: 28.441875, ssim: 0.563683


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.44     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5641
best epoch: 41, psnr: 28.448162, ssim: 0.564133


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5641


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5647
best epoch: 43, psnr: 28.453985, ssim: 0.564714


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5647
best epoch: 44, psnr: 28.461340, ssim: 0.564672


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5645
best epoch: 45, psnr: 28.464664, ssim: 0.564515


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5648


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5645


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5643
best epoch: 50, psnr: 28.467215, ssim: 0.564324


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5647
best epoch: 52, psnr: 28.468346, ssim: 0.564727


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5647
best epoch: 53, psnr: 28.468918, ssim: 0.564711


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5645


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5649


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5647
best epoch: 57, psnr: 28.469727, ssim: 0.564659


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5647
best epoch: 59, psnr: 28.469902, ssim: 0.564743


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5645
best epoch: 61, psnr: 28.470127, ssim: 0.564519


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5648


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5637


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5641


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5644
best epoch: 65, psnr: 28.470398, ssim: 0.564369


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5634


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5640


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5636


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5637


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5628


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5638


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5638


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.46     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5635


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5643


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5647
best epoch: 80, psnr: 28.474920, ssim: 0.564701


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.47     ssim: 0.5648


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646
best epoch: 82, psnr: 28.477341, ssim: 0.564592


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5645


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647
best epoch: 84, psnr: 28.477667, ssim: 0.564692


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5648


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5648


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647
best epoch: 88, psnr: 28.478025, ssim: 0.564720


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647
best epoch: 89, psnr: 28.478251, ssim: 0.564696


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646
best epoch: 91, psnr: 28.478634, ssim: 0.564646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646
best epoch: 95, psnr: 28.478741, ssim: 0.564638


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5647


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.48     ssim: 0.5646


### Trying training model with 4000 image on Div2K

In [9]:
!pip install datasets==2.15

Collecting datasets==2.15
  Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.15)
  Downloading dill-0.3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting xxhash (from datasets==2.15)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets==2.15)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets==2.15)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
Collecting multiprocess (from datasets==2.15)
  Downloading multiprocess-0.70.15-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [5]:
from datasets import load_dataset

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='train')
train_dataset = TrainDataset(augmented_dataset)
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='validation'))

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=100,
)

config = DrlnConfig(
    scale=2,
    bam=True
)
model = DrlnModel(config)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 27.04     ssim: 0.7495
best epoch: 0, psnr: 27.035873, ssim: 0.749477


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 29.63     ssim: 0.8270
best epoch: 1, psnr: 29.630690, ssim: 0.826990


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 30.71     ssim: 0.8598
best epoch: 2, psnr: 30.705715, ssim: 0.859779


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 31.33     ssim: 0.8833
best epoch: 3, psnr: 31.330223, ssim: 0.883273


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 31.97     ssim: 0.8910
best epoch: 4, psnr: 31.965254, ssim: 0.890954


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.31     ssim: 0.8986
best epoch: 5, psnr: 32.313732, ssim: 0.898591


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.55     ssim: 0.9043
best epoch: 6, psnr: 32.549282, ssim: 0.904265


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.69     ssim: 0.9059
best epoch: 7, psnr: 32.692554, ssim: 0.905928


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.64     ssim: 0.9113


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.80     ssim: 0.9116
best epoch: 9, psnr: 32.802639, ssim: 0.911642


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 32.99     ssim: 0.9116
best epoch: 10, psnr: 32.988308, ssim: 0.911629


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.15     ssim: 0.9154
best epoch: 11, psnr: 33.148254, ssim: 0.915377


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.20     ssim: 0.9166
best epoch: 12, psnr: 33.201290, ssim: 0.916618


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.19     ssim: 0.9177


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.32     ssim: 0.9209
best epoch: 14, psnr: 33.320168, ssim: 0.920853


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.15     ssim: 0.9209


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.59     ssim: 0.9218
best epoch: 16, psnr: 33.589844, ssim: 0.921764


  0%|          | 0/800 [00:00<?, ?it/s]

scale:2      eval psnr: 33.60     ssim: 0.9227
best epoch: 17, psnr: 33.598000, ssim: 0.922681


  0%|          | 0/800 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Training on 100 random sample Div2k

In [21]:
from datasets import load_dataset

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='train')

rand_index = random.sample(range(len(augmented_dataset)), 100)
training_set = []
for i in rand_index:
  training_set.append(augmented_dataset[i])


eval_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='validation')
val_index = random.sample(range(len(eval_dataset)), 20)
val_set = []
for i in val_index:
  val_set.append(eval_dataset[i])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [24]:
train_dataset = TrainDataset(training_set)
eval_dataset = EvalDataset(val_set)


training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=100,
    learning_rate=1e-5
)

config = DrlnConfig(
    scale=2,
    bam=True
)
model = DrlnModel(config)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.10     ssim: 0.3266
best epoch: 0, psnr: 26.100000, ssim: 0.326600


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.27     ssim: 0.3423
best epoch: 1, psnr: 26.270000, ssim: 0.342300


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.57     ssim: 0.3617
best epoch: 2, psnr: 26.570000, ssim: 0.361700


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.62     ssim: 0.3752
best epoch: 3, psnr: 26.620000, ssim: 0.375200


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.71     ssim: 0.4019
best epoch: 4, psnr: 26.710000, ssim: 0.401900


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.83     ssim: 0.4256
best epoch: 5, psnr: 26.830000, ssim: 0.425600


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 26.98     ssim: 0.4264
best epoch: 6, psnr: 26.980000, ssim: 0.426400


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.09     ssim: 0.4489
best epoch: 7, psnr: 27.090000, ssim: 0.448900


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.33     ssim: 0.4529
best epoch: 8, psnr: 27.330000, ssim: 0.452900


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.51     ssim: 0.4712
best epoch: 9, psnr: 27.510000, ssim: 0.471200


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.62     ssim: 0.4957
best epoch: 10, psnr: 27.620000, ssim: 0.495700


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 27.87     ssim: 0.5102
best epoch: 11, psnr: 27.870000, ssim: 0.510200


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.08     ssim: 0.5210
best epoch: 12, psnr: 28.080000, ssim: 0.521000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.12     ssim: 0.5389
best epoch: 13, psnr: 28.120000, ssim: 0.538900


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.21     ssim: 0.5430
best epoch: 14, psnr: 28.210000, ssim: 0.543000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.29     ssim: 0.5750
best epoch: 15, psnr: 28.290000, ssim: 0.575000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.32     ssim: 0.5420
best epoch: 16, psnr: 28.320000, ssim: 0.542000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.45     ssim: 0.5490
best epoch: 17, psnr: 28.450000, ssim: 0.549000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.51     ssim: 0.5620
best epoch: 18, psnr: 28.510000, ssim: 0.562000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.53     ssim: 0.5740
best epoch: 19, psnr: 28.530000, ssim: 0.574000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.49     ssim: 0.5600


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.68     ssim: 0.5790
best epoch: 21, psnr: 28.680000, ssim: 0.579000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.82     ssim: 0.5910
best epoch: 22, psnr: 28.820000, ssim: 0.591000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.84     ssim: 0.5950
best epoch: 23, psnr: 28.840000, ssim: 0.595000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.70     ssim: 0.5850


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.80     ssim: 0.5790


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.97     ssim: 0.5860
best epoch: 26, psnr: 28.970000, ssim: 0.586000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.65     ssim: 0.5980


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.86     ssim: 0.5770


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.81     ssim: 0.6010


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.96     ssim: 0.5820


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.99     ssim: 0.5910
best epoch: 31, psnr: 28.990000, ssim: 0.591000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.75     ssim: 0.5810


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.66     ssim: 0.5840


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.65     ssim: 0.5760


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.92     ssim: 0.5880


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.76     ssim: 0.5980


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.89     ssim: 0.5820


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.75     ssim: 0.6050


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.67     ssim: 0.5780


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.72     ssim: 0.5940


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.95     ssim: 0.5820


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.66     ssim: 0.5830


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.76     ssim: 0.5990


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.77     ssim: 0.5860


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.74     ssim: 0.6030


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.95     ssim: 0.6120


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 29.01     ssim: 0.5890
best epoch: 47, psnr: 29.010000, ssim: 0.589000


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.86     ssim: 0.6010


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.88     ssim: 0.5940


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.83     ssim: 0.6100


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.95     ssim: 0.6050


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.99     ssim: 0.5940


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 29.00     ssim: 0.6030


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.89     ssim: 0.6020


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.94     ssim: 0.5750


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.93     ssim: 0.6120


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.75     ssim: 0.6020


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.92     ssim: 0.5980


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.76     ssim: 0.5790


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 28.88     ssim: 0.6110


  0%|          | 0/96 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Training on our data from 320 to 640

In [8]:
!mkdir /content/data/train/320
!mkdir /content/data/val/320

In [14]:
import os
from PIL import Image
from skimage.metrics import structural_similarity as ssim
import numpy as np
import cv2
from tqdm import tqdm


def scale_image(input_path, output_path, size=(512, 640)):
    """
    Resize the image to the specified size.

    :param input_path: Path to the input image.
    :param output_path: Path to save the resized image.
    :param size: Desired size as a tuple (width, height).
    """
    with Image.open(input_path) as img:
        img = img.resize((int(img.size[0]/2), int(img.size[1]/2)), Image.LANCZOS)
        img.save(output_path)


def compute_ssim(original_path, resized_path):
    """
    Compute SSIM between the original and resized images.

    :param original_path: Path to the original image.
    :param resized_path: Path to the resized image.
    :return: SSIM value.
    """
    original = Image.open(original_path).convert('L')
    resized = Image.open(resized_path).convert('L')

    original_np = np.array(original)
    resized_np = np.array(resized)
    return ssim(original_np, resized_np)

def process_images(input_dir, output_dir, size=(1280, 1024)):
    """
    Resize all images in the input directory and compute SSIM.

    :param input_dir: Directory with the original images.
    :param output_dir: Directory to save the resized images.
    :param size: Desired size for resizing.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    score = []
    for filename in tqdm(os.listdir(input_dir)):
        if filename.lower().endswith(('.png', '.jpg', '.bmp')):
            original_path = os.path.join(input_dir, filename)
            resized_path = os.path.join(output_dir, filename)
            scale_image(original_path, resized_path, size)


print("RESULT TRAINING DATASET:")
input_directory = '/content/data/train/SD'
output_dir = '/content/data/train/320'
process_images(input_directory, output_dir)

print("RESULT VALIDATION DATASET:")
input_directory = '/content/data/val/SD'
output_dir = '/content/data/val/320'
process_images(input_directory, output_dir)



RESULT TRAINING DATASET:


100%|██████████| 99/99 [00:00<00:00, 268.86it/s]


RESULT VALIDATION DATASET:


100%|██████████| 22/22 [00:00<00:00, 267.14it/s]


In [10]:
train_dict = []
input_dir = "/content/data/train/320"
for filename in tqdm(os.listdir(input_dir)):
  train_dict.append({"lr": os.path.join(input_dir, filename),
                     "hr":os.path.join(input_dir, filename).replace("320", "SD") })

val_dict = []
input_dir = "/content/data/val/320"
for filename in tqdm(os.listdir(input_dir)):
  val_dict.append({"lr": os.path.join(input_dir, filename),
                     "hr":os.path.join(input_dir, filename).replace("320", "SD") })

100%|██████████| 99/99 [00:00<00:00, 181246.66it/s]
100%|██████████| 22/22 [00:00<00:00, 142619.30it/s]


In [17]:
train_dataset = TrainDataset(train_dict)
eval_dataset = EvalDataset(val_dict)


training_args = TrainingArguments(
    learning_rate=5e-4,
    output_dir='./results',
    num_train_epochs=100,
)

config = DrlnConfig(
    scale=2,
    bam=True
)
model = DrlnModel(config)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 30.18     ssim: 0.7637
best epoch: 0, psnr: 30.183573, ssim: 0.763726


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 31.97     ssim: 0.8147
best epoch: 1, psnr: 31.970394, ssim: 0.814731


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 32.90     ssim: 0.8296
best epoch: 2, psnr: 32.898811, ssim: 0.829625


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 33.53     ssim: 0.8439
best epoch: 3, psnr: 33.526630, ssim: 0.843949


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 34.39     ssim: 0.8577
best epoch: 4, psnr: 34.391670, ssim: 0.857709


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 35.02     ssim: 0.8688
best epoch: 5, psnr: 35.019264, ssim: 0.868815


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 35.53     ssim: 0.8773
best epoch: 6, psnr: 35.534950, ssim: 0.877266


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 35.82     ssim: 0.8827
best epoch: 7, psnr: 35.824482, ssim: 0.882659


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.04     ssim: 0.8863
best epoch: 8, psnr: 36.039036, ssim: 0.886274


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.22     ssim: 0.8892
best epoch: 9, psnr: 36.216621, ssim: 0.889165


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.42     ssim: 0.8922
best epoch: 10, psnr: 36.420349, ssim: 0.892248


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.46     ssim: 0.8941
best epoch: 11, psnr: 36.455921, ssim: 0.894081


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.67     ssim: 0.8967
best epoch: 12, psnr: 36.671715, ssim: 0.896709


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.72     ssim: 0.8973
best epoch: 13, psnr: 36.718361, ssim: 0.897272


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.82     ssim: 0.8988
best epoch: 14, psnr: 36.823860, ssim: 0.898813


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.82     ssim: 0.9006
best epoch: 15, psnr: 36.824715, ssim: 0.900649


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.03     ssim: 0.9016
best epoch: 16, psnr: 37.029041, ssim: 0.901608


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.13     ssim: 0.9025
best epoch: 17, psnr: 37.127602, ssim: 0.902527


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.17     ssim: 0.9029
best epoch: 18, psnr: 37.169525, ssim: 0.902873


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.22     ssim: 0.9039
best epoch: 19, psnr: 37.216541, ssim: 0.903881


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.97     ssim: 0.9042


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.25     ssim: 0.9047
best epoch: 21, psnr: 37.247566, ssim: 0.904723


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.25     ssim: 0.9044
best epoch: 22, psnr: 37.249855, ssim: 0.904413


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.26     ssim: 0.9046
best epoch: 23, psnr: 37.259773, ssim: 0.904586


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.32     ssim: 0.9058
best epoch: 24, psnr: 37.315411, ssim: 0.905804


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.39     ssim: 0.9062
best epoch: 25, psnr: 37.387402, ssim: 0.906220


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.41     ssim: 0.9066
best epoch: 26, psnr: 37.414322, ssim: 0.906629


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.96     ssim: 0.9058


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.33     ssim: 0.9069


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 35.95     ssim: 0.9043


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.79     ssim: 0.9063


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.06     ssim: 0.9070


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.16     ssim: 0.9069


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.10     ssim: 0.9056


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 36.51     ssim: 0.9051


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.45     ssim: 0.9074
best epoch: 35, psnr: 37.449356, ssim: 0.907351


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.18     ssim: 0.9076


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.42     ssim: 0.9074


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.45     ssim: 0.9076
best epoch: 38, psnr: 37.451782, ssim: 0.907579


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.52     ssim: 0.9079
best epoch: 39, psnr: 37.517181, ssim: 0.907885


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.47     ssim: 0.9081


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.32     ssim: 0.9078


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.09     ssim: 0.9077


  0%|          | 0/96 [00:00<?, ?it/s]

scale:2      eval psnr: 37.05     ssim: 0.9078


  0%|          | 0/96 [00:00<?, ?it/s]

KeyboardInterrupt: 