<a href="https://colab.research.google.com/github/IM2COLD/MachineLearningWithKaggle/blob/main/llie_mirnet_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install numpy opencv-python torchmetrics torchvision tqdm

Collecting torchmetrics
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.9.0 torchmetrics-1.2.0


In [None]:
!gdown https://drive.google.com/uc?id=1DdGIJ4PZPlF2ikl8mNM9V-PdVxVLbQi6
!gdown https://drive.google.com/uc?id=1dzuLCk9_gE2bFF222n3-7GVUlSVHpMYC
!mkdir -p dataset
!unzip -oq -d dataset lol_dataset.zip
!unzip -oq -d dataset LOL-v2.zip

Downloading...
From: https://drive.google.com/uc?id=1DdGIJ4PZPlF2ikl8mNM9V-PdVxVLbQi6
To: /content/lol_dataset.zip
100% 347M/347M [00:01<00:00, 181MB/s]
Downloading...
From: https://drive.google.com/uc?id=1dzuLCk9_gE2bFF222n3-7GVUlSVHpMYC
To: /content/LOL-v2.zip
100% 1.05G/1.05G [00:04<00:00, 222MB/s]


In [None]:
import cmath
import cv2
import google
#import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchmetrics
import torchvision
from tqdm import tqdm

In [None]:
## Learning Enriched Features for Fast Image Restoration and Enhancement
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)
## https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/

##########################################################################
##---------- Selective Kernel Feature Fusion (SKFF) ----------
class SKFF(torch.nn.Module):
    def __init__(self, in_channels, height=3,reduction=8,bias=False):
        super(SKFF, self).__init__()

        self.height = height
        d = max(int(in_channels/reduction),4)

        self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.conv_du = torch.nn.Sequential(torch.nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), torch.nn.LeakyReLU(0.2))

        self.fcs = torch.nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(torch.nn.Conv2d(d, in_channels, kernel_size=1, stride=1,bias=bias))

        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats =  inp_feats[0].shape[1]


        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])

        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)

        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        # stx()
        attention_vectors = self.softmax(attention_vectors)

        feats_V = torch.sum(inp_feats*attention_vectors, dim=1)

        return feats_V

