# Fundus Super-Resolution Project

This notebook contains two main parts: training and inference.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import vgg19
from tqdm import tqdm
from PIL import Image
import glob
import piq

# ===============================
# CONFIG
# ===============================
from google.colab import drive
drive.mount('/content/drive')

HR_DIR   = "/content/drive/MyDrive/superRes/SR/SRimage"
MASK_DIR = "/content/drive/MyDrive/superRes/SR/Ground truth"

WORK_DIR = "/content/work_fundus_sr"
os.makedirs(WORK_DIR, exist_ok=True)

UPSCALE = 2
PATCH_HR = 128
VAL_SPLIT = 0.1
BATCH_TRAIN = 2
BATCH_VAL   = 2
ACCUM_STEPS = 2
NUM_WORKERS = 2
IMG_EXTS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

PERC_RESIZE = 224
PERC_DEVICE = 'cpu'

EPOCHS = 50
EARLY_STOP = 10
WARMUP_EPOCHS = 5

USE_AMP = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Scale x{UPSCALE} | HRpatch={PATCH_HR} | batch={BATCH_TRAIN} | perc@{PERC_RESIZE} on {PERC_DEVICE}")

# ===============================
# MODEL (EDSR Lite)
# ===============================
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(c, c, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(c, c, 3, padding=1)
        )
    def forward(self, x):
        return x + self.body(x) * 0.1

class EDSR_Lite(nn.Module):
    def __init__(self, scale=2, n_res=8, c=64):
        super().__init__()
        self.head = nn.Conv2d(3, c, 3, padding=1)
        self.body = nn.Sequential(*[ResBlock(c) for _ in range(n_res)])
        self.tail = nn.Sequential(
            nn.Conv2d(c, c * scale * scale, 3, padding=1),
            nn.PixelShuffle(scale),
            nn.Conv2d(c, 3, 3, padding=1)
        )
    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        x = x + res
        return self.tail(x)

# ===============================
# DATASET
# ===============================
class FundusDataset(Dataset):
    def __init__(self, hr_dir, mask_dir, patch_size=128, upscale=2, train=True, split=0.1):
        self.hr_files = sorted([f for f in glob.glob(os.path.join(hr_dir, "*")) if f.endswith(IMG_EXTS)])
        self.mask_files = sorted([f for f in glob.glob(os.path.join(mask_dir, "*")) if f.endswith(IMG_EXTS)])
        n_val = int(len(self.hr_files) * split)
        if train:
            self.hr_files = self.hr_files[n_val:]
            self.mask_files = self.mask_files[n_val:]
        else:
            self.hr_files = self.hr_files[:n_val]
            self.mask_files = self.mask_files[:n_val]
        self.patch_size = patch_size
        self.upscale = upscale
        self.train = train
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_files[idx]).convert("RGB")
        mask = Image.open(self.mask_files[idx]).convert("L")
        hr = self.to_tensor(hr)
        mask = self.to_tensor(mask)
        lr = F.interpolate(hr.unsqueeze(0), scale_factor=1/self.upscale, mode="bicubic", align_corners=False).squeeze(0)
        return lr, hr, mask, os.path.basename(self.hr_files[idx])

train_ds = FundusDataset(HR_DIR, MASK_DIR, PATCH_HR, UPSCALE, train=True, split=VAL_SPLIT)
val_ds   = FundusDataset(HR_DIR, MASK_DIR, PATCH_HR, UPSCALE, train=False, split=VAL_SPLIT)
train_dl = DataLoader(train_ds, batch_size=BATCH_TRAIN, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_VAL, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ===============================
# PERCEPTUAL NETWORK
# ===============================
class VGGPerceptual(nn.Module):
    def __init__(self, resize=True):
        super().__init__()
        vgg = vgg19(pretrained=True).features[:16].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
        self.resize = resize
    def forward(self, x):
        if self.resize:
            x = F.interpolate(x, size=(PERC_RESIZE,PERC_RESIZE), mode="bilinear", align_corners=False)
        return self.vgg(x)

perc_net = VGGPerceptual().to(PERC_DEVICE)

def perceptual_loss(net, sr, hr):
    f1 = net(sr)
    f2 = net(hr)
    return F.l1_loss(f1, f2)

def mask_weighted_l1(sr, hr, mask):
    return (torch.abs(sr - hr) * mask).mean()

def tv_loss(x):
    return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) +            torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))

ssim_metric = piq.SSIMLoss(data_range=1.).to(device)

# ===============================
# TRAINING
# ===============================
model = EDSR_Lite(scale=UPSCALE, n_res=8, c=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

BEST_PSNR = -1.0
SAVE_PATH = os.path.join(WORK_DIR, "best_model.pt")

print("==> Start Training on", device, "(AMP=", USE_AMP, ")")
for ep in range(1, EPOCHS+1):
    model.train()
    tr_loss = 0.0
    pbar = tqdm(train_dl, desc=f"Epoch {ep}/{EPOCHS}")
    for lr_t, hr_t, m_t, _ in pbar:
        lr_t, hr_t, m_t = lr_t.to(device), hr_t.to(device), m_t.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            sr = model(lr_t).clamp(0,1)
            l1 = mask_weighted_l1(sr, hr_t, m_t)
            perc = perceptual_loss(perc_net, sr.to(PERC_DEVICE).float(), hr_t.to(PERC_DEVICE).float())
            ssim_l = 1.0 - ssim_metric(sr.float(), hr_t.float())
            tv = tv_loss(sr)
            loss = 0.6*l1 + 0.3*perc + 0.1*ssim_l + 0.05*tv
        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        tr_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")
    scheduler.step()

    # Validation
    model.eval()
    v_psnr, v_ssim, v_loss, n = 0.,0.,0.,0
    with torch.no_grad():
        for lr_t, hr_t, m_t, _ in val_dl:
            lr_t, hr_t, m_t = lr_t.to(device), hr_t.to(device), m_t.to(device)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                sr = model(lr_t).clamp(0,1)
                l1 = mask_weighted_l1(sr, hr_t, m_t)
                perc = perceptual_loss(perc_net, sr.to(PERC_DEVICE).float(), hr_t.to(PERC_DEVICE).float())
                ssim_l = 1.0 - ssim_metric(sr.float(), hr_t.float())
                tv = tv_loss(sr)
                loss = 0.6*l1 + 0.3*perc + 0.1*ssim_l + 0.05*tv
            v_loss += loss.item()
            v_psnr += piq.psnr(sr.float(), hr_t.float(), data_range=1.).item()
            v_ssim += 1.0 - ssim_l.item()
            n += 1
    v_loss/=n; v_psnr/=n; v_ssim/=n
    print(f"[VAL] loss={v_loss:.4f} | PSNR={v_psnr:.2f} | SSIM={v_ssim:.3f}")
    if v_psnr > BEST_PSNR:
        BEST_PSNR = v_psnr
        torch.save({"model": model.state_dict(), "cfg": {"scale": UPSCALE}}, SAVE_PATH)
        print(f"  Saved best @ PSNR {BEST_PSNR:.2f}")

# Save final model
FINAL_PATH = os.path.join(WORK_DIR, "final_model.pt")
torch.save({"model": model.state_dict(), "cfg": {"scale": UPSCALE}}, FINAL_PATH)
print("✅ Model saved to:", FINAL_PATH)
