# Dataset & DataLoader

Use h5py library for reading data

In [1]:
!pip install h5py



In [6]:
!pip3 install tqdm



In [2]:
import h5py
import numpy as np
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


class EvalDataset(Dataset):
    def __init__(self, h5_file):
        super(EvalDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

# Model definition

## TO DO - Define the model by following the given architecture
### SRCNN

conv1: (in_channel) num_channels, (out_channel) 64, (kernel_size) 9, (stride) 1, (padding) 4

conv2: (in_channel) 64, (out_channel) 32, (kernel_size) 5, (stride) 1, (padding) 2

conv3: (in_channel) 32, (out_channel) num_channels, (kernel_size) 5, (stride) 1, (padding) 2

Each layers should have ReLU activation but the last conv layer should not have any activation layer.

In [10]:
from torch import nn


class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        
        return x

In [3]:
from torchvision.models import vgg19
# VGG19 pretrained model
# Classifer를 사용하지 않고, 중간에서 끊어서 feature만 사용함.
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = vgg19(pretrained=True)
    
    def forward(self, x):
        for i in range(20):
            x = self.model.features[i](x)
        return x
model = vgg19(pretrained=True)
print(model.features[2])

Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


# Util functions

In [11]:
import torch
import numpy as np


def convert_rgb_to_y(img):
    if type(img) == np.ndarray:
        return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
    else:
        raise Exception('Unknown Type', type(img))


def convert_rgb_to_ycbcr(img):
    if type(img) == np.ndarray:
        y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
        cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
        cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
        cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
        cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))


def convert_ycbcr_to_rgb(img):
    if type(img) == np.ndarray:
        r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
        g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
        b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
        g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
        b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))


def calc_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Training loop

In [12]:
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm


train_file = "data/91-image_x3.h5"
eval_file = "data/Set5_x3.h5"
outputs_dir = "outputs/"

scale = 3
lr = 1e-4
batch_size = 16
num_epochs = 400
num_workers = 8
seed = 123

outputs_dir = os.path.join(outputs_dir, 'x{}'.format(scale))

if not os.path.exists(outputs_dir):
    os.makedirs(outputs_dir)

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(seed)

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train_dataset = TrainDataset(train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  drop_last=True)
eval_dataset = EvalDataset(eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(num_epochs):
    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % batch_size)) as t:
        t.set_description('epoch: {}/{}'.format(epoch, num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)
            
            preds = model(inputs)

            loss = criterion(preds, labels)

            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

    model.eval()
    epoch_psnr = AverageMeter()

    for data in eval_dataloader:
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)

        epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

    print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

    if epoch_psnr.avg > best_psnr:
        best_epoch = epoch
        best_psnr = epoch_psnr.avg
        best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(outputs_dir, 'best.pth'))

epoch: 0/399: 100%|██████████| 21872/21872 [00:15<00:00, 1414.41it/s, loss=0.004404]
epoch: 1/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 31.84


epoch: 1/399: 100%|██████████| 21872/21872 [00:14<00:00, 1467.18it/s, loss=0.001455]
epoch: 2/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.11


epoch: 2/399: 100%|██████████| 21872/21872 [00:14<00:00, 1509.66it/s, loss=0.001399]
epoch: 3/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.26


epoch: 3/399: 100%|██████████| 21872/21872 [00:15<00:00, 1427.55it/s, loss=0.001372]
epoch: 4/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.28


epoch: 4/399: 100%|██████████| 21872/21872 [00:16<00:00, 1331.20it/s, loss=0.001356]
epoch: 5/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.40


epoch: 5/399: 100%|██████████| 21872/21872 [00:14<00:00, 1483.66it/s, loss=0.001344]
epoch: 6/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.47


epoch: 6/399: 100%|██████████| 21872/21872 [00:15<00:00, 1409.95it/s, loss=0.001333]
epoch: 7/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.47


epoch: 7/399: 100%|██████████| 21872/21872 [00:14<00:00, 1460.64it/s, loss=0.001322]
epoch: 8/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.55


epoch: 8/399: 100%|██████████| 21872/21872 [00:14<00:00, 1532.06it/s, loss=0.001314]
epoch: 9/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.58


epoch: 9/399: 100%|██████████| 21872/21872 [00:14<00:00, 1494.59it/s, loss=0.001306]
epoch: 10/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.59


epoch: 10/399: 100%|██████████| 21872/21872 [00:14<00:00, 1498.95it/s, loss=0.001298]
epoch: 11/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.63


epoch: 11/399: 100%|██████████| 21872/21872 [00:14<00:00, 1508.21it/s, loss=0.001290]
epoch: 12/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.64