class ContextBlock(torch.nn.Module):

    def __init__(self, n_feat, bias=False):
        super(ContextBlock, self).__init__()

        self.conv_mask = torch.nn.Conv2d(n_feat, 1, kernel_size=1, bias=bias)
        self.softmax = torch.nn.Softmax(dim=2)

        self.channel_add_conv = torch.nn.Sequential(
            torch.nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias),
            torch.nn.LeakyReLU(0.2),
            torch.nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
        )

    def modeling(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        context_mask = self.conv_mask(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = self.softmax(context_mask)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(3)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.modeling(x)

        # [N, C, 1, 1]
        channel_add_term = self.channel_add_conv(context)
        x = x + channel_add_term

        return x

##########################################################################
### --------- Residual Context Block (RCB) ----------
class RCB(torch.nn.Module):
    def __init__(self, n_feat, kernel_size=3, reduction=8, bias=False, groups=1):
        super(RCB, self).__init__()

        act = torch.nn.LeakyReLU(0.2)

        self.body = torch.nn.Sequential(
            torch.nn.Conv2d(n_feat, n_feat, kernel_size=3, stride=1, padding=1, bias=bias, groups=groups),
            act,
            torch.nn.Conv2d(n_feat, n_feat, kernel_size=3, stride=1, padding=1, bias=bias, groups=groups)
        )

        self.act = act

        self.gcnet = ContextBlock(n_feat, bias=bias)

    def forward(self, x):
        res = self.body(x)
        res = self.act(self.gcnet(res))
        res += x
        return res


##########################################################################
##---------- Resizing Modules ----------
class Down(torch.nn.Module):
    def __init__(self, in_channels, chan_factor, bias=False):
        super(Down, self).__init__()

        self.bot = torch.nn.Sequential(
            torch.nn.AvgPool2d(2, ceil_mode=True, count_include_pad=False),
            torch.nn.Conv2d(in_channels, int(in_channels*chan_factor), 1, stride=1, padding=0, bias=bias)
            )

    def forward(self, x):
        return self.bot(x)

class DownSample(torch.nn.Module):
    def __init__(self, in_channels, scale_factor, chan_factor=2, kernel_size=3):
        super(DownSample, self).__init__()
        self.scale_factor = int(np.log2(scale_factor))

        modules_body = []
        for i in range(self.scale_factor):
            modules_body.append(Down(in_channels, chan_factor))
            in_channels = int(in_channels * chan_factor)

        self.body = torch.nn.Sequential(*modules_body)

    def forward(self, x):
        x = self.body(x)
        return x

class Up(torch.nn.Module):
    def __init__(self, in_channels, chan_factor, bias=False):
        super(Up, self).__init__()

        self.bot = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, int(in_channels//chan_factor), 1, stride=1, padding=0, bias=bias),
            torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias)
            )

    def forward(self, x):
        return self.bot(x)

class UpSample(torch.nn.Module):
    def __init__(self, in_channels, scale_factor, chan_factor=2, kernel_size=3):
        super(UpSample, self).__init__()
        self.scale_factor = int(np.log2(scale_factor))

        modules_body = []
        for i in range(self.scale_factor):
            modules_body.append(Up(in_channels, chan_factor))
            in_channels = int(in_channels // chan_factor)

        self.body = torch.nn.Sequential(*modules_body)

    def forward(self, x):
        x = self.body(x)
        return x


##########################################################################
##---------- Multi-Scale Resiudal Block (MRB) ----------
class MRB(torch.nn.Module):
    def __init__(self, n_feat, height, width, chan_factor, bias,groups):
        super(MRB, self).__init__()

        self.n_feat, self.height, self.width = n_feat, height, width

        self.dau_top = RCB(int(n_feat*chan_factor**0), bias=bias, groups=groups)
        self.dau_mid = RCB(int(n_feat*chan_factor**1), bias=bias, groups=groups)
        self.dau_bot = RCB(int(n_feat*chan_factor**2), bias=bias, groups=groups)

        self.down2 = DownSample(int((chan_factor**0)*n_feat),2,chan_factor)
        self.down4 = torch.nn.Sequential(
            DownSample(int((chan_factor**0)*n_feat),2,chan_factor),
            DownSample(int((chan_factor**1)*n_feat),2,chan_factor)
        )

        self.up21_1 = UpSample(int((chan_factor**1)*n_feat),2,chan_factor)
        self.up21_2 = UpSample(int((chan_factor**1)*n_feat),2,chan_factor)
        self.up32_1 = UpSample(int((chan_factor**2)*n_feat),2,chan_factor)
        self.up32_2 = UpSample(int((chan_factor**2)*n_feat),2,chan_factor)

        self.conv_out = torch.nn.Conv2d(n_feat, n_feat, kernel_size=1, padding=0, bias=bias)

        # only two inputs for SKFF
        self.skff_top = SKFF(int(n_feat*chan_factor**0), 2)
        self.skff_mid = SKFF(int(n_feat*chan_factor**1), 2)

    def forward(self, x):
        x_top = x.clone()
        x_mid = self.down2(x_top)
        x_bot = self.down4(x_top)

        x_top = self.dau_top(x_top)
        x_mid = self.dau_mid(x_mid)
        x_bot = self.dau_bot(x_bot)

        x_mid = self.skff_mid([x_mid, self.up32_1(x_bot)])
        x_top = self.skff_top([x_top, self.up21_1(x_mid)])

        x_top = self.dau_top(x_top)
        x_mid = self.dau_mid(x_mid)
        x_bot = self.dau_bot(x_bot)

        x_mid = self.skff_mid([x_mid, self.up32_2(x_bot)])
        x_top = self.skff_top([x_top, self.up21_2(x_mid)])

        out = self.conv_out(x_top)
        out = out + x

        return out

##########################################################################
##---------- Recursive Residual Group (RRG) ----------
class RRG(torch.nn.Module):
    def __init__(self, n_feat, n_MRB, height, width, chan_factor, bias=False, groups=1):
        super(RRG, self).__init__()
        modules_body = [MRB(n_feat, height, width, chan_factor, bias, groups) for _ in range(n_MRB)]
        modules_body.append(torch.nn.Conv2d(n_feat, n_feat, kernel_size=3, stride=1, padding=1, bias=bias))
        self.body = torch.nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res


##########################################################################
##---------- MIRNet  -----------------------
class MIRNet_v2(torch.nn.Module):
    def __init__(self,
        inp_channels=3,
        out_channels=3,
        n_feat=80,
        chan_factor=1.5,
        n_RRG=4,
        n_MRB=2,
        height=3,
        width=2,
        scale=1,
        bias=False,
        task= None
    ):
        super(MIRNet_v2, self).__init__()

        kernel_size=3
        self.task = task

        self.conv_in = torch.nn.Conv2d(inp_channels, n_feat, kernel_size=3, padding=1, bias=bias)

        modules_body = []

        modules_body.append(RRG(n_feat, n_MRB, height, width, chan_factor, bias, groups=1))
        modules_body.append(RRG(n_feat, n_MRB, height, width, chan_factor, bias, groups=2))
        modules_body.append(RRG(n_feat, n_MRB, height, width, chan_factor, bias, groups=4))
        modules_body.append(RRG(n_feat, n_MRB, height, width, chan_factor, bias, groups=4))

        self.body = torch.nn.Sequential(*modules_body)
        self.conv_out = torch.nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=1, bias=bias)


    def forward(self, inp_img):
        shallow_feats = self.conv_in(inp_img)
        deep_feats = self.body(shallow_feats)

        if self.task == 'defocus_deblurring':
            deep_feats += shallow_feats
            out_img = self.conv_out(deep_feats)

        else:
            out_img = self.conv_out(deep_feats)
            out_img += inp_img

        return out_img


In [None]:
class CharbonnierLoss(torch.nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps:float = 1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt((diff ** 2) + (self.eps ** 2)))
        return loss

In [None]:
class PairedRandomCrop(torch.nn.Module):
  def __init__(self, size:torch.IntTensor):
    super(PairedRandomCrop, self).__init__()
    self.size:torch.IntTensor = size
  def forward(self, data_list:list[torch.Tensor]):
    location:torch.IntTensor = (torch.rand(2) * (torch.IntTensor(list(data_list[0].size())[1:]) - self.size)).to(dtype = torch.int)
    new_list:list[torch.Tensor] = [torchvision.transforms.functional.crop(data, location[0], location[1], self.size[0], self.size[1]) for data in data_list]
    return new_list

In [None]:
class LOLDataset(torch.utils.data.Dataset):
	def __init__(self, root_dir:str, sub_dir:list[str], prefix:list[str] = ["", ""], transform = None) -> None:
		super().__init__()
		self.root_dir:str = root_dir
		self.sub_dir:list[str] = sub_dir
		self.prefix:list[str] = prefix
		self.transform = transform
		self.file_name:list[str] = sorted(os.listdir(os.path.join(self.root_dir, self.sub_dir[0])))
		self.file_name:list[str] = [fn.removeprefix(self.prefix[0]) for fn in self.file_name if fn.endswith(".png")]
	def __len__(self) -> int:
		return len(self.file_name)
	def __getitem__(self, idx:int) -> list[torch.FloatTensor]:
		file_path:list[str] = [os.path.join(self.root_dir, sd, pf + self.file_name[idx]) for sd, pf in zip(self.sub_dir, self.prefix)]
		data:list[np.array] = [torch.as_tensor(cv2.imread(fp)).float() for fp in file_path]
		data:list[np.array] = [np.transpose(d, [2, 0, 1]) for d in data]
		data:list[torch.Tensor] = [torch.as_tensor(d) for d in data]
		data:list[torch.FloatTensor] = [d.float() for d in data]
		if self.transform != None:
			data = self.transform(data)
		return data

In [None]:
class BGR2LAB(torch.nn.Module):
  def __init__(self) -> None:
    super(BGR2LAB, self).__init__()
  def forward(self, data:torch.Tensor) -> torch.Tensor:
    data:torch.Tensor = data.permute(0, 2, 3, 1)
    data:list[np.ndarray] = [img.detach().cpu().numpy() for img in data]
    data:list[np.ndarray] = [cv2.cvtColor(img, cv2.COLOR_BGR2LAB) for img in data]
    data:np.ndarray = np.array([img.tolist() for img in data])
    data:np.ndarray = np.transpose(data, [0, 3, 1, 2])
    data:torch.Tensor = torch.as_tensor(data)
    return data

In [None]:
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore

In [None]:
def train(dataset:list[torch.utils.data.Dataset], device:str):
    epochs = 64
    batch_size = 2
    loader:list[torch.utils.data.DataLoader] = [
      torch.utils.data.DataLoader(dataset[0], batch_size = batch_size, shuffle = True),
      torch.utils.data.DataLoader(dataset[1], batch_size = 1, shuffle = False),
    ]
    model:torch.nn.Module = MIRNet_v2().to(device = device)
    bgr2lab:torch.nn.Module = BGR2LAB().to(device = device)
    criterion:torch.nn.Module = CharbonnierLoss().to(device = device)
    optimizer:torch.optim.Optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler:torch.optim.lr_scheduler.LRScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5)
    metric:list[torchmetrics.metric.Metric] = [
        torchmetrics.image.PeakSignalNoiseRatio().to(device = device),
        torchmetrics.image.StructuralSimilarityIndexMeasure().to(device = device),
    ]

    best_val_loss:float = cmath.inf

    for epoch in range(1, epochs + 1):
        sum_loss:list[float] = [0.0 for _ in range(2)]
        sum_metric:list[float] = [0.0 for m in metric]

        loop = tqdm(enumerate(loader[0]), total = len(loader[0]), leave = False)

        model.train()
        for idx, data in loop:
            optimizer.zero_grad()
            data:list[torch.Tensor] = [d.to(device = device) for d in data]
            data[0]:torch.Tensor = model(data[0])
            #data:list[torch.Tensor] = [bgr2lab(d) for d in data]
            loss:torch.Tensor = criterion(data[0], data[1])
            loss.backward()
            optimizer.step()

            sum_loss[0] += loss.item()
            loop.set_description(f"Epoch {epoch}/{epochs}")

        current_lr = scheduler.optimizer.param_groups[0]["lr"]
        scheduler.step()

        model.eval()
        with torch.no_grad():
            loop = tqdm(enumerate(loader[1]), total = len(loader[1]), leave = False)

            for idx, data in loop:
                data:list[torch.Tensor] = [d.to(device = device) for d in data]
                data[0]:torch.Tensor = model(data[0])
                #data:list[torch.Tensor] = [bgr2lab(d) for d in data]
                loss:torch.Tensor = criterion(data[0], data[1])
                sum_loss[1] += loss.item()
                for i in range(len(metric)):
                  sum_metric[i] += metric[i](data[0], data[1]).item()
                loop.set_description("valid")

        avg_loss:list[float] = [sum_loss[i] / len(loader[i]) for i in range(2)]
        avg_metric:list[float] = [s / len(loader[1]) for s in sum_metric]

        print(f"Epoch: {epoch}\ttrain_loss: {avg_loss[0]:.4f}\tval_loss: {avg_loss[1]:.4f}\tpsnr: {avg_metric[0]:.4f}\tssim: {avg_metric[1]:.4f}")

        if best_val_loss > avg_loss[1]:
            print("=" * 80)
            print(f"val_loss is improved from {best_val_loss:.4f} to {avg_loss[1]:.4f}\t saved current weight")
            print("=" * 80)
            best_val_loss:float = avg_loss[1]

            torch.save(
              {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": criterion
              },
              os.path.join("drive", "My Drive", "llie", "weight", f"{epoch:03d}.pth")
            )

