In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
from PIL import Image
from tqdm import tqdm

# ============================
# 1) DEVICE
# ============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# ============================
# 2) PATHS
# ============================
DEG_DIR = "../data/train/degraded_images"
HR_DIR  = "../data/train/images"

CACHE_DIR = "cache_color"
CHECKPOINT_DIR = "checkpoints_color"
GRID_DIR = "samples_colorclean"

os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(GRID_DIR, exist_ok=True)

# ============================
# 3) TRANSFORMS
# ============================
transform_rgb = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),  # [0,1]
])

# ============================
# 4) LUMA (pseudo grayscale)
# ============================
def rgb_to_luma(rgb01):
    r, g, b = rgb01[0], rgb01[1], rgb01[2]
    y = 0.299*r + 0.587*g + 0.114*b
    return y.unsqueeze(0)  # [1,H,W]

# ============================
# 5) DATASET + CACHE
# ============================
class ColorDataset(Dataset):
    """
    On fait simple :
    Input = Luma(degraded)  [1,H,W]
    Target = HR RGB         [3,H,W]

    Cache = accélère énormément (sinon on relit + resize + calcule à chaque epoch)
    """
    def __init__(self, deg_dir, hr_dir, cache_dir, max_items=8000):
        self.cache_dir = cache_dir

        deg_files = sorted([f for f in os.listdir(deg_dir) if f.startswith("degraded_")])
        hr_set = set(os.listdir(hr_dir))

        self.pairs = []
        for f in deg_files:
            clean = f.replace("degraded_", "")
            if clean in hr_set:
                self.pairs.append((f, clean))

        self.pairs = self.pairs[:max_items]
        print("Aligned pairs:", len(self.pairs))

        self._build_cache(deg_dir, hr_dir)

    def _build_cache(self, deg_dir, hr_dir):
        print("Building cache (only once)...")
        for i, (d, h) in enumerate(tqdm(self.pairs)):
            path = os.path.join(self.cache_dir, f"{i:06d}.pt")
            if os.path.exists(path):
                continue

            deg = Image.open(os.path.join(deg_dir, d)).convert("RGB")
            hr  = Image.open(os.path.join(hr_dir, h)).convert("RGB")

            deg = transform_rgb(deg)
            hr  = transform_rgb(hr)

            L = rgb_to_luma(deg)
            torch.save({"L": L, "target": hr}, path)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        data = torch.load(os.path.join(self.cache_dir, f"{idx:06d}.pt"), map_location="cpu")

        # L est obligatoire
        L = data["L"]

        # target peut avoir plusieurs noms selon tes anciens scripts
        if "target" in data:
            target = data["target"]
        elif "target_rgb" in data:
            target = data["target_rgb"]
        elif "hr" in data:
            target = data["hr"]
        else:
            raise KeyError(f"Unknown keys in cache file: {data.keys()}")

        return L, target

# ============================
# 6) UNET (light)
# ============================
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class UNetRGB(nn.Module):
    def __init__(self, base=32):
        super().__init__()

        self.inc = DoubleConv(1, base)          # 128x128
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),                    # 64x64
            DoubleConv(base, base*2)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),                    # 32x32
            DoubleConv(base*2, base*4)
        )

        self.mid = DoubleConv(base*4, base*4)   # 32x32

        self.up1 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)  # 64x64
        self.conv1 = DoubleConv(base*4 + base*2, base*2)

        self.up2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)  # 128x128
        self.conv2 = DoubleConv(base*2 + base, base)

        self.out = nn.Conv2d(base, 3, 1)

    def forward(self, x):
        x1 = self.inc(x)     # 128
        x2 = self.down1(x1)  # 64
        x3 = self.down2(x2)  # 32

        xm = self.mid(x3)

        x = self.up1(xm)              # 64
        x = torch.cat([x, x2], dim=1)
        x = self.conv1(x)

        x = self.up2(x)               # 128
        x = torch.cat([x, x1], dim=1)
        x = self.conv2(x)

        return torch.sigmoid(self.out(x))  # [0,1]

# ============================
# 7) TRAIN CONFIG
# ============================
BATCH_SIZE = 8 if DEVICE.type == "cuda" else 1
EPOCHS = 6
LR = 2e-4
MAX_ITEMS = 8000

dataset = ColorDataset(DEG_DIR, HR_DIR, CACHE_DIR, max_items=MAX_ITEMS)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

model = UNetRGB(base=32).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.L1Loss()

# ============================
# 8) TRAIN LOOP
# ============================
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for L, target in loop:
        L = L.to(DEVICE)
        target = target.to(DEVICE)

        pred = model(L)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=float(loss.item()))

    avg = total_loss / len(loader)
    print(f"Epoch {epoch+1} | Avg Loss: {avg:.4f}")

    torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/color_epoch_{epoch+1}.pth")

    model.eval()
    with torch.no_grad():
        L, target = next(iter(loader))
        pred = model(L.to(DEVICE)).cpu()

        # grid: input(gray->rgb), pred, gt
        # On convertit L en 3 canaux pour visualiser
        L3 = L.repeat(1,3,1,1)

        grid = torch.cat([L3[:4], pred[:4], target[:4]], dim=0)
        utils.save_image(grid, f"{GRID_DIR}/epoch_{epoch+1}.png", nrow=4)

print("✅ DONE. Check:", GRID_DIR)


DEVICE: cpu
Aligned pairs: 8000
Building cache (only once)...