epoch: 12/399: 100%|██████████| 21872/21872 [00:14<00:00, 1538.11it/s, loss=0.001285]
epoch: 13/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.72


epoch: 13/399: 100%|██████████| 21872/21872 [00:14<00:00, 1549.92it/s, loss=0.001279]
epoch: 14/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.71


epoch: 14/399: 100%|██████████| 21872/21872 [00:14<00:00, 1557.09it/s, loss=0.001275]
epoch: 15/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.73


epoch: 15/399: 100%|██████████| 21872/21872 [00:14<00:00, 1513.28it/s, loss=0.001270]
epoch: 16/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.74


epoch: 16/399: 100%|██████████| 21872/21872 [00:14<00:00, 1518.18it/s, loss=0.001266]
epoch: 17/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.76


epoch: 17/399: 100%|██████████| 21872/21872 [00:15<00:00, 1420.78it/s, loss=0.001263]
epoch: 18/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 18/399: 100%|██████████| 21872/21872 [00:15<00:00, 1443.66it/s, loss=0.001260]
epoch: 19/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.79


epoch: 19/399: 100%|██████████| 21872/21872 [00:14<00:00, 1472.33it/s, loss=0.001257]
epoch: 20/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 20/399: 100%|██████████| 21872/21872 [00:14<00:00, 1458.32it/s, loss=0.001253]
epoch: 21/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.82


epoch: 21/399: 100%|██████████| 21872/21872 [00:15<00:00, 1438.03it/s, loss=0.001250]
epoch: 22/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.83


epoch: 22/399: 100%|██████████| 21872/21872 [00:15<00:00, 1438.35it/s, loss=0.001247]
epoch: 23/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.88


epoch: 23/399: 100%|██████████| 21872/21872 [00:15<00:00, 1457.25it/s, loss=0.001244]
epoch: 24/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.85


epoch: 24/399: 100%|██████████| 21872/21872 [00:15<00:00, 1450.11it/s, loss=0.001242]
epoch: 25/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.86


epoch: 25/399: 100%|██████████| 21872/21872 [00:14<00:00, 1459.32it/s, loss=0.001239]
epoch: 26/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.87


epoch: 26/399: 100%|██████████| 21872/21872 [00:15<00:00, 1450.79it/s, loss=0.001238]
epoch: 27/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.89


epoch: 27/399: 100%|██████████| 21872/21872 [00:15<00:00, 1380.60it/s, loss=0.001235]
epoch: 28/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 28/399: 100%|██████████| 21872/21872 [00:14<00:00, 1494.73it/s, loss=0.001234]
epoch: 29/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.91


epoch: 29/399: 100%|██████████| 21872/21872 [00:14<00:00, 1515.42it/s, loss=0.001231]
epoch: 30/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.93


epoch: 30/399: 100%|██████████| 21872/21872 [00:13<00:00, 1573.64it/s, loss=0.001229]
epoch: 31/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 31/399: 100%|██████████| 21872/21872 [00:13<00:00, 1602.31it/s, loss=0.001227]
epoch: 32/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.93


epoch: 32/399: 100%|██████████| 21872/21872 [00:14<00:00, 1539.55it/s, loss=0.001226]
epoch: 33/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.93


epoch: 33/399: 100%|██████████| 21872/21872 [00:13<00:00, 1627.41it/s, loss=0.001225]
epoch: 34/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.95


epoch: 34/399: 100%|██████████| 21872/21872 [00:13<00:00, 1626.22it/s, loss=0.001223]
epoch: 35/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.94


epoch: 35/399: 100%|██████████| 21872/21872 [00:13<00:00, 1628.30it/s, loss=0.001221]
epoch: 36/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.87


epoch: 36/399: 100%|██████████| 21872/21872 [00:13<00:00, 1628.19it/s, loss=0.001220]
epoch: 37/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.94


epoch: 37/399: 100%|██████████| 21872/21872 [00:13<00:00, 1630.66it/s, loss=0.001218]
epoch: 38/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 38/399: 100%|██████████| 21872/21872 [00:13<00:00, 1621.77it/s, loss=0.001218]
epoch: 39/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.98


epoch: 39/399: 100%|██████████| 21872/21872 [00:14<00:00, 1497.31it/s, loss=0.001216]
epoch: 40/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.96


epoch: 40/399: 100%|██████████| 21872/21872 [00:13<00:00, 1604.27it/s, loss=0.001214]
epoch: 41/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.00


epoch: 41/399: 100%|██████████| 21872/21872 [00:14<00:00, 1514.09it/s, loss=0.001212]
epoch: 42/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.96