In [None]:
def test(dataset:torch.utils.data.Dataset, device:str):
  loader:torch.utils.data.DataLoader = torch.utils.data.DataLoader(dataset, batch_size = 1)
  last_cp_path:str = sorted(os.listdir(os.path.join("drive", "My Drive", "llie", "weight")))[-1]
  cp:dict = torch.load(os.path.join("drive", "My Drive", "llie", "weight", last_cp_path))
  model:torch.nn.Module = MIRNet_v2().to(device = device)
  criterion:torch.nn.Module = CharbonnierLoss().to(device = device)
  metric:list[torchmetrics.metric.Metric] = [
    torchmetrics.image.PeakSignalNoiseRatio().to(device = device),
    torchmetrics.image.StructuralSimilarityIndexMeasure().to(device = device),
  ]
  model.load_state_dict(cp["model_state_dict"])
  model.eval()
  with torch.no_grad():
    loop = tqdm(enumerate(loader), total = len(loader), leave = False)
    for idx, data in loop:
      data:list[torch.Tensor] = [d.to(device = device) for d in data]
      data.append(model(data[0]))
      loss:torch.Tensor = criterion(data[2], data[1])
      metric_val:list[float] = [m(data[2], data[1]).item() for m in metric]
      print(f"idx: {idx}\tloss: {loss.item():.4f}\tpsnr: {metric_val[0]:.4f}\tssim: {metric_val[1]:.4f}")
      for i, d in enumerate(data):
        try:
          os.mkdir(os.path.join("drive", "My Drive", "llie", "img", str(i)))
        except:
          pass
        for img in d.permute(0, 2, 3, 1).detach().cpu().numpy():
          cv2.imwrite(os.path.join("drive", "My Drive", "llie", "img", str(i), f"{idx:03d}.png"), img)
      loop.set_description("test")

