In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image

# --- Utility modules ---

# Residual Block met skip-verbinding (1x1 convolution indien nodig)
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        # Indien in_channels != out_channels, pas dan een 1x1 convolution toe op de skip-verbinding.
        if in_channels != out_channels:
            self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.skip_conv = None

    def forward(self, x):
        identity = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        if self.skip_conv is not None:
            identity = self.skip_conv(identity)
        out += identity
        return self.relu(out)

# Downsample laag (conv met stride 2)
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.conv(x))

# Upsample laag (transposed conv)
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.conv(x))

# --- UNet Generator ---

# De UNet is hier aangepast zodat deze conditionele informatie kan verwerken. 
# De input bestaat uit de concatenatie van:
# - een "noisy" high quality–afbeelding (die met DDPM-ruis is aangetast)
# - de low quality–afbeelding als condition
# Daardoor krijgt de UNet 6 kanalen als input en voorspelt zij de ruis (3 kanalen) die moet worden verwijderd.
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=64):
        super(UNet, self).__init__()
        # Encoder
        self.enc1 = ResidualBlock(in_channels, base_channels)           # in_channels = 6 (3 noisy + 3 condition)
        self.down1 = Downsample(base_channels, base_channels * 2)         # 64 -> 128
        self.enc2 = ResidualBlock(base_channels * 2, base_channels * 2)     # 128 -> 128
        self.down2 = Downsample(base_channels * 2, base_channels * 4)       # 128 -> 256
        self.enc3 = ResidualBlock(base_channels * 4, base_channels * 4)     # 256 -> 256

        # Bottleneck
        self.bottleneck = ResidualBlock(base_channels * 4, base_channels * 4) # 256 -> 256

        # Decoder
        self.up2 = Upsample(base_channels * 4, base_channels * 2)           # 256 -> 128
        self.dec2 = ResidualBlock(base_channels * 4, base_channels * 2)       # na concat: 128+128=256 -> 128
        self.up1 = Upsample(base_channels * 2, base_channels)               # 128 -> 64
        self.dec1 = ResidualBlock(base_channels * 2, base_channels)           # na concat: 64+64=128 -> 64

        # Output: voorspelt 3 kanalen (ruis)
        self.output_conv = nn.Conv2d(base_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, time_embedding=None):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.down1(enc1))
        enc3 = self.enc3(self.down2(enc2))
        # Bottleneck
        bottleneck = self.bottleneck(enc3)
        # Decoder
        up2 = self.up2(bottleneck)
        dec2 = self.dec2(torch.cat([up2, enc2], dim=1))
        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([up1, enc1], dim=1))
        return self.output_conv(dec1)

# --- DDPM Noise Scheduler en Sampler ---