epoch: 42/399: 100%|██████████| 21872/21872 [00:14<00:00, 1522.89it/s, loss=0.001210]
epoch: 43/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.01


epoch: 43/399: 100%|██████████| 21872/21872 [00:14<00:00, 1521.50it/s, loss=0.001210]
epoch: 44/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.00


epoch: 44/399: 100%|██████████| 21872/21872 [00:14<00:00, 1538.42it/s, loss=0.001209]
epoch: 45/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 45/399: 100%|██████████| 21872/21872 [00:13<00:00, 1565.38it/s, loss=0.001207]
epoch: 46/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.00


epoch: 46/399: 100%|██████████| 21872/21872 [00:14<00:00, 1494.11it/s, loss=0.001206]
epoch: 47/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 47/399: 100%|██████████| 21872/21872 [00:14<00:00, 1545.69it/s, loss=0.001204]
epoch: 48/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.01


epoch: 48/399: 100%|██████████| 21872/21872 [00:14<00:00, 1532.03it/s, loss=0.001204]
epoch: 49/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.02


epoch: 49/399: 100%|██████████| 21872/21872 [00:14<00:00, 1531.98it/s, loss=0.001202]
epoch: 50/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.97


epoch: 50/399: 100%|██████████| 21872/21872 [00:14<00:00, 1484.04it/s, loss=0.001201]
epoch: 51/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.03


epoch: 51/399: 100%|██████████| 21872/21872 [00:14<00:00, 1541.07it/s, loss=0.001200]
epoch: 52/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.01


epoch: 52/399: 100%|██████████| 21872/21872 [00:14<00:00, 1533.09it/s, loss=0.001198]
epoch: 53/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.04


epoch: 53/399: 100%|██████████| 21872/21872 [00:14<00:00, 1524.83it/s, loss=0.001197]
epoch: 54/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.01


epoch: 54/399: 100%|██████████| 21872/21872 [00:14<00:00, 1531.49it/s, loss=0.001197]
epoch: 55/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.05


epoch: 55/399: 100%|██████████| 21872/21872 [00:14<00:00, 1537.88it/s, loss=0.001196]
epoch: 56/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.02


epoch: 56/399: 100%|██████████| 21872/21872 [00:14<00:00, 1548.32it/s, loss=0.001194]
epoch: 57/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.01


epoch: 57/399: 100%|██████████| 21872/21872 [00:14<00:00, 1536.68it/s, loss=0.001194]
epoch: 58/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.04


epoch: 58/399: 100%|██████████| 21872/21872 [00:14<00:00, 1544.38it/s, loss=0.001193]
epoch: 59/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.90


epoch: 59/399: 100%|██████████| 21872/21872 [00:14<00:00, 1521.56it/s, loss=0.001191]
epoch: 60/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.03


epoch: 60/399: 100%|██████████| 21872/21872 [00:15<00:00, 1451.02it/s, loss=0.001191]
epoch: 61/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.07


epoch: 61/399: 100%|██████████| 21872/21872 [00:15<00:00, 1445.99it/s, loss=0.001190]
epoch: 62/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 32.99


epoch: 62/399: 100%|██████████| 21872/21872 [00:14<00:00, 1525.53it/s, loss=0.001189]
epoch: 63/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.02


epoch: 63/399: 100%|██████████| 21872/21872 [00:14<00:00, 1539.25it/s, loss=0.001187]
epoch: 64/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.06


epoch: 64/399: 100%|██████████| 21872/21872 [00:14<00:00, 1528.42it/s, loss=0.001187]
epoch: 65/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.05


epoch: 65/399: 100%|██████████| 21872/21872 [00:14<00:00, 1489.88it/s, loss=0.001186]
epoch: 66/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.07


epoch: 66/399: 100%|██████████| 21872/21872 [00:14<00:00, 1558.72it/s, loss=0.001186]
epoch: 67/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.06


epoch: 67/399: 100%|██████████| 21872/21872 [00:14<00:00, 1540.91it/s, loss=0.001184]
epoch: 68/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.05


epoch: 68/399: 100%|██████████| 21872/21872 [00:14<00:00, 1514.24it/s, loss=0.001184]
epoch: 69/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.07


epoch: 69/399: 100%|██████████| 21872/21872 [00:14<00:00, 1461.28it/s, loss=0.001183]
epoch: 70/399:   0%|          | 0/21872 [00:00<?, ?it/s]

eval psnr: 33.06


epoch: 70/399:  88%|████████▊ | 19216/21872 [00:13<00:01, 1374.32it/s, loss=0.001175]


KeyboardInterrupt: 