In [None]:
if __name__ == "__main__":
  device:str = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
  )
  transform:torch.nn.Module = PairedRandomCrop(torch.IntTensor([360, 540]))
  dataset:list[torch.utils.data.Dataset] = [
    LOLDataset(root_dir = os.path.join("dataset", "LOL-v2", "Real_captured", "Train"), sub_dir = ["Low", "Normal"], prefix = ["low", "normal"]),
    LOLDataset(root_dir = os.path.join("dataset", "LOL-v2", "Real_captured", "Test"), sub_dir = ["Low", "Normal"], prefix = ["low", "normal"]),
    LOLDataset(root_dir = os.path.join("dataset", "lol_dataset", "our485"), sub_dir = ["low", "high"]),
    LOLDataset(root_dir = os.path.join("dataset", "lol_dataset", "eval15"), sub_dir = ["low", "high"]),
  ]
  seed_everything(42)
  google.colab.drive.mount("drive")
  train(dataset[0:2], device)
  test(dataset[3], device)
  google.colab.drive.flush_and_unmount()

Mounted at drive




Epoch: 1	train_loss: 46.7869	val_loss: 24.9756	psnr: 18.2437	ssim: 0.8292
val_loss is improved from inf to 24.9756	 saved current weight




Epoch: 2	train_loss: 33.1583	val_loss: 21.8736	psnr: 19.9133	ssim: 0.8255
val_loss is improved from 24.9756 to 21.8736	 saved current weight




