

Eitan Spivak
311391866




1️⃣ Install and Import Packages

In [None]:
# Install necessary packages (only for Colab/first run)
!pip install --upgrade --force-reinstall torchmetrics[image] torch-fidelity scikit-image

import os
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt


2️⃣ Set Up Paths and Mount Drive

In [None]:
# Mount Google Drive (if using Colab)
from google.colab import drive
drive.mount('/content/drive')

# Define data/model/result directories
BASE_DIR = "/content/drive/MyDrive/super_resolution"
TRAIN_HR = os.path.join(BASE_DIR, "DIV2K_train_HR_cropped")
TRAIN_LR = os.path.join(BASE_DIR, "DIV2K_train_LR_bicubic_X4")
VALID_HR = os.path.join(BASE_DIR, "DIV2K_valid_HR_cropped")
VALID_LR = os.path.join(BASE_DIR, "DIV2K_valid_LR_bicubic_X4")
MODEL_DIR = os.path.join(BASE_DIR, "models")
RESULTS_DIR = os.path.join(BASE_DIR, "results")
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)


3️⃣ Dataset and DataLoader

In [None]:
class PairedImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.filenames = sorted([f for f in os.listdir(lr_dir) if f.endswith('.png')])
        self.transform = transform or transforms.ToTensor()
    def __len__(self):
        return len(self.filenames)
    def __getitem__(self, idx):
        fname = self.filenames[idx]
        lr = Image.open(os.path.join(self.lr_dir, fname)).convert('RGB')
        hr = Image.open(os.path.join(self.hr_dir, fname)).convert('RGB')
        lr_up = lr.resize(hr.size, Image.BICUBIC)
        return self.transform(lr), self.transform(lr_up), self.transform(hr)

BATCH_SIZE = 4
train_dataset = PairedImageDataset(TRAIN_LR, TRAIN_HR)
valid_dataset = PairedImageDataset(VALID_LR, VALID_HR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=2)


4️⃣ Model Definitions: SRCNN and ImprovedSRCNN

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, 5, padding=2)
        self.conv3 = nn.Conv2d(32, 3, 5, padding=2)
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return self.conv3(x)

class ImprovedSRCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=4, mode='bicubic', align_corners=False)
        self.body = nn.Sequential(
            nn.Conv2d(3, 64, 9, padding=4), nn.PReLU(),
            nn.Conv2d(64, 64, 5, padding=2), nn.PReLU(),
            nn.Conv2d(64, 32, 5, padding=2), nn.PReLU(),
            nn.Conv2d(32, 16, 3, padding=1), nn.PReLU(),
            nn.Conv2d(16, 3, 3, padding=1)
        )
    def forward(self, x):
        x = self.upsample(x)
        return x + self.body(x)


5️⃣ Training Functions

In [None]:
def calculate_psnr(mse):
    return 10 * np.log10(1.0 / mse) if mse > 0 else 100

