In [1]:
import os
import math
import glob
import torch
import random
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from tqdm.auto import tqdm
from torch.optim import Adam
from torchvision.transforms import v2
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
HR_train_paths = sorted(glob.glob("../data/DIV2K_train_HR/*.png"))
X2_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X2/*.png"))
X4_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X4/*.png"))
X8_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X8/*.png"))
X16_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X16/*.png"))
X32_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X32/*.png"))
X64_train_paths = sorted(glob.glob("../data/DIV2K_train_LR_bicubic/X64/*.png"))

HR_valid_paths = sorted(glob.glob("../data/DIV2K_valid_HR/*.png"))
X2_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X2/*.png"))
X4_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X4/*.png"))
X8_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X8/*.png"))
X16_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X16/*.png"))
X32_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X32/*.png"))
X64_valid_paths = sorted(glob.glob("../data/DIV2K_valid_LR_bicubic/X64/*.png"))

In [4]:
class DenseBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.Sequential()
        for i in range(1,5):
            self.blocks.append(
                nn.Sequential(
                    nn.Conv2d(i * 64, 64, 3, stride=1, padding='same'),
                    nn.LeakyReLU()
                )
            )

        self.blocks.append(nn.Conv2d(5 * 64, 64, 3, stride=1, padding='same'))

    def forward(self, x):
        x1 = self.blocks[0](x)
        x = torch.cat([x, x1], dim=1)
        x2 = self.blocks[1](x)
        x = torch.cat([x, x2], dim=1)
        x3 = self.blocks[2](x)
        x = torch.cat([x, x3], dim=1)
        x4 = self.blocks[3](x)
        x = torch.cat([x, x4], dim=1)
        return self.blocks[4](x)
        
class ESRGAN(nn.Module):
    def __init__(self, n: int):
        """
        Enhanced Deep Residual Network with Residual in Residual Dense Block a.k.a.
        Enhanced Super Resolution Generative Adversarial Networks
        Args:
            n: scaling factor
        """
        super().__init__()
        self.DIV2K_RGB = torch.tensor([0.44882884613943946, 0.43713809810624193, 0.4040371984052683]).view(1, 3, 1, 1).to(device)
        
        self.expand = nn.Sequential(
            nn.Conv2d(3, 64, 9, stride=1, padding='same'),
            nn.PReLU()
        )

        self.residual_blocks = nn.Sequential()
        for _ in range(23):
            self.residual_blocks.append(DenseBlock())

        self.residual_blocks.append(nn.Conv2d(64, 64, 3, stride=1, padding='same'))

        self.upscaling_head = nn.Sequential()
        for _ in range(int(math.log2(n))):
            self.upscaling_head.append(nn.Conv2d(64, 256, 3, stride=1, padding='same'))
            self.upscaling_head.append(nn.PixelShuffle(2))
            self.upscaling_head.append(nn.PReLU())
            
        self.upscaling_head.append(nn.Conv2d(64, 64, 9, stride=1, padding='same'))
        self.upscaling_head.append(nn.Conv2d(64, 3, 9, stride=1, padding='same'))

    def forward(self, x):
        x = self.expand(x-self.DIV2K_RGB)
        xp = x.clone()
        for i in range(23):
            xp = xp + 0.2 * self.residual_blocks[i](xp)

        x = x + 0.2 * self.residual_blocks[23](xp)
        return self.upscaling_head(x) + self.DIV2K_RGB

In [5]:
class ESRGAN_Dataset(Dataset):
    def __init__(self, target_paths: list[str], scale: int, ram_limit_gb: float = 2.0):
        self.crop_size = scale * 48
        self.scale = scale

        self.rotations = [0, 90, 180, 270]
        self.transforms = v2.Compose([
            v2.PILToTensor(),
            v2.Lambda(lambda x: (x / 255.0))
        ])

        self.preloaded = {}
        self.paths = target_paths

        total_ram_used = 0
        for i, path in enumerate(tqdm(target_paths, desc="Preloading images")):
            img = Image.open(path).convert("RGB")
            total_ram_used += img.width * img.height * 3 / (1024 ** 3)  # ~size in GB

            if total_ram_used < ram_limit_gb:
                self.preloaded[i] = img
            else:
                break

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        if idx in self.preloaded:
            target = self.preloaded[idx]
        else:
            target = Image.open(self.paths[idx]).convert("RGB")

        target = self.random_crop(target, self.crop_size)
        inp = target.resize((target.width // self.scale, target.height // self.scale), Image.BICUBIC)
        
        rotation =  random.choice(self.rotations)
        if rotation != 0:
            inp = v2.functional.rotate(inp, rotation)
            target = v2.functional.rotate(target, rotation)
        if random.randint(0, 1):
            inp = v2.functional.horizontal_flip(inp)
            target = v2.functional.horizontal_flip(target)
            
        return self.transforms(inp), self.transforms(target)

    def random_crop(self, img, size):
        w, h = img.size
        if w < size or h < size:
            img = img.resize((size, size), Image.BICUBIC)
        x = random.randint(0, w - size)
        y = random.randint(0, h - size)
        return img.crop((x, y, x + size, y + size))

    def set_scale(self, scale: int):
        self.scale = scale

    def set_crop_size(self, crop_size: int):
        self.crop_size = crop_size

In [6]:
psnr = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)

In [7]:
transform = v2.Compose([
    v2.PILToTensor(),
    v2.Lambda(lambda x: x / 255.0)
])

In [8]:
def calc_metrics(model: nn.Module, target_ds: list[str], scale: int):
    transform_target = v2.Compose([
        v2.PILToTensor(),
        v2.Lambda(lambda x: x/255.0)
    ])

    transform_input = v2.Compose([
        v2.PILToTensor(),
        v2.Lambda(lambda x: (x / 255.0))
    ])

    psnr_acc = 0
    ssim_acc = 0
    lpips_acc = 0
    failed_lpips = 0

    for i in tqdm(range(len(target_ds)), leave=False):
        target_image = Image.open(target_ds[i]).convert("RGB")
        w, h = target_image.size

        w -= w % scale
        h -= h % scale
        target_image = target_image.crop((0, 0, w, h))
        
        lowres = target_image.resize((w // scale, h // scale), resample=Image.BICUBIC)
        input_tensor = transform_input(lowres).unsqueeze(0).to(device)
        target_tensor = transform_target(target_image).unsqueeze(0).to(device)

        with torch.inference_mode():
            sr = model(input_tensor).clamp(0, 1)

        psnr_acc += psnr(sr, target_tensor).item()
        ssim_acc += ssim(sr, target_tensor).item()
        
        # There are 2 images that cause lpips to fail
        try:
            x = lpips(sr, target_tensor).cpu().item()
            if np.isnan(x):
                failed_lpips += 1
                continue
                
            lpips_acc += x
        except:
            failed_lpips += 1

    lpips_acc /= len(target_ds) - failed_lpips
    psnr_acc /= len(target_ds)
    ssim_acc /= len(target_ds)
    return psnr_acc, ssim_acc, lpips_acc

In [9]:
def train_step(model, dataloader, optimizer, loss_fn, scaler):
    avg_psnr = 0
    avg_ssim = 0
    model.train()

    for batch, target in dataloader:
        optimizer.zero_grad(set_to_none=True)
        
        batch, target = batch.to(device), target.to(device)

        with autocast('cuda'):
            logits = model(batch)
            loss = loss_fn(logits, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        logits = logits.clamp(0.0, 1.0)
        target = target.clamp(0.0, 1.0)
        
        avg_psnr += psnr(logits, target).item()
        avg_ssim += ssim(logits, target).item()
        
    avg_psnr /= len(dataloader)
    avg_ssim /= len(dataloader)
    return avg_psnr, avg_ssim

def valid_step(model, dataloader, loss_fn):
    avg_psnr = 0
    avg_ssim = 0
    avg_lpips = 0
    model.eval()

    with torch.inference_mode():
        for batch, target in dataloader:
            batch, target = batch.to(device), target.to(device)

            logits = model(batch)
            
            logits = logits.clamp(0.0, 1.0)
            target = target.clamp(0.0, 1.0)
            
            avg_psnr += psnr(logits, target).item()
            avg_ssim += ssim(logits, target).item()
            avg_lpips += lpips(logits, target).item()

    avg_psnr /= len(dataloader)
    avg_ssim /= len(dataloader)
    avg_lpips /= len(dataloader)

        
    return avg_psnr, avg_ssim, avg_lpips

In [10]:
def train(model, train_dl, valid_dl, optimizer, scheduler: StepLR, loss_fn, epochs, start_checkpoint=None):
    os.makedirs('../tmp_model_checkpoints', exist_ok=True)
    counter = 0 # count epochs without printing training stats
    scaler = GradScaler('cuda')
    
    if start_checkpoint:
        start_epoch = start_checkpoint['epoch']
        best_psnr = start_checkpoint['best_psnr']
        best_ssim = start_checkpoint['best_ssim']
        best_lpips = start_checkpoint['best_lpips']
        scaler.load_state_dict(start_checkpoint['scaler_state_dict'])
    else:
        start_epoch = 0
        best_psnr = 0
        best_ssim = 0
        best_lpips = float('inf')
        
    log_freq = (epochs - start_epoch) // 30 # how often to print stats when no progress is made
    
    for epoch in tqdm(range(start_epoch, epochs), desc="Epochs"):
        counter += 1
        train_psnr, train_ssim = train_step(
            model,
            train_dl,
            optimizer,
            loss_fn,
            scaler
        )

        valid_psnr, valid_ssim, valid_lpips = valid_step(
            model,
            valid_dl,
            loss_fn,
        )

        scheduler.step()

        progress = False
        
        if valid_psnr > best_psnr:
            progress = True
            best_psnr = valid_psnr
            checkpoint = {
                'epoch': epoch,
                'best_psnr': best_psnr,
                'best_ssim': best_ssim,
                'best_lpips': best_lpips,
                'model_state_dict': model._orig_mod.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict()
            }
            torch.save(checkpoint, f'../tmp_model_checkpoints/best_psnr.pth')

        if valid_ssim > best_ssim:
            progress = True
            best_ssim = valid_ssim
            checkpoint = {
                'epoch': epoch,
                'best_psnr': best_psnr,
                'best_ssim': best_ssim,
                'best_lpips': best_lpips,
                'model_state_dict': model._orig_mod.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict()
            }
            torch.save(checkpoint, f'../tmp_model_checkpoints/best_ssim.pth')

        if valid_lpips < best_lpips:
            progress = True
            best_lpips = valid_lpips
            checkpoint = {
                'epoch': epoch,
                'best_psnr': best_psnr,
                'best_ssim': best_ssim,
                'best_lpips': best_lpips,
                'model_state_dict': model._orig_mod.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict()
            }
            torch.save(checkpoint, f'../tmp_model_checkpoints/best_lpips.pth')

        if epoch == epochs-1:
            checkpoint = {
                'epoch': epoch,
                'best_psnr': best_psnr,
                'best_ssim': best_ssim,
                'best_lpips': best_lpips,
                'model_state_dict': model._orig_mod.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict()
            }
            torch.save(checkpoint, f'../tmp_model_checkpoints/last.pth')
            
        if progress or counter >= log_freq:
            counter = 0
            print(
                f"Epoch: {epoch+1} | "
                f"learning rate: {scheduler.get_last_lr()[0]:.6f} | "
                f"[train] PSNR: {train_psnr:.4f} | "
                f"[train] SSIM: {train_ssim:.4f} | "
                f"[valid] PSNR: {valid_psnr:.4f} | "
                f"[valid] SSIM: {valid_ssim:.4f} | "
                f"[valid] LPIPS: {valid_lpips:.4f}"
            )

## Training

In [11]:
valid_ds = ESRGAN_Dataset(HR_valid_paths, 2, ram_limit_gb=1)

Preloading images:   0%|          | 0/100 [00:00<?, ?it/s]

In [12]:
train_ds = ESRGAN_Dataset(HR_train_paths, 2, ram_limit_gb=8)

Preloading images:   0%|          | 0/800 [00:00<?, ?it/s]

### Day 1

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

model = torch.compile(ESRGAN(2).to(device))
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 3333)

Epochs:   0%|          | 0/3333 [00:00<?, ?it/s]

W0523 07:50:34.369000 2685 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 1 | learning rate: 0.000200 | [train] PSNR: 19.2574 | [train] SSIM: 0.5180 | [valid] PSNR: 20.9305 | [valid] SSIM: 0.5610 | [valid] LPIPS: 0.6127
Epoch: 2 | learning rate: 0.000200 | [train] PSNR: 21.7832 | [train] SSIM: 0.6175 | [valid] PSNR: 21.1425 | [valid] SSIM: 0.6097 | [valid] LPIPS: 0.4966
Epoch: 3 | learning rate: 0.000200 | [train] PSNR: 23.4297 | [train] SSIM: 0.6826 | [valid] PSNR: 24.3567 | [valid] SSIM: 0.7129 | [valid] LPIPS: 0.4039
Epoch: 4 | learning rate: 0.000200 | [train] PSNR: 24.3505 | [train] SSIM: 0.7334 | [valid] PSNR: 24.9800 | [valid] SSIM: 0.7403 | [valid] LPIPS: 0.3672
Epoch: 5 | learning rate: 0.000200 | [train] PSNR: 25.4929 | [train] SSIM: 0.7706 | [valid] PSNR: 26.1496 | [valid] SSIM: 0.7861 | [valid] LPIPS: 0.3219
Epoch: 6 | learning rate: 0.000200 | [train] PSNR: 26.0246 | [train] SSIM: 0.8002 | [valid] PSNR: 26.3174 | [valid] SSIM: 0.8096 | [valid] LPIPS: 0.2705
Epoch: 7 | learning rate: 0.000200 | [train] PSNR: 26.3079 | [train] SSIM: 0.8113 

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.749013,0.881142,0.050005
63px -> 126px,28.033528,0.897445,0.074818
127px -> 254px,29.897095,0.920161,0.076518
255px -> 510px,31.428937,0.929166,0.075169
510px -> 1020px,33.283587,0.939109,0.07385
1020px -> 2040px,34.498609,0.936519,0.087584


### Day 2

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

model.load_state_dict(checkpoint['model_state_dict'])
model = torch.compile(model)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 6666, checkpoint)

Epochs:   0%|          | 0/3334 [00:00<?, ?it/s]

W0527 08:09:18.368000 2936 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 3443 | learning rate: 0.000200 | [train] PSNR: 29.4277 | [train] SSIM: 0.8936 | [valid] PSNR: 30.0798 | [valid] SSIM: 0.9026 | [valid] LPIPS: 0.0968
Epoch: 3554 | learning rate: 0.000200 | [train] PSNR: 31.3929 | [train] SSIM: 0.9250 | [valid] PSNR: 31.2760 | [valid] SSIM: 0.9239 | [valid] LPIPS: 0.0607
Epoch: 3665 | learning rate: 0.000200 | [train] PSNR: 31.8970 | [train] SSIM: 0.9350 | [valid] PSNR: 32.0994 | [valid] SSIM: 0.9364 | [valid] LPIPS: 0.0564
Epoch: 3744 | learning rate: 0.000200 | [train] PSNR: 31.8033 | [train] SSIM: 0.9363 | [valid] PSNR: 34.1744 | [valid] SSIM: 0.9437 | [valid] LPIPS: 0.0503
Epoch: 3855 | learning rate: 0.000200 | [train] PSNR: 32.1193 | [train] SSIM: 0.9369 | [valid] PSNR: 31.7876 | [valid] SSIM: 0.9334 | [valid] LPIPS: 0.0552
Epoch: 3966 | learning rate: 0.000200 | [train] PSNR: 32.1072 | [train] SSIM: 0.9368 | [valid] PSNR: 31.9686 | [valid] SSIM: 0.9380 | [valid] LPIPS: 0.0522
Epoch: 4077 | learning rate: 0.000100 | [train] PSNR: 32.3261 | 

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.885568,0.884635,0.049951
63px -> 126px,28.206117,0.901167,0.071395
127px -> 254px,30.141289,0.923697,0.073097
255px -> 510px,31.703115,0.93239,0.071856
510px -> 1020px,33.611977,0.94233,0.070376
1020px -> 2040px,34.819135,0.939396,0.08471


### Day 3

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

model.load_state_dict(checkpoint['model_state_dict'])
model = torch.compile(model)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 10000, checkpoint)

Epochs:   0%|          | 0/3335 [00:00<?, ?it/s]

W0528 07:51:48.187000 2264 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 6776 | learning rate: 0.000100 | [train] PSNR: 32.4359 | [train] SSIM: 0.9396 | [valid] PSNR: 32.8723 | [valid] SSIM: 0.9387 | [valid] LPIPS: 0.0488
Epoch: 6887 | learning rate: 0.000100 | [train] PSNR: 31.8479 | [train] SSIM: 0.9370 | [valid] PSNR: 32.2791 | [valid] SSIM: 0.9304 | [valid] LPIPS: 0.0606
Epoch: 6998 | learning rate: 0.000100 | [train] PSNR: 32.2204 | [train] SSIM: 0.9399 | [valid] PSNR: 31.9740 | [valid] SSIM: 0.9344 | [valid] LPIPS: 0.0564
Epoch: 7109 | learning rate: 0.000100 | [train] PSNR: 32.5184 | [train] SSIM: 0.9394 | [valid] PSNR: 32.2937 | [valid] SSIM: 0.9400 | [valid] LPIPS: 0.0465
Epoch: 7220 | learning rate: 0.000100 | [train] PSNR: 32.4449 | [train] SSIM: 0.9411 | [valid] PSNR: 32.0985 | [valid] SSIM: 0.9342 | [valid] LPIPS: 0.0580
Epoch: 7331 | learning rate: 0.000100 | [train] PSNR: 32.4014 | [train] SSIM: 0.9390 | [valid] PSNR: 32.8533 | [valid] SSIM: 0.9384 | [valid] LPIPS: 0.0558
Epoch: 7442 | learning rate: 0.000100 | [train] PSNR: 32.3572 | 

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.94176,0.885824,0.04886
63px -> 126px,28.287441,0.902701,0.069783
127px -> 254px,30.239467,0.92501,0.070816
255px -> 510px,31.809355,0.933386,0.069737
510px -> 1020px,33.714115,0.94309,0.068426
1020px -> 2040px,34.907089,0.939921,0.083195


### Day 4

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

model.load_state_dict(checkpoint['model_state_dict'])
model = torch.compile(model)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 13333, checkpoint)

Epochs:   0%|          | 0/3334 [00:00<?, ?it/s]

W0614 07:38:17.542000 2613 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 10110 | learning rate: 0.000050 | [train] PSNR: 32.5288 | [train] SSIM: 0.9416 | [valid] PSNR: 32.3015 | [valid] SSIM: 0.9384 | [valid] LPIPS: 0.0536
Epoch: 10221 | learning rate: 0.000050 | [train] PSNR: 32.3393 | [train] SSIM: 0.9391 | [valid] PSNR: 32.9201 | [valid] SSIM: 0.9396 | [valid] LPIPS: 0.0560
Epoch: 10332 | learning rate: 0.000050 | [train] PSNR: 32.2097 | [train] SSIM: 0.9382 | [valid] PSNR: 32.0657 | [valid] SSIM: 0.9380 | [valid] LPIPS: 0.0518
Epoch: 10443 | learning rate: 0.000050 | [train] PSNR: 32.6383 | [train] SSIM: 0.9444 | [valid] PSNR: 33.2284 | [valid] SSIM: 0.9435 | [valid] LPIPS: 0.0506
Epoch: 10554 | learning rate: 0.000050 | [train] PSNR: 32.7045 | [train] SSIM: 0.9409 | [valid] PSNR: 32.9060 | [valid] SSIM: 0.9393 | [valid] LPIPS: 0.0492
Epoch: 10665 | learning rate: 0.000050 | [train] PSNR: 32.2890 | [train] SSIM: 0.9407 | [valid] PSNR: 32.5733 | [valid] SSIM: 0.9385 | [valid] LPIPS: 0.0518
Epoch: 10776 | learning rate: 0.000050 | [train] PSNR: 32.

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.956819,0.886506,0.046618
63px -> 126px,28.303397,0.903393,0.067279
127px -> 254px,30.282097,0.925547,0.068266
255px -> 510px,31.855069,0.933928,0.06738
510px -> 1020px,33.757307,0.943486,0.066587
1020px -> 2040px,34.942267,0.940225,0.081736


### Day 5

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

model.load_state_dict(checkpoint['model_state_dict'])
model = torch.compile(model)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 16666, checkpoint)

Epochs:   0%|          | 0/3334 [00:00<?, ?it/s]

W0615 07:38:13.938000 2765 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 13443 | learning rate: 0.000025 | [train] PSNR: 32.4455 | [train] SSIM: 0.9410 | [valid] PSNR: 31.3775 | [valid] SSIM: 0.9356 | [valid] LPIPS: 0.0532
Epoch: 13554 | learning rate: 0.000025 | [train] PSNR: 32.8361 | [train] SSIM: 0.9433 | [valid] PSNR: 32.5786 | [valid] SSIM: 0.9379 | [valid] LPIPS: 0.0551
Epoch: 13665 | learning rate: 0.000025 | [train] PSNR: 32.9135 | [train] SSIM: 0.9420 | [valid] PSNR: 32.4226 | [valid] SSIM: 0.9406 | [valid] LPIPS: 0.0499
Epoch: 13776 | learning rate: 0.000025 | [train] PSNR: 32.4191 | [train] SSIM: 0.9417 | [valid] PSNR: 32.1183 | [valid] SSIM: 0.9398 | [valid] LPIPS: 0.0475
Epoch: 13782 | learning rate: 0.000025 | [train] PSNR: 32.7253 | [train] SSIM: 0.9431 | [valid] PSNR: 33.8164 | [valid] SSIM: 0.9517 | [valid] LPIPS: 0.0410
Epoch: 13893 | learning rate: 0.000025 | [train] PSNR: 32.4641 | [train] SSIM: 0.9408 | [valid] PSNR: 31.6650 | [valid] SSIM: 0.9337 | [valid] LPIPS: 0.0577
Epoch: 14004 | learning rate: 0.000025 | [train] PSNR: 32.

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.970628,0.886749,0.046358
63px -> 126px,28.32371,0.903706,0.066612
127px -> 254px,30.303386,0.925914,0.067781
255px -> 510px,31.880631,0.934283,0.067149
510px -> 1020px,33.790051,0.943804,0.066595
1020px -> 2040px,34.969289,0.940428,0.081991


### Day 6

In [13]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=os.cpu_count()-1)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=os.cpu_count()-1)

checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
loss_fn = nn.L1Loss()
optimizer = Adam(model.parameters(), lr=2e-4)
scheduler = StepLR(optimizer, step_size=4000, gamma=0.5)

model.load_state_dict(checkpoint['model_state_dict'])
model = torch.compile(model)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(model, train_dl, valid_dl, optimizer, scheduler, loss_fn, 20000, checkpoint)

Epochs:   0%|          | 0/3335 [00:00<?, ?it/s]

W0616 07:18:01.169000 2842 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch: 16776 | learning rate: 0.000013 | [train] PSNR: 32.8523 | [train] SSIM: 0.9438 | [valid] PSNR: 32.1335 | [valid] SSIM: 0.9391 | [valid] LPIPS: 0.0494
Epoch: 16887 | learning rate: 0.000013 | [train] PSNR: 32.9673 | [train] SSIM: 0.9422 | [valid] PSNR: 32.4388 | [valid] SSIM: 0.9406 | [valid] LPIPS: 0.0500
Epoch: 16998 | learning rate: 0.000013 | [train] PSNR: 32.7729 | [train] SSIM: 0.9390 | [valid] PSNR: 32.8799 | [valid] SSIM: 0.9381 | [valid] LPIPS: 0.0523
Epoch: 17109 | learning rate: 0.000013 | [train] PSNR: 32.4975 | [train] SSIM: 0.9418 | [valid] PSNR: 32.5822 | [valid] SSIM: 0.9387 | [valid] LPIPS: 0.0504
Epoch: 17220 | learning rate: 0.000013 | [train] PSNR: 32.2789 | [train] SSIM: 0.9402 | [valid] PSNR: 33.5055 | [valid] SSIM: 0.9429 | [valid] LPIPS: 0.0515
Epoch: 17331 | learning rate: 0.000013 | [train] PSNR: 32.3594 | [train] SSIM: 0.9378 | [valid] PSNR: 31.8204 | [valid] SSIM: 0.9389 | [valid] LPIPS: 0.0505
Epoch: 17441 | learning rate: 0.000013 | [train] PSNR: 32.

In [14]:
checkpoint = torch.load('../tmp_model_checkpoints/last.pth')
model = ESRGAN(2).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
model.eval();

In [16]:
targets = [X32_valid_paths, X16_valid_paths, X8_valid_paths, X4_valid_paths, X2_valid_paths, HR_valid_paths]

metrics_x2 = pd.DataFrame(columns=["PSNR↑", "SSIM↑", "LPIPS↓"])
for target_ds in tqdm(targets, total=6):
    metrics_x2.loc[len(metrics_x2)] = calc_metrics(model, target_ds, 2)

metrics_x2.index = [
    "31px -> 62px", "63px -> 126px", "127px -> 254px", 
    "255px -> 510px", "510px -> 1020px", "1020px -> 2040px"
]
metrics_x2

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,PSNR↑,SSIM↑,LPIPS↓
31px -> 62px,26.979871,0.887078,0.046642
63px -> 126px,28.339206,0.904002,0.066789
127px -> 254px,30.31764,0.926106,0.068074
255px -> 510px,31.894504,0.934373,0.067519
510px -> 1020px,33.804934,0.943922,0.066973
1020px -> 2040px,34.984146,0.940503,0.082304