100%|██████████| 8000/8000 [00:00<00:00, 18965.50it/s]
Epoch 1/6: 100%|██████████| 8000/8000 [1:07:05<00:00,  1.99it/s, loss=0.0474]   


Epoch 1 | Avg Loss: 0.0772


Epoch 2/6: 100%|██████████| 8000/8000 [34:46<00:00,  3.84it/s, loss=0.0742]   


Epoch 2 | Avg Loss: 0.0713


Epoch 3/6: 100%|██████████| 8000/8000 [30:19<00:00,  4.40it/s, loss=0.0835]


Epoch 3 | Avg Loss: 0.0704


Epoch 4/6: 100%|██████████| 8000/8000 [30:38<00:00,  4.35it/s, loss=0.171] 


Epoch 4 | Avg Loss: 0.0699


Epoch 5/6: 100%|██████████| 8000/8000 [32:48<00:00,  4.07it/s, loss=0.0512]


Epoch 5 | Avg Loss: 0.0697


Epoch 6/6: 100%|██████████| 8000/8000 [30:12<00:00,  4.41it/s, loss=0.0431]

Epoch 6 | Avg Loss: 0.0695
✅ DONE. Check: samples_colorclean





In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from PIL import Image

# ============================
# DEVICE
# ============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# ============================
# PATHS
# ============================
DEG_DIR = "../data/test/degraded_images"   # prends TEST pour voir un vrai rendu
HR_DIR  = "../data/test/images"

CNN1_CKPT = "checkpoints/epoch_8.pth"
CNN2_CKPT = "checkpoints_color/color_epoch_8.pth"  # ou ton dernier

OUT_GRID = "grid_compare.png"

# ============================
# TRANSFORM
# ============================
transform_rgb = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# ============================
# CNN1 ARCHI
# ============================
class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.BatchNorm2d(ch)
        )
    def forward(self, x):
        return torch.relu(x + self.block(x))

class ElegantCNN(nn.Module):
    def __init__(self, in_ch=3, base_ch=64, n_blocks=6):
        super().__init__()
        self.inp = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.body = nn.Sequential(*[ResidualBlock(base_ch) for _ in range(n_blocks)])
        self.out = nn.Conv2d(base_ch, in_ch, 3, padding=1)
    def forward(self, x):
        identity = x
        x = torch.relu(self.inp(x))
        x = self.body(x)
        x = self.out(x)
        return torch.tanh(x + identity)

# ============================
# CNN2 ARCHI
# ============================
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class UNetRGB(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.inc = DoubleConv(1, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base*2, base*4))
        self.mid = DoubleConv(base*4, base*4)

        self.up1 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv1 = DoubleConv(base*4 + base*2, base*2)

        self.up2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv2 = DoubleConv(base*2 + base, base)

        self.out = nn.Conv2d(base, 3, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        xm = self.mid(x3)

        x = self.up1(xm)
        x = torch.cat([x, x2], dim=1)
        x = self.conv1(x)

        x = self.up2(x)
        x = torch.cat([x, x1], dim=1)
        x = self.conv2(x)

        return torch.sigmoid(self.out(x))

# ============================
# UTILS
# ============================
def rgb_to_luma(rgb01):
    r, g, b = rgb01[:,0], rgb01[:,1], rgb01[:,2]  # batch
    y = 0.299*r + 0.587*g + 0.114*b
    return y.unsqueeze(1)  # [B,1,H,W]

# ============================
# LOAD MODELS
# ============================
cnn1 = ElegantCNN().to(DEVICE)
cnn1.load_state_dict(torch.load(CNN1_CKPT, map_location=DEVICE))
cnn1.eval()

cnn2 = UNetRGB(base=32).to(DEVICE)
cnn2.load_state_dict(torch.load(CNN2_CKPT, map_location=DEVICE))
cnn2.eval()

print("✅ Models loaded")

# ============================
# BUILD MINI DATASET (just for grid)
# ============================
deg_files = sorted([f for f in os.listdir(DEG_DIR) if f.startswith("degraded_")])[:8]

degs, hrs = [], []
for f in deg_files:
    clean_name = f.replace("degraded_", "")
    deg = Image.open(os.path.join(DEG_DIR, f)).convert("RGB")
    hr  = Image.open(os.path.join(HR_DIR, clean_name)).convert("RGB")
    degs.append(transform_rgb(deg))
    hrs.append(transform_rgb(hr))

deg_batch = torch.stack(degs).to(DEVICE)   # [B,3,128,128] in [0,1]
hr_batch  = torch.stack(hrs).to(DEVICE)

# ============================
# FORWARD
# ============================
with torch.no_grad():
    # CNN1 wants [-1,1]
    deg_in = deg_batch * 2 - 1
    restored = cnn1(deg_in)                # [-1,1]
    restored01 = (restored + 1) / 2        # [0,1]

    # CNN2 input = L(restored)
    L = rgb_to_luma(restored01)
    pred_rgb = cnn2(L)                     # [0,1]

# ============================
# SAVE GRID
# ============================
# 4 rows: degraded / restored / colorized / gt
grid = torch.cat([
    deg_batch.cpu(),
    restored01.cpu(),
    pred_rgb.cpu(),
    hr_batch.cpu()
], dim=0)

utils.save_image(grid, OUT_GRID, nrow=len(deg_files))
print("✅ Grid saved:", OUT_GRID)


DEVICE: cpu


FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints_color/color_restored_epoch_8.pth'