In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

https://github.com/yjn870/DRRN-pytorch

In [2]:
!pip install datasets
!pip install super-image
!pip install torchsummary

[0mCollecting super-image
  Downloading super_image-0.1.7-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.0/91.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: super-image
Successfully installed super-image-0.1.7
[0mCollecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [3]:
from torch import nn


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ConvLayer, self).__init__()
        self.module = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False)
        )

    def forward(self, x):
        return self.module(x)


class ResidualUnit(nn.Module):
    def __init__(self, num_features):
        super(ResidualUnit, self).__init__()
        self.module = nn.Sequential(
            ConvLayer(num_features, num_features),
            ConvLayer(num_features, num_features)
        )

    def forward(self, h0, x):
        return h0 + self.module(x)


class RecursiveBlock(nn.Module):
    def __init__(self, in_channels, out_channels, U):
        super(RecursiveBlock, self).__init__()
        self.U = U
        self.h0 = ConvLayer(in_channels, out_channels)
        self.ru = ResidualUnit(out_channels)

    def forward(self, x):
        h0 = self.h0(x)
        x = h0
        for i in range(self.U):
            x = self.ru(h0, x)
        return x


class DRRN(nn.Module):
    def __init__(self, B, U, num_channels=1, num_features=128):
        super(DRRN, self).__init__()
        self.rbs = nn.Sequential(*[RecursiveBlock(num_channels if i == 0 else num_features, num_features, U) for i in range(B)])
        self.rec = ConvLayer(num_features, num_channels)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        residual = x
        x = self.rbs(x)
        x = self.rec(x)
        x += residual
        return x

In [4]:
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
from torchsummary import summary

