In [1]:
import os
import time
import math
import copy
import numpy as np
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn.functional as F

import random
from typing import List, Tuple
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

from common import *
from dataset import *
from model import *

In [2]:
steps = 100000
batch_size = 128
scale = 4

In [3]:
DATASET_PATH = f"{os.getcwd()}/ambient-cg-images"
DATASET_PATH

'C:\\Projects\\Texture-SuperResolution/ambient-cg-images'

In [4]:
def _list_images(root: str, exts=(".png", ".jpg", ".jpeg", ".bmp", ".webp")) -> List[str]:
    return [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(exts)]

def _ensure_divisible_by(img: Image.Image, s: int) -> Image.Image:
    w, h = img.size
    w2, h2 = (w // s) * s, (h // s) * s
    return img if (w2, h2) == (w, h) else img.crop((0, 0, w2, h2))

def _to_tensor(img: Image.Image) -> torch.Tensor:
    arr = np.asarray(img).astype(np.float32) / 255.0  # H,W,3
    arr = np.transpose(arr, (2, 0, 1))               # 3,H,W
    return torch.from_numpy(arr)

class AmbientCG_FSRCNN_Dataset(Dataset):
    """
    FSRCNN dataset producing (LR, HR, name)
    - ORIGINAL (≈2K) --(crop to divisible by 2*scale)-->  downscale by 2  --> HR (≈1K)
    - LR = downsample(HR) by `scale`
    - mode='patch': random HR patch (size divisible by scale) from HR, then LR from that patch
    - mode='full' : full HR image
    """
    def __init__(
        self,
        file_list: List[str],
        scale: int = 4,
        mode: str = "patch",
        hr_patch_size: int = 128,
        patches_per_image: int = 16,
        augment: bool = True,
        seed: int = 42,
    ):
        assert scale in (2,3,4)
        assert mode in ("patch", "full")
        if mode == "patch":
            assert hr_patch_size % scale == 0, "hr_patch_size must be divisible by scale"
        self.files = file_list
        self.scale = scale
        self.mode = mode
        self.hr_patch_size = hr_patch_size
        self.patches_per_image = patches_per_image
        self.augment = augment

        # deterministic index map for patch mode
        self.idx_map: List[Tuple[int,int]] = []
        if mode == "patch":
            for i in range(len(self.files)):
                for k in range(self.patches_per_image):
                    self.idx_map.append((i, k))
            random.Random(seed).shuffle(self.idx_map)

    def __len__(self):
        return len(self.idx_map) if self.mode == "patch" else len(self.files)

    @staticmethod
    def _rand_patch(img: Image.Image, size: int) -> Image.Image:
        w, h = img.size
        if w < size or h < size:
            img = img.resize((max(w, size), max(h, size)), Image.Resampling.BICUBIC)
            w, h = img.size
        x = random.randint(0, w - size); y = random.randint(0, h - size)
        return img.crop((x, y, x + size, y + size))

    @staticmethod
    def _augment_flip_rot(img: Image.Image) -> Image.Image:
        if random.random() < 0.5: img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
        if random.random() < 0.5: img = img.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
        k = random.randint(0,3)
        if k: img = img.rotate(90*k, expand=False)
        return img

    def __getitem__(self, idx):
        # ---- 1) Load ORIGINAL (~2K) and ensure divisibility by 2*scale ----
        if self.mode == "patch":
            img_idx, _ = self.idx_map[idx]
            path = self.files[img_idx]
        else:
            path = self.files[idx]

        orig = Image.open(path).convert("RGB")
        orig = _ensure_divisible_by(orig, 2 * self.scale)  # so HR (orig/2) is divisible by scale

        # ---- 2) Downscale ORIGINAL by 2 to form HR (~1K) ----
        ow, oh = orig.size
        hr_size = (ow // 2, oh // 2)
        hr = orig.resize(hr_size, Image.Resampling.BICUBIC)

        # (safety) ensure HR divisible by scale
        hr = _ensure_divisible_by(hr, self.scale)

        # ---- 3) Patch/augment ON HR (not on the original) ----
        if self.mode == "patch":
            hr = self._rand_patch(hr, self.hr_patch_size)  # size divisible by scale
            if self.augment:
                hr = self._augment_flip_rot(hr)

        # ---- 4) Create LR by downsampling HR by `scale` ----
        hr_w, hr_h = hr.size
        lr = hr.resize((hr_w // self.scale, hr_h // self.scale), Image.Resampling.BICUBIC)

        return _to_tensor(lr), _to_tensor(hr), os.path.basename(path)                   

all_files = _list_images(DATASET_PATH)
all_files = all_files[0:200]
print(f"Found {len(all_files)} images")

# train/val split (disjoint by file)
train_files, val_files = train_test_split(all_files, test_size=0.2, random_state=42, shuffle=True)

Found 200 images


In [5]:
all_files[0]

'C:\\Projects\\Texture-SuperResolution/ambient-cg-images\\Asphalt004_2K-JPG_Color.jpg'

In [6]:
train_ds = AmbientCG_FSRCNN_Dataset(
    file_list=train_files,
    scale=scale,
    mode="patch",
    hr_patch_size=128,       # divisible by SCALE
    patches_per_image=16,
    augment=True,
)

val_ds = AmbientCG_FSRCNN_Dataset(
    file_list=val_files,
    scale=scale,
    mode="full",             # full frames for validation; batch_size must be 1
    augment=False,
)


In [7]:
len(train_ds), len(val_ds)

(2560, 40)

In [8]:
#-----------------------
# Config
# -----------------------
SCALE        = 4                                # MUST match your dataset/model
BATCH_SIZE   = 16
NUM_EPOCHS   = 40
LR           = 1e-4
STEP_EVERY   = 15
GAMMA        = 0.5
NUM_WORKERS  = 0                                # Windows/Jupyter-safe
PIN_MEMORY   = False
BORDER_CROP  = 4                                # shave borders for PSNR(Y)
CKPT_PATH    = f"fsrcnn_x{SCALE}_best.pth"
PRETRAINED   = f"pretrained_fsrcnn_x{SCALE}.pth" # put the repo weights here if you have them

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -----------------------
# DataLoaders (use your already-created datasets)
#   train_ds: patch mode
#   val_ds:   full mode (batch 1)
# -----------------------
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds,   batch_size=1,           shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# -----------------------
# Helpers
# -----------------------
def rgb_to_y(t: torch.Tensor) -> torch.Tensor:
    # [B,3,H,W] -> [B,1,H,W]
    r, g, b = t[:,0:1], t[:,1:2], t[:,2:3]
    return 0.299*r + 0.587*g + 0.114*b

@torch.no_grad()
def batch_psnr_y(sr: torch.Tensor, hr: torch.Tensor, shave: int = 0) -> float:
    # sr, hr: [B,3,H,W] in [0,1]; compute PSNR on Y channel with border crop
    sr_y = rgb_to_y(sr)
    hr_y = rgb_to_y(hr)
    if shave > 0:
        sr_y = sr_y[..., shave:-shave, shave:-shave]
        hr_y = hr_y[..., shave:-shave, shave:-shave]
    mse = torch.mean((sr_y - hr_y) ** 2, dim=(1,2,3))  # per-image
    psnr = 20.0 * torch.log10(torch.tensor(1.0, device=sr.device)) - 10.0 * torch.log10(mse.clamp_min(1e-10))
    return psnr.mean().item()

@torch.no_grad()
def batch_psnr_y_bicubic(lr: torch.Tensor, hr: torch.Tensor, shave: int = 0, scale: int = SCALE) -> float:
    # upsample LR to HR size with bicubic for baseline PSNR(Y)
    lr_up = F.interpolate(lr, size=hr.shape[-2:], mode="bicubic", align_corners=False)
    return batch_psnr_y(lr_up, hr, shave)

# -----------------------
# Model / Optim / Loss
# -----------------------
model = FSRCNN_model(scale=SCALE).to(device)

# (Optional) load pretrained repo weights if available
if os.path.exists(PRETRAINED):
    ckpt = torch.load(PRETRAINED, map_location=device)
    state = ckpt.get("model_state_dict", ckpt)  # support plain state_dict too
    missing, unexpected = model.load_state_dict(state, strict=False)
    print(f"Loaded pretrained: {PRETRAINED} | missing:{len(missing)} unexpected:{len(unexpected)}")
else:
    print("No pretrained weights provided; training from scratch.")

criterion = nn.L1Loss()                          # L1 works well for textures
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_EVERY, gamma=GAMMA)
use_amp = (device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_val_psnr = -1e9

# -----------------------
# Train / Val loops
# -----------------------
for epoch in range(1, NUM_EPOCHS + 1):
    print("------------------------------------------------------------")
    model.train()
    t0 = time.time()
    train_loss = 0.0

    for lr, hr, _ in train_loader:
        lr = lr.to(device, non_blocking=True)   # [B,3,h,w]
        hr = hr.to(device, non_blocking=True)   # [B,3,H,W]

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            sr = model(lr)                      # -> [B,3,H,W] (FSRCNN upsamples internally)
            loss = criterion(sr, hr)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    scheduler.step()
    train_loss /= max(1, len(train_loader))

    # ---------- Validation ----------
    model.eval()
    val_loss = 0.0
    val_psnr_sr  = 0.0
    val_psnr_bic = 0.0
    with torch.no_grad():
        for lr, hr, _ in val_loader:
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)
            sr = model(lr)

            val_loss += criterion(sr, hr).item()
            val_psnr_sr  += batch_psnr_y(sr, hr, shave=BORDER_CROP)
            val_psnr_bic += batch_psnr_y_bicubic(lr, hr, shave=BORDER_CROP, scale=SCALE)

    n_val = max(1, len(val_loader))
    val_loss   /= n_val
    val_psnr_sr  /= n_val
    val_psnr_bic /= n_val

    # Save best
    if val_psnr_sr > best_val_psnr:
        best_val_psnr = val_psnr_sr
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_psnr_y": best_val_psnr,
            "scale": SCALE
        }, CKPT_PATH)

    dt = time.time() - t0
    print(f"Epoch {epoch:03d}/{NUM_EPOCHS} | "
          f"train_loss={train_loss:.6f} | "
          f"val_loss={val_loss:.6f} | "
          f"PSNR(Y) srcnn={val_psnr_sr:.2f} dB | "
          f"PSNR(Y) bic={val_psnr_bic:.2f} dB | "
          f"lr={scheduler.get_last_lr()[0]:.1e} | {dt:.1f}s")

print(f"Best val PSNR(Y): {best_val_psnr:.2f} dB  (saved → {CKPT_PATH})")

Device: cuda
No pretrained weights provided; training from scratch.
------------------------------------------------------------


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):


