In [4]:
# ===== 1. Check GPU & Install Dependencies =====
#@title **1.1 Check GPU**
!nvidia-smi

#@title **1.2 Install Python Packages**
!pip install torch torchvision pillow scikit-image tqdm


/bin/bash: line 1: nvidia-smi: command not found


In [5]:
# ===== 2. Download & Prepare DIV2K HR =====

#@title **2.1 Download & Flatten HR images**
import os, glob, shutil

# Cleanup old data
!rm -rf data/DIV2K_train_HR data/DIV2K_train_HR.zip temp_DIV2K
os.makedirs('data/DIV2K_train_HR', exist_ok=True)

# Download DIV2K HR zip
!wget -q https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip -O data/DIV2K_train_HR.zip

# Unzip into temp and flatten all images into data/DIV2K_train_HR
!unzip -q data/DIV2K_train_HR.zip -d temp_DIV2K
for pattern in ('*.png','*.jpg','*.jpeg'):
    for fp in glob.glob(f"temp_DIV2K/{pattern}") + glob.glob(f"temp_DIV2K/*/{pattern}"):
        shutil.move(fp, 'data/DIV2K_train_HR/')
# Cleanup
!rm -rf data/DIV2K_train_HR.zip temp_DIV2K

# Verify
hr_files = glob.glob("data/DIV2K_train_HR/*")
print(f"Found {len(hr_files)} HR images in data/DIV2K_train_HR")


Found 800 HR images in data/DIV2K_train_HR


In [6]:
# ===== 3. Define Models =====

#@title **3.1 Import Libraries**
import torch
import torch.nn as nn
from torchvision.models import vgg19


In [7]:
#@title **3.2 ResidualBlock & Generator**
class ResidualBlock(nn.Module):
    def __init__(self, n_feats=64):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 3, 1, 1),
            nn.BatchNorm2d(n_feats),
            nn.PReLU(),
            nn.Conv2d(n_feats, n_feats, 3, 1, 1),
            nn.BatchNorm2d(n_feats),
        )
    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, n_res_blocks=16, n_feats=64, scale=4):
        super().__init__()
        self.conv_in = nn.Conv2d(3, n_feats, 9, 1, 4)
        self.prelu = nn.PReLU()
        self.res_blocks = nn.Sequential(*[ResidualBlock(n_feats) for _ in range(n_res_blocks)])
        self.conv_mid = nn.Sequential(nn.Conv2d(n_feats, n_feats, 3, 1, 1), nn.BatchNorm2d(n_feats))
        upsample = []
        for _ in range(int(scale/2)):
            upsample += [
                nn.Conv2d(n_feats, n_feats*4, 3, 1, 1),
                nn.PixelShuffle(2),
                nn.PReLU()
            ]
        self.upsample = nn.Sequential(*upsample)
        self.conv_out = nn.Conv2d(n_feats, 3, 9, 1, 4)
    def forward(self, x):
        x1 = self.prelu(self.conv_in(x))
        res = self.res_blocks(x1)
        res = self.conv_mid(res)
        x2 = x1 + res
        out = self.upsample(x2)
        return self.conv_out(out)

In [8]:
#@title **3.3 Discriminator**
def conv_block(in_c, out_c, s):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 3, s, 1),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.2, inplace=True),
    )

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True),
            conv_block(64, 64, 2), conv_block(64, 128, 1),
            conv_block(128, 128, 2), conv_block(128, 256, 1),
            conv_block(256, 256, 2), conv_block(256, 512, 1),
            conv_block(512, 512, 2), nn.AdaptiveAvgPool2d(1),
            nn.Flatten(), nn.Linear(512, 1024), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1), nn.Sigmoid(),
        )
    def forward(self, x):
        return self.net(x)

In [9]:
#@title **3.4 VGG Feature Extractor**
class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg19(pretrained=True).features
        self.slice = nn.Sequential(*list(vgg)[:36])
        for p in self.slice.parameters():
            p.requires_grad = False
    def forward(self, x):
        return self.slice(x)

In [10]:
# ===== 4. Patch-based Dataset & DataLoader =====

#@title **4.1 Define SRPatchDataset & DataLoader**
import random
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