Epoch: 3	train_loss: 30.0386	val_loss: 24.9160	psnr: 19.2515	ssim: 0.8164




Epoch: 4	train_loss: 26.4952	val_loss: 30.6292	psnr: 17.9175	ssim: 0.8026




Epoch: 5	train_loss: 25.2449	val_loss: 22.8609	psnr: 19.4839	ssim: 0.8352




Epoch: 6	train_loss: 24.2368	val_loss: 20.9830	psnr: 20.7243	ssim: 0.8145
val_loss is improved from 21.8736 to 20.9830	 saved current weight




Epoch: 7	train_loss: 23.8592	val_loss: 22.2224	psnr: 20.2595	ssim: 0.8172




Epoch: 8	train_loss: 22.8332	val_loss: 24.8187	psnr: 19.1638	ssim: 0.8211




Epoch: 9	train_loss: 22.0199	val_loss: 26.5376	psnr: 19.1195	ssim: 0.8118




Epoch: 10	train_loss: 21.7486	val_loss: 34.4776	psnr: 17.1187	ssim: 0.8019




Epoch: 11	train_loss: 21.2831	val_loss: 33.5307	psnr: 17.2777	ssim: 0.7892




Epoch: 12	train_loss: 20.0867	val_loss: 26.9410	psnr: 18.9551	ssim: 0.8101




Epoch: 13	train_loss: 19.8137	val_loss: 23.3677	psnr: 19.8868	ssim: 0.8160




Epoch: 14	train_loss: 19.4813	val_loss: 32.6198	psnr: 17.4236	ssim: 0.7998




Epoch: 15	train_loss: 18.7568	val_loss: 29.5387	psnr: 18.2496	ssim: 0.8091




Epoch: 16	train_loss: 17.9549	val_loss: 29.7639	psnr: 18.2283	ssim: 0.8075




Epoch: 17	train_loss: 17.5312	val_loss: 25.2586	psnr: 19.5286	ssim: 0.8128




