In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, models, utils
from tqdm import tqdm
import os
from PIL import Image

# =========================
# DEVICE
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# TRANSFORMS (MODEL SPACE: [-1, 1])
# =========================
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

# =========================
# DATASET
# =========================
class RestorationDataset(torch.utils.data.Dataset):
    def __init__(self, degraded_dir, clean_dir, transform=None):
        self.degraded_dir = degraded_dir
        self.clean_dir = clean_dir
        self.transform = transform

        degraded_files = os.listdir(degraded_dir)
        clean_set = set(os.listdir(clean_dir))

        self.filenames = [
            f for f in degraded_files
            if f.replace("degraded_", "") in clean_set
        ]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        deg_name = self.filenames[idx]
        cln_name = deg_name.replace("degraded_", "")

        deg = Image.open(os.path.join(self.degraded_dir, deg_name)).convert("RGB")
        cln = Image.open(os.path.join(self.clean_dir, cln_name)).convert("RGB")

        if self.transform:
            seed = torch.seed()
            torch.manual_seed(seed)
            deg = self.transform(deg)
            torch.manual_seed(seed)
            cln = self.transform(cln)

        return deg, cln

# =========================
# PERCEPTUAL LOSS
# =========================
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights="IMAGENET1K_V1").features[:16].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg

        self.register_buffer(
            "mean", torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        )
        self.register_buffer(
            "std", torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
        )

    def forward(self, x, y):
        x = (x + 1) / 2
        y = (y + 1) / 2
        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std
        return nn.functional.l1_loss(self.vgg(x), self.vgg(y))

# =========================
# RESIDUAL BLOCK
# =========================
class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch)
        )

    def forward(self, x):
        return torch.relu(x + self.block(x))

# =========================
# CNN (RESIDUAL RESTORATION)
# =========================
class ElegantCNN(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, n_blocks=6):
        super().__init__()
        self.inp = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.body = nn.Sequential(*[ResidualBlock(base_ch) for _ in range(n_blocks)])
        self.out = nn.Conv2d(base_ch, in_ch, 3, padding=1)

    def forward(self, x):
        identity = x
        x = torch.relu(self.inp(x))
        x = self.body(x)
        x = self.out(x)
        return torch.tanh(x + identity)

# =========================
# HYPERPARAMS
# =========================
BATCH_SIZE = 16
LR = 1e-3
EPOCHS = 8

CHECKPOINT_DIR = "checkpoints"
SAMPLES_DIR = "samples_grids"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)

# =========================
# DATA
# =========================
train_full = RestorationDataset(
    "../data/train/degraded_images",
    "../data/train/images",
    transform
)
val_full = RestorationDataset(
    "../data/test/degraded_images",
    "../data/test/images",
    transform
)

train_ds = Subset(train_full, range(min(10000, len(train_full))))
val_ds = Subset(val_full, range(min(1000, len(val_full))))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

print("Train pairs:", len(train_ds))
print("Val pairs:", len(val_ds))

# =========================
# MODEL & LOSS
# =========================
model = ElegantCNN().to(DEVICE)
criterion_l1 = nn.L1Loss()
criterion_mse = nn.MSELoss()
perceptual = VGGPerceptualLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)

# =========================
# TRAINING
# =========================
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for deg, cln in loop:
        deg, cln = deg.to(DEVICE), cln.to(DEVICE)
        optimizer.zero_grad()
        out = model(deg)
        lambda_vgg = min(0.1, epoch / 3 * 0.1)
        loss = criterion_l1(out, cln) + 0.5*criterion_mse(out, cln) + lambda_vgg*perceptual(out, cln)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1} | Avg loss: {total_loss / len(train_loader):.4f}")

    # ===== CHECKPOINT =====
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"epoch_{epoch+1}.pth"))

    # ===== GRID SAVE =====
    model.eval()
    with torch.no_grad():
        all_degs, all_restored, all_clean = [], [], []
        for i, (deg, cln) in enumerate(val_loader):
            deg, cln = deg.to(DEVICE), cln.to(DEVICE)
            restored = model(deg)
            all_degs.append(deg)
            all_restored.append(restored)
            all_clean.append(cln)
            # Stop at ~32 images to keep grids manageable
            if len(all_degs)*BATCH_SIZE >= 32:
                break

        deg_grid = torch.cat(all_degs)[:32]
        res_grid = torch.cat(all_restored)[:32]
        cln_grid = torch.cat(all_clean)[:32]

        grid = torch.cat([deg_grid, res_grid, cln_grid], dim=0)
        grid = (grid*0.5 + 0.5).clamp(0,1)

        utils.save_image(
            grid,
            os.path.join(SAMPLES_DIR, f"epoch_{epoch+1}_grid.png"),
            nrow=deg_grid.size(0)
        )

    model.train()

print("✅ Training finished. Grids saved in:", SAMPLES_DIR)


In [4]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os

# =========================
# DEVICE
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# TRANSFORM (comme training)
# =========================
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

# =========================
# MODEL (même architecture)
# =========================
class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch)
        )

    def forward(self, x):
        return torch.relu(x + self.block(x))

class ElegantCNN(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, n_blocks=6):
        super().__init__()
        self.inp = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.body = nn.Sequential(*[ResidualBlock(base_ch) for _ in range(n_blocks)])
        self.out = nn.Conv2d(base_ch, in_ch, 3, padding=1)

    def forward(self, x):
        identity = x
        x = torch.relu(self.inp(x))
        x = self.body(x)
        x = self.out(x)
        return torch.tanh(x + identity)

# =========================
# LOAD MODEL
# =========================
CHECKPOINT_PATH = "checkpoints/epoch_8.pth" 

model = ElegantCNN().to(DEVICE)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
model.eval()

# =========================
# TEST IMAGE 
# =========================
INPUT_IMAGE = "../data/test/degraded_images/degraded_image_000012.jpg"  # <-- change ici
OUTPUT_IMAGE = "restored_output.png"

# Load image
img = Image.open(INPUT_IMAGE).convert("RGB")

# Preprocess
x = transform(img).unsqueeze(0).to(DEVICE)  # (1,3,128,128)

# Inference
with torch.no_grad():
    out = model(x)

# Back to [0,1]
out = (out.squeeze(0) * 0.5 + 0.5).clamp(0, 1)

# Save output
out_img = transforms.ToPILImage()(out.cpu())
out_img.save(OUTPUT_IMAGE)

print(" Image restaurée sauvegardée :", OUTPUT_IMAGE)


 Image restaurée sauvegardée : restored_output.png