# Deze klasse beheert het toevoegen van ruis en het omkeren van het diffusieproces.
class DDPM(nn.Module):
    def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
        super(DDPM, self).__init__()
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.timesteps = timesteps
        self.betas = torch.linspace(beta_start, beta_end, timesteps, device=self.device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas = self.alphas.to(self.device)
        self.alpha_cumprod = self.alpha_cumprod.to(self.device)
        
    def add_noise(self, x, t):
        """
        Voeg ruis toe aan x volgens het DDPM schema op timestep t.
        x: clean (high quality) image, t: tensor met timesteps per sample.
        """
        noise = torch.randn_like(x)
        # t heeft vorm (batch,); haal voor elke sample de juiste cumulatieve alpha op.
        sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - self.alpha_cumprod[t]).view(-1, 1, 1, 1)
        noisy = sqrt_alpha_cumprod * x + sqrt_one_minus_alpha_cumprod * noise
        return noisy, noise

    def forward(self, x, cond, t):
        """
        Forward pass tijdens training: voorspel de toegevoegde ruis.
        x: noisy image (afgeleid van de high quality image)
        cond: condition (low quality image)
        t: timestep (per sample)
        """
        # Concateneer noisy image en condition langs de channel-dimensie.
        x_in = torch.cat([x, cond], dim=1)
        return self.model(x_in, t)

    def p_sample(self, x, t, cond):
        """
        Voer één reverse diffusion stap uit.
        x: huidige x_t (batch van afbeeldingen)
        t: huidige timestep (gehele getal)
        cond: condition (low quality image)
        """
        # Maak een tensor van t met de batchgrootte
        t_tensor = torch.tensor([t], device=self.device).long().expand(x.size(0))
        model_out = self.model(torch.cat([x, cond], dim=1), t_tensor)
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(self.alphas[t])
        sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod_t)
        # Bereken voorspelling van de clean image x0
        x0_pred = (x - sqrt_one_minus_alpha_cumprod_t * model_out) * sqrt_recip_alpha_t
        if t == 0:
            return x0_pred
        beta_t = self.betas[t]
        alpha_cumprod_prev = self.alpha_cumprod[t-1]
        # Bereken de posterior variantie volgens het DDPM-schema
        posterior_variance = beta_t * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t)
        noise = torch.randn_like(x)
        mean = torch.sqrt(alpha_cumprod_prev) * x0_pred
        x_prev = mean + torch.sqrt(posterior_variance) * noise
        return x_prev

    def sample(self, cond):
        """
        Voer het volledige reverse diffusieproces uit om een herstelde afbeelding (x0) te genereren,
        gegeven de condition (low quality image).
        """
        # Begin met pure ruis (vorm: (batch, 3, H, W))
        x = torch.randn_like(cond).to(self.device)
        for t in reversed(range(1, self.timesteps)):
            x = self.p_sample(x, t, cond)
        # Laatste stap bij t = 0
        t0 = 0
        t_tensor = torch.tensor([t0], device=self.device).long().expand(x.size(0))
        model_out = self.model(torch.cat([x, cond], dim=1), t_tensor)
        sqrt_recip_alpha0 = 1.0 / torch.sqrt(self.alphas[t0])
        sqrt_one_minus_alpha_cumprod0 = torch.sqrt(1 - self.alpha_cumprod[t0])
        x0_pred = (x - sqrt_one_minus_alpha_cumprod0 * model_out) * sqrt_recip_alpha0
        return x0_pred

# --- Dataset ---

class ImageDataset(Dataset):
    def __init__(self, lq_dir, hq_dir, transform=None):
        self.lq_paths = sorted([os.path.join(lq_dir, f) for f in os.listdir(lq_dir) 
                                  if f.lower().endswith((".png", ".jpg", ".jpeg"))])
        self.hq_paths = sorted([os.path.join(hq_dir, f) for f in os.listdir(hq_dir) 
                                  if f.lower().endswith((".png", ".jpg", ".jpeg"))])
        if len(self.lq_paths) == 0 or len(self.hq_paths) == 0:
            raise ValueError(f"Dataset is leeg! Controleer paden: {lq_dir} en {hq_dir}")
        self.transform = transform

    def __len__(self):
        return len(self.lq_paths)

    def __getitem__(self, index):
        lq_image = self._load_image(self.lq_paths[index])
        hq_image = self._load_image(self.hq_paths[index])
        if self.transform:
            lq_image = self.transform(lq_image)
            hq_image = self.transform(hq_image)
        return lq_image, hq_image

    def _load_image(self, path):
        return Image.open(path).convert("RGB")

# --- Training en Test Framework ---