In [6]:
summary(DRRN(1, 9, 3).to(device), (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
              ReLU-1            [-1, 3, 64, 64]               0
            Conv2d-2          [-1, 128, 64, 64]           3,456
         ConvLayer-3          [-1, 128, 64, 64]               0
              ReLU-4          [-1, 128, 64, 64]               0
            Conv2d-5          [-1, 128, 64, 64]         147,456
         ConvLayer-6          [-1, 128, 64, 64]               0
              ReLU-7          [-1, 128, 64, 64]               0
            Conv2d-8          [-1, 128, 64, 64]         147,456
         ConvLayer-9          [-1, 128, 64, 64]               0
     ResidualUnit-10          [-1, 128, 64, 64]               0
             ReLU-11          [-1, 128, 64, 64]               0
           Conv2d-12          [-1, 128, 64, 64]         147,456
        ConvLayer-13          [-1, 128, 64, 64]               0
             ReLU-14          [-1, 128,

In [7]:
import PIL.Image as pil_image
import numpy as np


def load_image(path):
    return pil_image.open(path).convert('RGB')


def generate_lr(image, scale):
    image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
    return image


def modcrop(image, modulo):
    w = image.width - image.width % modulo
    h = image.height - image.height % modulo
    return image.crop((0, 0, w, h))


def generate_patch(image, patch_size, stride):
    for i in range(0, image.height - patch_size + 1, stride):
        for j in range(0, image.width - patch_size + 1, stride):
            yield image.crop((j, i, j + patch_size, i + patch_size))


def image_to_array(image):
    return np.array(image).transpose((2, 0, 1))


def normalize(x):
    return x / 255.0


def denormalize(x):
    if type(x) == torch.Tensor:
        return (x * 255.0).clamp(0.0, 255.0)
    elif type(x) == np.ndarray:
        return (x * 255.0).clip(0.0, 255.0)
    else:
        raise Exception('The denormalize function supports torch.Tensor or np.ndarray types.', type(x))


def rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def PSNR(a, b, max=255.0, shave_border=0):
    assert type(a) == type(b)
    assert (type(a) == torch.Tensor) or (type(a) == np.ndarray)

    a = a[shave_border:a.shape[0]-shave_border, shave_border:a.shape[1]-shave_border]
    b = b[shave_border:b.shape[0]-shave_border, shave_border:b.shape[1]-shave_border]

    if type(a) == torch.Tensor:
        return 10. * ((max ** 2) / ((a - b) ** 2).mean()).log10()
    elif type(a) == np.ndarray:
        return 10. * np.log10((max ** 2) / np.mean(((a - b) ** 2)))
    else:
        raise Exception('The PSNR function supports torch.Tensor or np.ndarray types.', type(a))


def load_weights(model, path):
    state_dict = model.state_dict()
    for n, p in torch.load(path, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)
    return model


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [8]:
from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='train')                          # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset)                                                     # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x2', split='validation'))      # prepare the eval dataset for the PyTorch DataLoader

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading and preparing dataset div2k/bicubic_x2 to /root/.cache/huggingface/datasets/eugenesiow___div2k/bicubic_x2/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/925M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/118M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.53G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/449M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset div2k downloaded and prepared to /root/.cache/huggingface/datasets/eugenesiow___div2k/bicubic_x2/2.0.0/d7599f94c7e662a3eed3547efc7efa52b2ed71082b40fc2e42a693870e35b677. Subsequent calls will reuse this data.


In [9]:
class Args:
    def __init__(self, B, U, num_features, 
                 weights_file, eval_scale, 
                 lr, batch_size, num_workers,
                 clip_grad,
                 num_epochs,
                 seed,
                 outputs_dir,
                    eval = True):
        self.B = B
        self.U = U
        self.num_features = num_features
        self.weights_file = weights_file
        self.lr = lr
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.outputs_dir = outputs_dir
        self.clip_grad = clip_grad
        self.num_epochs = num_epochs
        self.seed = seed
        self.eval = eval
        self.eval_scale = eval_scale

In [10]:
args = Args(B = 1, U = 9, num_features = 32, lr = 0.1, weights_file = None, clip_grad = 0.01, num_epochs = 50, num_workers=2,seed = 123, outputs_dir = "/kaggle/working/", batch_size=128, eval_scale = 4)

In [11]:
import PIL.Image as pil_image
import numpy as np
import torch
import cv2


def load_image(path):
    return pil_image.open(path).convert('RGB')


def generate_lr(image, scale):
    image = image.resize((image.width // scale, image.height // scale), resample=pil_image.BICUBIC)
    image = image.resize((image.width * scale, image.height * scale), resample=pil_image.BICUBIC)
    return image


def modcrop(image, modulo):
    w = image.width - image.width % modulo
    h = image.height - image.height % modulo
    return image.crop((0, 0, w, h))


def generate_patch(image, patch_size, stride):
    for i in range(0, image.height - patch_size + 1, stride):
        for j in range(0, image.width - patch_size + 1, stride):
            yield image.crop((j, i, j + patch_size, i + patch_size))


def image_to_array(image):
    return np.array(image).transpose((2, 0, 1))


def normalize(x):
    return x / 255.0


def denormalize(x):
    if type(x) == torch.Tensor:
        return (x * 255.0).clamp(0.0, 255.0)
    elif type(x) == np.ndarray:
        return (x * 255.0).clip(0.0, 255.0)
    else:
        raise Exception('The denormalize function supports torch.Tensor or np.ndarray types.', type(x))


def rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def PSNR(a, b, max=255.0, shave_border=0):
    assert type(a) == type(b)
    assert (type(a) == torch.Tensor) or (type(a) == np.ndarray)

    a = a[shave_border:a.shape[0]-shave_border, shave_border:a.shape[1]-shave_border]
    b = b[shave_border:b.shape[0]-shave_border, shave_border:b.shape[1]-shave_border]

    if type(a) == torch.Tensor:
        return 10. * ((max ** 2) / ((a - b) ** 2).mean()).log10()
    elif type(a) == np.ndarray:
        return 10. * np.log10((max ** 2) / np.mean(((a - b) ** 2)))
    else:
        raise Exception('The PSNR function supports torch.Tensor or np.ndarray types.', type(a))


def load_weights(model, path):
    state_dict = model.state_dict()
    for n, p in torch.load(path, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)
    return model


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [12]:
def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def denormalize(img):
    img = img.mul(255.0).clamp(0.0, 255.0)
    return img

def calc_psnr(img1, img2, max=255.0):
    return 10. * ((max ** 2) / ((img1 - img2) ** 2).mean()).log10()

def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calc_ssim(img1, img2):
    """calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    """
    img1=img1.detach().cpu().numpy()
    img2 = img2.detach().cpu().numpy()
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

def compute_metrics(eval_prediction, scale):
    preds = eval_prediction.predictions
    labels = eval_prediction.labels

    # from piq import ssim, psnr
    # print(psnr(denormalize(preds), denormalize(labels), data_range=255.),
    #       ssim(denormalize(preds), denormalize(labels), data_range=255.))

    # original = preds[0][0][0][0]

    preds = convert_rgb_to_y(denormalize(preds.squeeze(0)), dim_order='chw')
    labels = convert_rgb_to_y(denormalize(labels.squeeze(0)), dim_order='chw')

    # print(preds[0][0], original * 255.)

    preds = preds[scale:-scale, scale:-scale]
    labels = labels[scale:-scale, scale:-scale]

    # print(calc_psnr(preds, labels), calc_ssim(preds, labels))

    return {
        'psnr': calc_psnr(preds, labels),
        'ssim': calc_ssim(preds, labels)
    }

In [13]:
import random
import numpy as np
from typing import NamedTuple, Tuple, Union

class EvalPrediction(NamedTuple):
    """
    Evaluation output (always contains labels), to be used to compute metrics.
    Parameters:
        predictions (:obj:`np.ndarray`): Predictions of the model.
        labels (:obj:`np.ndarray`): Targets to be matched.
    """

    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    labels: np.ndarray


In [14]:
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from torchvision.transforms.functional import resize, InterpolationMode

import argparse
import os
import copy

model = DRRN(B=args.B, U=args.U, num_features=args.num_features, num_channels=3).to(device)

if args.weights_file is not None:
    model = load_weights(model, args.weights_file)

criterion = nn.MSELoss(reduction='sum')
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)

#     train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

if args.eval == True:
#     eval_dataset = EvalDataset(args.eval_file)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(args.num_epochs):
    lr = args.lr * (0.5 ** ((epoch + 1) // 10))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    model.train()
    epoch_losses = AverageMeter()

    with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
        t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

        for data in train_dataloader:
            inputs, labels = data
            x = resize(inputs, size = (inputs.shape[2] * 2, inputs.shape[3] * 2), interpolation  =InterpolationMode.BICUBIC)
            x = x.to(device)
            
            labels = labels.to(device)

            preds = model(x)

            loss = criterion(preds, labels) / (2 * len(inputs))
            epoch_losses.update(loss.item(), len(inputs))

            optimizer.zero_grad()
            loss.backward()

            nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.clip_grad / lr)

            optimizer.step()

            t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr=lr)
            t.update(len(inputs))

    torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

    if args.eval == True:
        model.eval()
        epoch_psnr = AverageMeter()
        epoch_ssim = AverageMeter()

        for data in eval_dataloader:
            inputs, labels = data
            x = resize(inputs, size = (inputs.shape[2] * 2, inputs.shape[3] * 2), interpolation  =InterpolationMode.BICUBIC)
            x = x.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(x)

#             preds = denormalize(preds.squeeze(0).squeeze(0))
#             labels = denormalize(labels.squeeze(0).squeeze(0))
            res = compute_metrics(EvalPrediction(predictions=preds, labels=labels), args.eval_scale)
            
#             epoch_psnr.update(PSNR(preds, labels, shave_border=args.eval_scale), len(inputs))
            epoch_psnr.update(res['psnr'], len(inputs))
            epoch_ssim.update(res['ssim'], len(inputs))

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

if args.eval ==True:
    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

epoch: 0/49: : 800it [01:31,  8.73it/s, loss=104222.614062, lr=0.1]


eval psnr: 17.54


epoch: 1/49: : 800it [01:24,  9.42it/s, loss=3216.577114, lr=0.1]


eval psnr: 23.24


epoch: 2/49: : 800it [01:23,  9.57it/s, loss=518.249072, lr=0.1]


eval psnr: 25.98


epoch: 3/49: : 800it [01:22,  9.65it/s, loss=223.586935, lr=0.1]


eval psnr: 28.60


epoch: 4/49: : 800it [01:22,  9.70it/s, loss=114.376580, lr=0.1]


eval psnr: 29.91


epoch: 5/49: : 800it [01:19, 10.02it/s, loss=77.490544, lr=0.1]


eval psnr: 31.05


epoch: 6/49: : 800it [01:20,  9.99it/s, loss=66.820327, lr=0.1]


eval psnr: 29.47


epoch: 7/49: : 800it [01:20,  9.94it/s, loss=82.140725, lr=0.1]


eval psnr: 30.38


epoch: 8/49: : 800it [01:20,  9.91it/s, loss=64.899587, lr=0.1]


eval psnr: 29.61


epoch: 9/49: : 800it [01:20,  9.93it/s, loss=61.210113, lr=0.05]


eval psnr: 30.06


epoch: 10/49: : 800it [01:21,  9.86it/s, loss=44.879840, lr=0.05]


eval psnr: 32.39


epoch: 11/49: : 800it [01:22,  9.74it/s, loss=36.160287, lr=0.05]


eval psnr: 32.81


epoch: 12/49: : 800it [01:21,  9.87it/s, loss=37.609413, lr=0.05]


eval psnr: 32.81


epoch: 13/49: : 800it [01:20,  9.96it/s, loss=34.548312, lr=0.05]


eval psnr: 32.51


epoch: 14/49: : 800it [01:20,  9.91it/s, loss=35.483236, lr=0.05]


eval psnr: 32.92


epoch: 15/49: : 800it [01:20,  9.94it/s, loss=32.507244, lr=0.05]


eval psnr: 33.00


epoch: 16/49: : 800it [01:20,  9.91it/s, loss=31.636803, lr=0.05]


eval psnr: 32.95


epoch: 17/49: : 800it [01:20,  9.88it/s, loss=30.941588, lr=0.05]


eval psnr: 33.09


epoch: 18/49: : 800it [01:20,  9.94it/s, loss=35.611694, lr=0.05]


eval psnr: 33.16


epoch: 19/49: : 800it [01:20,  9.89it/s, loss=31.517198, lr=0.025]


eval psnr: 32.96


epoch: 20/49: : 800it [01:20,  9.95it/s, loss=31.120093, lr=0.025]


eval psnr: 32.96


epoch: 21/49: : 800it [01:20,  9.95it/s, loss=32.720859, lr=0.025]


eval psnr: 33.20


epoch: 22/49: : 800it [01:19, 10.09it/s, loss=29.067436, lr=0.025]


eval psnr: 33.19


epoch: 23/49: : 800it [01:19, 10.02it/s, loss=29.325656, lr=0.025]


eval psnr: 33.14


epoch: 24/49: : 800it [01:19, 10.01it/s, loss=29.519595, lr=0.025]


eval psnr: 33.25


epoch: 25/49: : 800it [01:19, 10.05it/s, loss=30.491735, lr=0.025]


eval psnr: 33.35


epoch: 26/49: : 800it [01:19, 10.03it/s, loss=29.752593, lr=0.025]


eval psnr: 33.38


epoch: 27/49: : 800it [01:19, 10.09it/s, loss=28.953802, lr=0.025]


eval psnr: 33.42


epoch: 28/49: : 800it [01:19, 10.04it/s, loss=29.226464, lr=0.025]


eval psnr: 33.43


epoch: 29/49: : 800it [01:20,  9.92it/s, loss=26.253437, lr=0.0125]


eval psnr: 33.49


epoch: 30/49: : 800it [01:20,  9.93it/s, loss=29.988918, lr=0.0125]


eval psnr: 33.52


epoch: 31/49: : 800it [01:19, 10.01it/s, loss=28.832699, lr=0.0125]


eval psnr: 33.53


epoch: 32/49: : 800it [01:19, 10.01it/s, loss=28.198457, lr=0.0125]


eval psnr: 33.51


epoch: 33/49: : 800it [01:20,  9.98it/s, loss=29.387229, lr=0.0125]


eval psnr: 33.51


epoch: 34/49: : 800it [01:20,  9.98it/s, loss=29.492003, lr=0.0125]


eval psnr: 33.44


epoch: 35/49: : 800it [01:19, 10.05it/s, loss=27.961594, lr=0.0125]


eval psnr: 33.56


epoch: 36/49: : 800it [01:22,  9.73it/s, loss=28.446414, lr=0.0125]


eval psnr: 33.56


epoch: 37/49: : 800it [01:19, 10.03it/s, loss=29.815155, lr=0.0125]


eval psnr: 33.57


epoch: 38/49: : 800it [01:20,  9.98it/s, loss=26.145754, lr=0.0125]


eval psnr: 33.62


epoch: 39/49: : 800it [01:20,  9.95it/s, loss=28.067477, lr=0.00625]


eval psnr: 33.62


epoch: 40/49: : 800it [01:20,  9.90it/s, loss=29.923155, lr=0.00625]


eval psnr: 33.64


epoch: 41/49: : 800it [01:19, 10.05it/s, loss=29.110659, lr=0.00625]


eval psnr: 33.64


epoch: 42/49: : 800it [01:23,  9.60it/s, loss=26.097828, lr=0.00625]


eval psnr: 33.56


epoch: 43/49: : 800it [01:20,  9.89it/s, loss=29.476652, lr=0.00625]


eval psnr: 33.54


epoch: 44/49: : 800it [01:19, 10.04it/s, loss=26.502661, lr=0.00625]


eval psnr: 33.60


epoch: 45/49: : 800it [01:22,  9.64it/s, loss=27.460746, lr=0.00625]


eval psnr: 33.67


epoch: 46/49: : 800it [01:20,  9.89it/s, loss=29.545349, lr=0.00625]


eval psnr: 33.69


epoch: 47/49: : 800it [01:20, 10.00it/s, loss=29.946649, lr=0.00625]


eval psnr: 33.61


epoch: 48/49: : 800it [01:19, 10.00it/s, loss=27.045139, lr=0.00625]


eval psnr: 33.55


epoch: 49/49: : 800it [01:19, 10.03it/s, loss=27.913321, lr=0.00313]


eval psnr: 33.69
best epoch: 49, psnr: 33.69