Epoch 001/40 | train_loss=0.110317 | val_loss=0.048822 | PSNR(Y) srcnn=26.17 dB | PSNR(Y) bic=30.72 dB | lr=1.0e-04 | 290.6s
------------------------------------------------------------
Epoch 002/40 | train_loss=0.052258 | val_loss=0.040825 | PSNR(Y) srcnn=27.22 dB | PSNR(Y) bic=30.72 dB | lr=1.0e-04 | 294.9s
------------------------------------------------------------
Epoch 003/40 | train_loss=0.044468 | val_loss=0.039207 | PSNR(Y) srcnn=27.63 dB | PSNR(Y) bic=30.72 dB | lr=1.0e-04 | 301.3s
------------------------------------------------------------
Epoch 004/40 | train_loss=0.042275 | val_loss=0.036197 | PSNR(Y) srcnn=28.10 dB | PSNR(Y) bic=30.72 dB | lr=1.0e-04 | 941.2s
------------------------------------------------------------
Epoch 005/40 | train_loss=0.040292 | val_loss=0.034240 | PSNR(Y) srcnn=28.51 dB | PSNR(Y) bic=30.72 dB | lr=1.0e-04 | 320.9s
------------------------------------------------------------
Epoch 006/40 | train_loss=0.038307 | val_loss=0.032124 | PSNR(Y) srcnn

KeyboardInterrupt: 