In [19]:
# 1. MOUNT GOOGLE DRIVE
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [20]:
import os
import random
import math
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
from tqdm.auto import tqdm

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337
BASE_DRIVE_PATH = '/content/drive/MyDrive/NSIN' # Change if needed
print(BASE_DRIVE_PATH)

def set_seed(seed: int = 1337):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

set_seed(SEED)
print(f"Running on device: {DEVICE}")

/content/drive/MyDrive/NSIN
Running on device: cuda


In [21]:
class PairedNISN(Dataset):
    """
    Loads paired noisy and clean images from a directory structure.
    Expected structure:
      root/
        noisy images/*.jpg
        ground truth/*.jpg
    """
    def __init__(self, split_root: str, resize_to=None):
        self.root = Path(split_root)
        self.noisy_dir = self.root / "noisy images"
        self.clean_dir = self.root / "ground truth"

        # Basic validation
        if not self.noisy_dir.exists() or not self.clean_dir.exists():
            print(f"Warning: Directories not found in {split_root}")

        # Transforms
        t = []
        if resize_to is not None:
            t.append(transforms.Resize(resize_to))
        t.append(transforms.ToTensor())
        self.transform = transforms.Compose(t)

        # Find files
        exts = ("*.jpg", "*.jpeg", "*.png")
        self.pairs = []

        # Index clean images by filename for fast lookup
        clean_index = {}
        if self.clean_dir.exists():
            for e in exts:
                for p in self.clean_dir.glob(e):
                    clean_index[p.name] = p

        # Match noisy images to clean images
        if self.noisy_dir.exists():
            noisy_paths = []
            for e in exts:
                noisy_paths.extend(self.noisy_dir.glob(e))

            for n_path in sorted(noisy_paths):
                # Logic: gauss_123_abc.jpg -> 123_abc.jpg
                # Adjust this logic if your naming convention differs
                clean_name = n_path.name
                if '_' in clean_name:
                     clean_name = clean_name[clean_name.find('_')+1:]

                c_path = clean_index.get(clean_name)
                if c_path:
                    self.pairs.append((n_path, c_path))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        n_path, c_path = self.pairs[idx]
        noisy = Image.open(n_path).convert("RGB")
        clean = Image.open(c_path).convert("RGB")
        return self.transform(noisy), self.transform(clean)

In [22]:
class DoubleConv(nn.Module):
    """(Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    """MaxPool -> DoubleConv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.pool(x)
        return self.conv(x)