class SRPatchDataset(Dataset):
    def __init__(self, hr_dir, patch_size=96, scale=4):
        super().__init__()
        self.hr_paths = sorted(glob.glob(f"{hr_dir}/*"))
        self.patch_size = patch_size
        self.scale = scale
        self.to_tensor = ToTensor()

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_paths[idx]).convert("RGB")
        w, h = hr.size
        # ensure random patch fits
        ps = self.patch_size
        if w < ps or h < ps:
            hr = hr.resize((max(ps,w), max(ps,h)), Image.BICUBIC)
            w, h = hr.size
        left = random.randint(0, w - ps)
        top  = random.randint(0, h - ps)
        hr_patch = hr.crop((left, top, left + ps, top + ps))
        lr_patch = hr_patch.resize((ps // self.scale, ps // self.scale), Image.BICUBIC)
        return self.to_tensor(lr_patch), self.to_tensor(hr_patch)

# instantiate dataset & loader
dataset = SRPatchDataset('data/DIV2K_train_HR', patch_size=96, scale=4)
loader  = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
print(f"Dataset size: {len(dataset)} patches")

Dataset size: 800 patches


In [11]:
# ===== 5. Training Setup =====

#@title **5.1 Instantiate Models, Losses & Optimizers**
device = 'cuda' if torch.cuda.is_available() else 'cpu'
G = Generator().to(device)
D = Discriminator().to(device)
VGG = VGGFeatureExtractor().to(device)

import torch.optim as optim
mse = nn.MSELoss()
bce = nn.BCELoss()
optG = optim.Adam(G.parameters(), lr=1e-4)
optD = optim.Adam(D.parameters(), lr=1e-4)
t_real = lambda n: torch.ones((n,1), device=device)
t_fake = lambda n: torch.zeros((n,1), device=device)

In [12]:
# ===== 6. Phase 1: MSE Pre-training =====

#@title **6.1 MSE Pre-training Loop**
from tqdm import tqdm
import torch

epochs_pre = 10
for epoch in range(epochs_pre):
    loop = tqdm(loader, desc=f"Pretrain {epoch+1}/{epochs_pre}")
    for lr_img, hr_img in loop:
        lr_img, hr_img = lr_img.to(device), hr_img.to(device)
        optG.zero_grad()
        sr = G(lr_img)
        loss = mse(sr, hr_img)
        loss.backward()
        optG.step()
        loop.set_postfix(mse=loss.item())
os.makedirs('checkpoints', exist_ok=True)
torch.save(G.state_dict(), 'checkpoints/srgan_pretrained.pth')

Pretrain 1/10: 100%|██████████| 50/50 [04:28<00:00,  5.37s/it, mse=0.0318]
Pretrain 2/10: 100%|██████████| 50/50 [04:17<00:00,  5.14s/it, mse=0.0115]
Pretrain 3/10: 100%|██████████| 50/50 [04:18<00:00,  5.17s/it, mse=0.00912]
Pretrain 4/10: 100%|██████████| 50/50 [04:17<00:00,  5.14s/it, mse=0.00785]
Pretrain 5/10: 100%|██████████| 50/50 [04:18<00:00,  5.17s/it, mse=0.00554]
Pretrain 6/10: 100%|██████████| 50/50 [04:17<00:00,  5.16s/it, mse=0.00936]
Pretrain 7/10: 100%|██████████| 50/50 [04:19<00:00,  5.19s/it, mse=0.00519]
Pretrain 8/10: 100%|██████████| 50/50 [04:17<00:00,  5.15s/it, mse=0.00606]
Pretrain 9/10: 100%|██████████| 50/50 [04:17<00:00,  5.16s/it, mse=0.00716]
Pretrain 10/10: 100%|██████████| 50/50 [04:18<00:00,  5.18s/it, mse=0.00791]


In [13]:
# ===== 7. Phase 2: Adversarial Training =====

#@title **7.1 GAN Training Loop**
epochs_gan = 20
for epoch in range(epochs_gan):
    loop = tqdm(loader, desc=f"GAN {epoch+1}/{epochs_gan}")
    for lr_img, hr_img in loop:
        lr_img, hr_img = lr_img.to(device), hr_img.to(device)
        # Discriminator step
        optD.zero_grad()
        sr_det = G(lr_img).detach()
        lossD = 0.5 * (bce(D(hr_img), t_real(lr_img.size(0))) +
                       bce(D(sr_det), t_fake(lr_img.size(0))))
        lossD.backward()
        optD.step()
        # Generator step
        optG.zero_grad()
        sr = G(lr_img)
        content_loss = mse(VGG(sr), VGG(hr_img))
        adv_loss     = bce(D(sr), t_real(lr_img.size(0)))
        lossG = content_loss + 1e-3 * adv_loss
        lossG.backward()
        optG.step()
        loop.set_postfix(D=lossD.item(), G=lossG.item())
    torch.save(G.state_dict(), f'checkpoints/srgan_GAN_epoch{epoch+1}.pth')

GAN 1/20: 100%|██████████| 50/50 [16:09<00:00, 19.39s/it, D=0.0495, G=0.149]
GAN 2/20: 100%|██████████| 50/50 [16:09<00:00, 19.39s/it, D=0.0129, G=0.148]
GAN 3/20: 100%|██████████| 50/50 [16:11<00:00, 19.43s/it, D=0.0421, G=0.182]
GAN 4/20: 100%|██████████| 50/50 [16:10<00:00, 19.42s/it, D=0.0265, G=0.138]
GAN 5/20: 100%|██████████| 50/50 [16:09<00:00, 19.40s/it, D=0.0065, G=0.169]
GAN 6/20: 100%|██████████| 50/50 [16:21<00:00, 19.63s/it, D=0.013, G=0.156]
GAN 7/20: 100%|██████████| 50/50 [16:27<00:00, 19.74s/it, D=0.0077, G=0.149]
GAN 8/20: 100%|██████████| 50/50 [16:26<00:00, 19.73s/it, D=0.00407, G=0.205]
GAN 9/20: 100%|██████████| 50/50 [16:27<00:00, 19.76s/it, D=0.00785, G=0.148]
GAN 10/20: 100%|██████████| 50/50 [16:24<00:00, 19.69s/it, D=0.123, G=0.173]
GAN 11/20: 100%|██████████| 50/50 [16:11<00:00, 19.43s/it, D=0.00285, G=0.107]
GAN 12/20: 100%|██████████| 50/50 [16:16<00:00, 19.53s/it, D=0.000622, G=0.119]
GAN 13/20: 100%|██████████| 50/50 [16:13<00:00, 19.47s/it, D=0.00247, 

In [17]:
# ===== 8. Evaluation & Metrics =====
#@title **8. Evaluation & Metrics**
# 8.1 Identify & Load Checkpoint
import os, re
from torchvision.transforms import ToPILImage

device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpts = os.listdir('checkpoints')
print("Available checkpoints:", ckpts)

gan_ckpts = [f for f in ckpts if 'srgan_GAN_epoch' in f]
if gan_ckpts:
    # pick highest‐numbered epoch
    epochs = {int(re.search(r'epoch(\d+)', f).group(1)): f for f in gan_ckpts}
    best = epochs[max(epochs)]
    ckpt_path = os.path.join('checkpoints', best)
else:
    ckpt_path = os.path.join('checkpoints', 'srgan_pretrained.pth')

print("Loading checkpoint:", ckpt_path)
G.load_state_dict(torch.load(ckpt_path, map_location=device))
G.eval()
to_pil = ToPILImage()
os.makedirs('results', exist_ok=True)

# 8.2 Super-Resolve Sample Patches
for i in range(10):
    lr, _ = dataset[i]
    with torch.no_grad():
        sr = G(lr.unsqueeze(0).to(device))
    to_pil(sr.squeeze(0).cpu()).save(f'results/sample_{i}.png')
print("Saved samples to results/")

# 8.3 Compute PSNR & SSIM (with a fallback for small images)
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import skimage.color as sc

psnr_vals, ssim_vals = [], []
for i in range(10):
    # load SR and HR
    sr = np.array(Image.open(f'results/sample_{i}.png')) / 255.0
    _, hr_tensor = dataset[i]
    hr = hr_tensor.permute(1, 2, 0).cpu().numpy()

    # PSNR
    psnr_vals.append(peak_signal_noise_ratio(hr, sr, data_range=1.0))

    # SSIM—prefer color, but fallback to grayscale if window too large
    try:
        ssim_vals.append(structural_similarity(hr, sr,
                                               channel_axis=2,
                                               data_range=1.0))
    except ValueError:
        hr_gray = sc.rgb2gray(hr)
        sr_gray = sc.rgb2gray(sr)
        ssim_vals.append(structural_similarity(hr_gray,
                                               sr_gray,
                                               data_range=1.0))

print(f"Avg PSNR: {np.mean(psnr_vals):.2f}, Avg SSIM: {np.mean(ssim_vals):.4f}")


Available checkpoints: ['srgan_GAN_epoch19.pth', 'srgan_GAN_epoch5.pth', 'srgan_GAN_epoch6.pth', 'srgan_GAN_epoch2.pth', 'srgan_GAN_epoch7.pth', 'srgan_GAN_epoch15.pth', 'srgan_GAN_epoch12.pth', 'srgan_GAN_epoch17.pth', 'srgan_GAN_epoch8.pth', 'srgan_GAN_epoch20.pth', 'srgan_GAN_epoch9.pth', 'srgan_GAN_epoch16.pth', 'srgan_GAN_epoch11.pth', 'srgan_GAN_epoch3.pth', 'srgan_GAN_epoch1.pth', 'srgan_pretrained.pth', 'srgan_GAN_epoch14.pth', 'srgan_GAN_epoch18.pth', 'srgan_GAN_epoch13.pth', 'srgan_GAN_epoch10.pth', 'srgan_GAN_epoch4.pth']
Loading checkpoint: checkpoints/srgan_GAN_epoch20.pth
Saved samples to results/
Avg PSNR: 6.76, Avg SSIM: 0.0334