def train_model(model, train_loader, valid_loader, save_path, epochs=5, lr=1e-4, device='cuda', vanilla=True):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    train_losses, valid_psnrs = [], []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for lr_img, lr_up_img, hr_img in train_loader:
            inp = lr_up_img if vanilla else lr_img
            inp, hr_img = inp.to(device), hr_img.to(device)
            optimizer.zero_grad()
            sr = model(inp)
            loss = criterion(sr, hr_img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        # Validation
        model.eval()
        val_psnr = 0.0
        with torch.no_grad():
            for lr_img, lr_up_img, hr_img in valid_loader:
                inp = lr_up_img if vanilla else lr_img
                inp, hr_img = inp.to(device), hr_img.to(device)
                sr = model(inp)
                mse = ((sr - hr_img) ** 2).mean().item()
                val_psnr += calculate_psnr(mse)
        val_psnr /= len(valid_loader)
        valid_psnrs.append(val_psnr)
        print(f"[{'Vanilla' if vanilla else 'Improved'}] Epoch {epoch+1}: Loss={avg_loss:.6f}, Val PSNR={val_psnr:.2f}")
    torch.save(model.state_dict(), save_path)
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train Loss')
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(valid_psnrs, label='Val PSNR')
    plt.legend()
    plt.show()
    return train_losses, valid_psnrs


6️⃣ Train or Load Models

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vanilla_model = SRCNN()
improved_model = ImprovedSRCNN()
vanilla_path = os.path.join(MODEL_DIR, "vanilla_srcnn.pth")
improved_path = os.path.join(MODEL_DIR, "improved_srcnn.pth")

if os.path.exists(vanilla_path):
    print("Loading Vanilla SRCNN weights...")
    vanilla_model.load_state_dict(torch.load(vanilla_path, map_location=device))
else:
    train_model(vanilla_model, train_loader, valid_loader, vanilla_path, epochs=5, lr=1e-4, device=device, vanilla=True)

if os.path.exists(improved_path):
    print("Loading Improved SRCNN weights...")
    improved_model.load_state_dict(torch.load(improved_path, map_location=device))
else:
    train_model(improved_model, train_loader, valid_loader, improved_path, epochs=5, lr=1e-4, device=device, vanilla=False)


7️⃣ Evaluation: Quantitative and Qualitative

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from torchmetrics.image.fid import FrechetInceptionDistance

def evaluate_model(model, dataloader, device='cuda', vanilla=True):
    model = model.to(device)
    model.eval()
    psnrs, ssims = [], []
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    with torch.no_grad():
        for lr, lr_up, hr in dataloader:
            inp = lr_up if vanilla else lr
            inp, hr = inp.to(device), hr.to(device)
            sr = model(inp)
            sr_np = np.clip(sr[0].cpu().permute(1,2,0).numpy(), 0, 1)
            hr_np = np.clip(hr[0].cpu().permute(1,2,0).numpy(), 0, 1)
            psnrs.append(compare_psnr(hr_np, sr_np, data_range=1.0))
            ssims.append(compare_ssim(hr_np, sr_np, channel_axis=2, data_range=1.0))
            fid.update(torch.clamp(sr, 0, 1), real=False)
            fid.update(torch.clamp(hr, 0, 1), real=True)
    print(f"Avg PSNR: {np.mean(psnrs):.2f} ± {np.std(psnrs):.2f}")
    print(f"Avg SSIM: {np.mean(ssims):.4f} ± {np.std(ssims):.4f}")
    print(f"FID: {fid.compute().item():.4f}")

# Evaluate both models on validation set
evaluate_model(vanilla_model, valid_loader, device=device, vanilla=True)
evaluate_model(improved_model, valid_loader, device=device, vanilla=False)


8️⃣ Visualize Best/Worst Examples

In [None]:
def visualize_examples(model, dataloader, device='cuda', vanilla=True, num_examples=2):
    model = model.to(device)
    model.eval()
    results = []
    psnrs = []
    with torch.no_grad():
        for lr, lr_up, hr in dataloader:
            inp = lr_up if vanilla else lr
            inp, hr = inp.to(device), hr.to(device)
            sr = model(inp)
            sr_np = np.clip(sr[0].cpu().permute(1,2,0).numpy(), 0, 1)
            hr_np = np.clip(hr[0].cpu().permute(1,2,0).numpy(), 0, 1)
            lr_np = np.clip(lr[0].cpu().permute(1,2,0).numpy(), 0, 1)
            psnr = compare_psnr(hr_np, sr_np, data_range=1.0)
            psnrs.append(psnr)
            results.append((lr_np, sr_np, hr_np, psnr))
    best_indices = np.argsort(psnrs)[-num_examples:]
    worst_indices = np.argsort(psnrs)[:num_examples]
    for idx in np.concatenate([best_indices, worst_indices]):
        lr_img, sr_img, hr_img, psnr = results[idx]
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1); plt.imshow(lr_img); plt.title("LR Input"); plt.axis('off')
        plt.subplot(1, 3, 2); plt.imshow(sr_img); plt.title(f"SR Output\nPSNR={psnr:.2f}"); plt.axis('off')
        plt.subplot(1, 3, 3); plt.imshow(hr_img); plt.title("HR Ground Truth"); plt.axis('off')
        plt.tight_layout()
        plt.show()


9️⃣ Ablation Study: Weight Decay Example

In [None]:
# Reuse train_model, just add a weight_decay parameter
def train_model_with_reg(model, train_loader, valid_loader, save_path, epochs=5, lr=1e-4, device='cuda', weight_decay=0, vanilla=False):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.MSELoss()
    train_losses, valid_psnrs = [], []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for lr_img, lr_up_img, hr_img in train_loader:
            inp = lr_up_img if vanilla else lr_img
            inp, hr_img = inp.to(device), hr_img.to(device)
            optimizer.zero_grad()
            sr = model(inp)
            loss = criterion(sr, hr_img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        model.eval()
        val_psnr = 0.0
        with torch.no_grad():
            for lr_img, lr_up_img, hr_img in valid_loader:
                inp = lr_up_img if vanilla else lr_img
                inp, hr_img = inp.to(device), hr_img.to(device)
                sr = model(inp)
                mse = ((sr - hr_img) ** 2).mean().item()
                val_psnr += calculate_psnr(mse)
        val_psnr /= len(valid_loader)
        valid_psnrs.append(val_psnr)
        print(f"[{'Vanilla' if vanilla else 'Improved'} | WD={weight_decay}] Epoch {epoch+1}: Loss={avg_loss:.6f}, Val PSNR={val_psnr:.2f}")
    torch.save(model.state_dict(), save_path)
    return train_losses, valid_psnrs

# Train and plot both with and without regularization as needed for your ablation study