class ImageRestorationModel:
    def __init__(self, lq_dir, hq_dir, test_lq_dir, test_hq_dir, device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        # Let op: de UNet krijgt 6 inputkanalen (noisy + condition) en voorspelt 3 outputkanalen (ruis)
        self.model = UNet(in_channels=6, out_channels=3).to(self.device)
        self.ddpm = DDPM(self.model, timesteps=1000, device=device).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
        self.criterion = nn.MSELoss()

        transform = transforms.Compose([transforms.ToTensor()])
        self.train_dataset = ImageDataset(lq_dir, hq_dir, transform=transform)
        self.test_dataset = ImageDataset(test_lq_dir, test_hq_dir, transform=transform)
        self.train_loader = DataLoader(self.train_dataset, batch_size=8, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=1, shuffle=False)

    def train(self, epochs=10):
        self.model.train()
        for epoch in range(epochs):
            epoch_loss = 0
            for lq, hq in tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                lq, hq = lq.to(self.device), hq.to(self.device)
                # Kies voor elke sample een willekeurige timestep t
                t = torch.randint(0, self.ddpm.timesteps, (lq.size(0),), device=self.device)
                # Voeg ruis toe aan de high quality image
                noisy_hq, noise = self.ddpm.add_noise(hq, t)
                # Geef de noisy image samen met de low quality image (condition) aan het model
                pred_noise = self.ddpm(noisy_hq, lq, t)
                loss = self.criterion(pred_noise, noise)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
            print(f"Epoch {epoch+1}: Gemiddelde loss = {epoch_loss/len(self.train_loader):.6f}")

    def test(self):
        self.model.eval()
        with torch.no_grad():
            for idx, (lq, hq) in enumerate(self.test_loader):
                lq = lq.to(self.device)
                # Genereer de herstelde afbeelding via reverse diffusion
                restored = self.ddpm.sample(lq)
                save_image(restored, f"restored_{idx}.png")
                save_image(hq.to(self.device), f"ground_truth_{idx}.png")
                print(f"Test sample {idx} opgeslagen.")

# --- Main ---

if __name__ == "__main__":
    # Stel de paden in voor de Low Quality (LQ) en High Quality (HQ) afbeeldingen
    lq_dir = "data/train/low"
    hq_dir = "data/train/high"
    test_lq_dir = "data/test/low"
    test_hq_dir = "data/test/high"

    # Initialiseer en train het model
    model = ImageRestorationModel(lq_dir, hq_dir, test_lq_dir, test_hq_dir, device="cuda")
    model.train(epochs=5)

    # Test het model
    model.test()


Epoch 1/5: 100%|██████████████████████████████████████████████████████████████████████| 270/270 [00:37<00:00,  7.12it/s]


Epoch 1: Gemiddelde loss = 0.273607


Epoch 2/5: 100%|██████████████████████████████████████████████████████████████████████| 270/270 [00:37<00:00,  7.19it/s]


Epoch 2: Gemiddelde loss = 0.081430


Epoch 3/5: 100%|██████████████████████████████████████████████████████████████████████| 270/270 [00:37<00:00,  7.21it/s]


Epoch 3: Gemiddelde loss = 0.065651


Epoch 4/5: 100%|██████████████████████████████████████████████████████████████████████| 270/270 [00:37<00:00,  7.19it/s]


Epoch 4: Gemiddelde loss = 0.050542


Epoch 5/5: 100%|██████████████████████████████████████████████████████████████████████| 270/270 [00:37<00:00,  7.21it/s]


Epoch 5: Gemiddelde loss = 0.047046
Test sample 0 opgeslagen.
Test sample 1 opgeslagen.
Test sample 2 opgeslagen.
Test sample 3 opgeslagen.
Test sample 4 opgeslagen.
Test sample 5 opgeslagen.
Test sample 6 opgeslagen.
Test sample 7 opgeslagen.
Test sample 8 opgeslagen.
Test sample 9 opgeslagen.
Test sample 10 opgeslagen.
Test sample 11 opgeslagen.
Test sample 12 opgeslagen.
Test sample 13 opgeslagen.
Test sample 14 opgeslagen.
Test sample 15 opgeslagen.
Test sample 16 opgeslagen.
Test sample 17 opgeslagen.
Test sample 18 opgeslagen.
Test sample 19 opgeslagen.
Test sample 20 opgeslagen.
Test sample 21 opgeslagen.
Test sample 22 opgeslagen.
Test sample 23 opgeslagen.
Test sample 24 opgeslagen.


In [9]:
import os

lq_dir = "data/train/low"
hq_dir = "data/train/high"

test_lq_dir = "data/test/low"
test_hq_dir = "data/test/high"

# Check if directories exist
for path in [lq_dir, hq_dir, test_lq_dir, test_hq_dir]:
    if not os.path.exists(path):
        print(f"Directory does not exist: {path}")
    else:
        print(f"{path} contains {len(os.listdir(path))} files")


data/train/low contains 2160 files
data/train/high contains 2160 files
data/test/low contains 25 files
data/test/high contains 25 files