class Up(nn.Module):
    """Upscale -> Concat -> DoubleConv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        # Handle padding issues if dimensions don't match perfectly
        diffY = skip.size(2) - x.size(2)
        diffX = skip.size(3) - x.size(3)
        if diffY > 0 or diffX > 0:
            x = F.pad(x, [diffX // 2, diffX - diffX // 2,
                          diffY // 2, diffY - diffY // 2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, base=64):
        super().__init__()
        self.inc = DoubleConv(in_ch, base)
        self.down1 = Down(base, base*2)
        self.down2 = Down(base*2, base*4)
        self.down3 = Down(base*4, base*8)
        self.down4 = Down(base*8, base*16)
        self.up1 = Up(base*16, base*8)
        self.up2 = Up(base*8, base*4)
        self.up3 = Up(base*4, base*2)
        self.up4 = Up(base*2, base)
        self.outc = nn.Conv2d(base, out_ch, kernel_size=1)
        self.act = nn.Sigmoid() # Output [0, 1]

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return self.act(logits)

In [23]:
class CombinedLoss(nn.Module):
    def __init__(self, l1_weight=1.0, vgg_weight=0.1, device='cuda'):
        super().__init__()
        self.l1_weight = l1_weight
        self.vgg_weight = vgg_weight

        # Standard L1 Loss
        self.l1_loss = nn.L1Loss()

        # VGG Perceptual Loss Setup
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features

        # We extract features from specific layers to capture texture/perceptual info
        # indices 2, 7, 12, 21, 30 usually correspond to relu1_2, relu2_2, etc.
        self.feature_layers = [2, 7, 12, 21, 30]
        self.feature_extractor = nn.ModuleList()

        # Extract layers and freeze them
        for i, layer in enumerate(vgg):
            if i in self.feature_layers:
                # Wrap previous layers + current layer into a Sequential for easier extraction
                # Note: This implementation just grabs specific layers.
                # A simpler way is to pass input through full VGG and hook outputs,
                # but extracting layers is efficient for inference.
                pass

        # A more robust VGG extractor for this context:
        self.vgg_submodules = vgg[:max(self.feature_layers) + 1].to(device).eval()
        for param in self.vgg_submodules.parameters():
            param.requires_grad = False

        # ImageNet normalization stats
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device))

    def get_vgg_features(self, x):
        """Forward pass through VGG layers, collecting features."""
        features = []
        # Normalize first
        x = (x - self.mean) / self.std
        for i, layer in enumerate(self.vgg_submodules):
            x = layer(x)
            if i in self.feature_layers:
                features.append(x)
        return features

    def forward(self, pred, target):
        # 1. L1 Pixel Loss
        loss_l1 = self.l1_loss(pred, target)

        # 2. VGG Perceptual Loss
        pred_feats = self.get_vgg_features(pred)
        target_feats = self.get_vgg_features(target)

        loss_vgg = 0.0
        for pf, tf in zip(pred_feats, target_feats):
            loss_vgg += F.l1_loss(pf, tf)

        # Combine
        total_loss = (self.l1_weight * loss_l1) + (self.vgg_weight * loss_vgg)
        return total_loss, loss_l1.item(), loss_vgg.item()

In [24]:
def train_model(model, train_loader, val_loader, epochs=20, lr=1e-4):
    # Initialize Loss
    # vgg_weight=0.1 is a common starting point; adjust if artifacts appear
    criterion = CombinedLoss(l1_weight=1.0, vgg_weight=0.1, device=DEVICE)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    # Gradient Scaler for Mixed Precision (faster on T4/A100)
    scaler = torch.amp.GradScaler('cuda')

    save_dir = Path("./checkpoints")
    save_dir.mkdir(exist_ok=True)

    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")

        for noisy, clean in pbar:
            noisy, clean = noisy.to(DEVICE), clean.to(DEVICE)

            optimizer.zero_grad()

            # Forward pass with Mixed Precision
            with torch.amp.autocast('cuda'):
                pred = model(noisy)
                loss, l1_val, vgg_val = criterion(pred, clean)

            # Backward pass
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()
            pbar.set_postfix({"L1": f"{l1_val:.4f}", "VGG": f"{vgg_val:.4f}", "Total": f"{loss.item():.4f}"})

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch} Train Loss: {avg_loss:.5f}")

        # Validation
        if val_loader:
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for v_noisy, v_clean in val_loader:
                    v_noisy, v_clean = v_noisy.to(DEVICE), v_clean.to(DEVICE)
                    v_pred = model(v_noisy)
                    v_loss, _, _ = criterion(v_pred, v_clean)
                    val_loss += v_loss.item()

            avg_val_loss = val_loss / len(val_loader)
            print(f"Epoch {epoch} Val Loss:   {avg_val_loss:.5f}")

            # Save best model
            torch.save(model.state_dict(), save_dir / "last.pth")
            scheduler.step(avg_val_loss)


In [26]:
# --- Execution ---

# 1. Create Datasets
# Ensure these paths exist in your Google Drive
train_ds = PairedNISN(f"{BASE_DRIVE_PATH}/train/train")
val_ds   = PairedNISN(f"{BASE_DRIVE_PATH}/validate/validate")

if len(train_ds) == 0:
    print("No training data found. Check BASE_DRIVE_PATH.")
else:
    # 2. Create Loaders
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

    # 3. Initialize Model
    model = UNet().to(DEVICE)

In [None]:
    # 4. Run Training
    # The model will train using L1 + VGG loss.
    train_model(model, train_loader, val_loader, epochs=20, lr=2e-4)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:02<00:00, 244MB/s]


Epoch 1/20:   0%|          | 0/132 [00:00<?, ?it/s]