Epoch: 18	train_loss: 17.1731	val_loss: 30.2222	psnr: 18.2282	ssim: 0.8018




Epoch: 19	train_loss: 16.7244	val_loss: 30.1618	psnr: 18.1610	ssim: 0.8048




Epoch: 20	train_loss: 16.4144	val_loss: 32.0599	psnr: 17.7914	ssim: 0.7994




Epoch: 21	train_loss: 16.0689	val_loss: 29.0136	psnr: 18.4239	ssim: 0.8068




Epoch: 22	train_loss: 16.0317	val_loss: 28.7893	psnr: 18.5743	ssim: 0.8064




Epoch: 23	train_loss: 15.8840	val_loss: 26.7087	psnr: 19.0018	ssim: 0.8103




Epoch: 24	train_loss: 16.0391	val_loss: 29.0897	psnr: 18.4854	ssim: 0.8077




Epoch: 25	train_loss: 16.1740	val_loss: 32.1292	psnr: 17.5532	ssim: 0.8116




Epoch: 26	train_loss: 16.5958	val_loss: 27.3190	psnr: 18.8392	ssim: 0.8206




Epoch: 27	train_loss: 16.5612	val_loss: 31.4084	psnr: 17.9790	ssim: 0.8050




Epoch: 28	train_loss: 16.9960	val_loss: 28.8321	psnr: 18.4636	ssim: 0.8099




Epoch: 29	train_loss: 18.5042	val_loss: 25.7986	psnr: 19.2760	ssim: 0.8236




Epoch: 30	train_loss: 18.4142	val_loss: 22.4122	psnr: 20.2304	ssim: 0.8198




Epoch: 31	train_loss: 17.9985	val_loss: 26.8263	psnr: 18.7401	ssim: 0.8277




Epoch: 32	train_loss: 18.3190	val_loss: 29.3072	psnr: 17.9094	ssim: 0.8220




Epoch: 33	train_loss: 19.4375	val_loss: 41.1338	psnr: 15.8863	ssim: 0.8035




Epoch: 34	train_loss: 19.8489	val_loss: 34.3915	psnr: 17.1667	ssim: 0.8167




Epoch: 35	train_loss: 19.2656	val_loss: 24.2638	psnr: 19.4849	ssim: 0.8228




Epoch: 36	train_loss: 18.6905	val_loss: 27.1025	psnr: 18.9980	ssim: 0.8309




Epoch: 37	train_loss: 18.8041	val_loss: 44.1437	psnr: 14.8765	ssim: 0.7959




Epoch: 38	train_loss: 19.9794	val_loss: 27.5953	psnr: 18.8250	ssim: 0.8368




Epoch: 39	train_loss: 18.9343	val_loss: 50.5377	psnr: 13.6543	ssim: 0.7975




Epoch: 40	train_loss: 19.7745	val_loss: 24.4772	psnr: 19.9606	ssim: 0.8404




Epoch: 41	train_loss: 19.3383	val_loss: 35.3898	psnr: 16.6653	ssim: 0.8324




Epoch: 42	train_loss: 18.5507	val_loss: 29.3197	psnr: 18.4078	ssim: 0.8417




Epoch: 43	train_loss: 18.1443	val_loss: 35.8039	psnr: 16.8033	ssim: 0.8310




Epoch: 44	train_loss: 18.1595	val_loss: 25.0321	psnr: 19.4553	ssim: 0.8507




Epoch: 45	train_loss: 17.3140	val_loss: 24.0380	psnr: 19.8845	ssim: 0.8520




Epoch: 46	train_loss: 16.8682	val_loss: 22.3632	psnr: 20.0751	ssim: 0.8568




Epoch: 47	train_loss: 17.3697	val_loss: 40.9654	psnr: 15.7659	ssim: 0.8187




Epoch: 48	train_loss: 16.8915	val_loss: 29.0070	psnr: 18.5475	ssim: 0.8466




Epoch: 49	train_loss: 15.8908	val_loss: 29.3142	psnr: 18.4426	ssim: 0.8469




Epoch: 50	train_loss: 15.0803	val_loss: 31.0573	psnr: 18.2532	ssim: 0.8415




Epoch: 51	train_loss: 14.5420	val_loss: 31.6003	psnr: 17.8345	ssim: 0.8389




Epoch: 52	train_loss: 14.2139	val_loss: 37.4318	psnr: 16.7113	ssim: 0.8236




Epoch: 53	train_loss: 13.5282	val_loss: 35.1991	psnr: 17.2876	ssim: 0.8302




Epoch: 54	train_loss: 12.9536	val_loss: 29.0627	psnr: 18.5047	ssim: 0.8474




Epoch: 55	train_loss: 12.4214	val_loss: 28.5319	psnr: 18.7491	ssim: 0.8470




Epoch: 56	train_loss: 11.9058	val_loss: 30.9447	psnr: 18.2134	ssim: 0.8423




Epoch: 57	train_loss: 11.5353	val_loss: 28.6053	psnr: 18.6781	ssim: 0.8480




Epoch: 58	train_loss: 11.2014	val_loss: 30.5442	psnr: 18.2410	ssim: 0.8430




Epoch: 59	train_loss: 10.8474	val_loss: 35.0705	psnr: 17.1921	ssim: 0.8323




Epoch: 60	train_loss: 10.6846	val_loss: 29.4905	psnr: 18.5270	ssim: 0.8449




Epoch: 61	train_loss: 10.5062	val_loss: 29.8860	psnr: 18.4167	ssim: 0.8431




Epoch: 62	train_loss: 10.3966	val_loss: 30.8389	psnr: 18.1506	ssim: 0.8428




Epoch: 63	train_loss: 10.4137	val_loss: 30.7269	psnr: 18.2358	ssim: 0.8414




Epoch: 64	train_loss: 10.5634	val_loss: 30.7882	psnr: 18.1954	ssim: 0.8414


  0%|          | 0/15 [00:00<?, ?it/s]

idx: 0	loss: 30.3400	psnr: 17.2380	ssim: 0.7952


test:   7%|▋         | 1/15 [00:06<01:35,  6.82s/it]

idx: 1	loss: 47.7165	psnr: 14.3101	ssim: 0.8326


test:  13%|█▎        | 2/15 [00:09<00:58,  4.49s/it]

idx: 2	loss: 32.2606	psnr: 16.9138	ssim: 0.7688


test:  20%|██        | 3/15 [00:12<00:44,  3.75s/it]

idx: 3	loss: 19.1968	psnr: 19.9606	ssim: 0.7279


test:  27%|██▋       | 4/15 [00:14<00:34,  3.13s/it]

idx: 4	loss: 29.6693	psnr: 17.8647	ssim: 0.8088


test:  33%|███▎      | 5/15 [00:17<00:28,  2.83s/it]

idx: 5	loss: 40.5535	psnr: 15.2372	ssim: 0.7464


test:  40%|████      | 6/15 [00:20<00:27,  3.07s/it]

idx: 6	loss: 34.7827	psnr: 14.6903	ssim: 0.5421


test:  47%|████▋     | 7/15 [00:23<00:24,  3.01s/it]

idx: 7	loss: 12.2357	psnr: 23.0183	ssim: 0.8255


test:  53%|█████▎    | 8/15 [00:26<00:20,  2.89s/it]

idx: 8	loss: 43.5788	psnr: 14.5650	ssim: 0.7500


test:  60%|██████    | 9/15 [00:28<00:16,  2.82s/it]

idx: 9	loss: 17.8467	psnr: 21.4774	ssim: 0.8105


test:  67%|██████▋   | 10/15 [00:31<00:14,  2.87s/it]

idx: 10	loss: 27.5355	psnr: 18.2994	ssim: 0.8261


test:  73%|███████▎  | 11/15 [00:34<00:11,  2.87s/it]

idx: 11	loss: 18.3210	psnr: 21.0312	ssim: 0.8695


test:  80%|████████  | 12/15 [00:37<00:08,  2.99s/it]

idx: 12	loss: 18.8952	psnr: 20.3908	ssim: 0.6994


test:  87%|████████▋ | 13/15 [00:40<00:05,  2.87s/it]

idx: 13	loss: 17.3078	psnr: 21.4421	ssim: 0.7760


test:  93%|█████████▎| 14/15 [00:43<00:02,  2.82s/it]

idx: 14	loss: 51.2711	psnr: 13.6379	ssim: 0.8602


