<a href="https://colab.research.google.com/github/janbanot/msc-cs-code/blob/main/sem3/DL/DL_2025_Task3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Sprawozdanie 3 - Jan Banot
!uv pip install torchinfo kagglehub

In [None]:
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 1. Konfiguracja urządzenia i parametrów
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
IMG_SIZE = 28

print(f"Praca na urządzeniu: {DEVICE}")

# 2. Pobieranie danych z Kaggle (CelebA-HQ Resized)
try:
    import kagglehub
    path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")
    # Zbiór zazwyczaj znajduje się w podfolderze 'celeba_hq_256'
    data_dir = os.path.join(path, "celeba_hq_256")
    if not os.path.exists(data_dir):
        data_dir = path
    print(f"Dane pobrane do: {data_dir}")
except Exception as e:
    print(f"Błąd pobierania: {e}. Upewnij się, że masz zainstalowane kagglehub: !pip install kagglehub")

# 3. Definicja transformacji (Kluczowe dla zadania)
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Konwersja na skalę szarości
    transforms.Resize((IMG_SIZE, IMG_SIZE)),      # Zmniejszenie rozdzielczości do 28x28
    transforms.ToTensor(),                       # Zakres [0, 1]
])

class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Tworzymy listę wszystkich plików .jpg w folderze
        self.image_names = [f for f in os.listdir(root_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        if len(self.image_names) == 0:
            raise RuntimeError(f"Nie znaleziono zdjęć w: {root_dir}")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_names[idx])
        # Otwieramy obraz i konwertujemy na RGB
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, 0

# 4. Ładowanie zbioru danych
dataset = CelebADataset(root_dir=data_dir, transform=transform)

# Podział na trening i test
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 5. Wizualizacja przygotowanych danych
def imshow(img):
    img = img.numpy().transpose((1, 2, 0))
    plt.figure(figsize=(8, 8))
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.title("Podgląd przetworzonych danych (CelebA 28x28 Grayscale)")
    plt.show()

dataiter = iter(train_loader)
images, _ = next(dataiter)
imshow(torchvision.utils.make_grid(images[:16]))

print(f"Liczba obrazów treningowych: {len(train_dataset)}")
print(f"Kształt pojedynczego obrazu: {images[0].shape}")

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

# --- 1. Model VAE ---
class CelebVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(CelebVAE, self).__init__()
        self.latent_dim = latent_dim

        # Koder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # -> (64, 14, 14)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), # -> (128, 7, 7)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(128 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(128 * 7 * 7, latent_dim)

        # Dekoder
        self.decoder_input = nn.Linear(latent_dim, 128 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # -> (64, 14, 14)
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),  # -> (1, 28, 28)
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z).view(-1, 128, 7, 7)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# --- 2. Dyskryminator (dla VAE-GAN) ---
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 1)
        )

    def forward(self, x):
        return self.main(x)

# --- 3. Wspólna architektura U-Net dla FM i DDPM ---
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, dim),
            nn.LeakyReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, t):
        return self.mlp(t)

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_dim, out_ch)
        if up:
            self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1)
        else:
            self.conv = nn.Conv2d(in_ch, out_ch, 4, 2, 1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.LeakyReLU()

    def forward(self, x, t):
        h = self.conv(x)
        h = h + self.time_mlp(t).unsqueeze(-1).unsqueeze(-1)
        return self.relu(self.bn(h))

class SimpleUNet(nn.Module):
    def __init__(self, time_dim=64):
        super().__init__()
        self.time_embed = TimeEmbedding(time_dim)

        # Downsampling
        self.down1 = UNetBlock(1, 64, time_dim)    # 28 -> 14
        self.down2 = UNetBlock(64, 128, time_dim) # 14 -> 7

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU()
        )

        # Upsampling
        self.up1 = UNetBlock(128, 64, time_dim, up=True) # 7 -> 14
        self.up2 = UNetBlock(128, 1, time_dim, up=True)  # 14 -> 28

    def forward(self, x, t):
        t = self.time_embed(t)
        d1 = self.down1(x, t)
        d2 = self.down2(d1, t)
        b = self.bottleneck(d2)
        u1 = self.up1(b, t)
        u2 = self.up2(torch.cat([u1, d1], dim=1), t)
        return u2

# --- 4. Inicjalizacja modeli i sprawdzenie poprawności ---
vae_model = CelebVAE().to(DEVICE)
unet_model = SimpleUNet().to(DEVICE)
disc_model = Discriminator().to(DEVICE)

print("Podsumowanie VAE:")
summary(vae_model, input_size=(1, 1, 28, 28))
print("\nPodsumowanie U-Net (FM/DDPM):")
summary(unet_model, input_data=[torch.randn(1, 1, 28, 28).to(DEVICE), torch.randn(1, 1).to(DEVICE)])

In [None]:
# Parametry wspólne
EPOCHS = 10
LR = 1e-3

# Konfiguracja DDPM
T = 500 # Liczba kroków czasowych
betas = torch.linspace(1e-4, 0.02, T).to(DEVICE)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

def get_loss_vae(model, x, beta=1.0):
    recon_x, mu, logvar = model(x)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (BCE + beta * KLD) / x.size(0)

def get_loss_fm(model, x1):
    x0 = torch.randn_like(x1)
    t = torch.rand(x1.size(0), 1).to(DEVICE)
    # Ścieżka liniowa: x_t = (1-t)*x0 + t*x1
    xt = (1 - t.view(-1, 1, 1, 1)) * x0 + t.view(-1, 1, 1, 1) * x1
    v_pred = model(xt, t)
    target_v = x1 - x0
    return F.mse_loss(v_pred, target_v)

def get_loss_ddpm(model, x0):
    t = torch.randint(0, T, (x0.size(0),)).to(DEVICE)
    noise = torch.randn_like(x0)
    alpha_t = alphas_cumprod[t].view(-1, 1, 1, 1)
    # Zaszumianie: x_t = sqrt(alpha_bar)*x0 + sqrt(1-alpha_bar)*noise
    xt = torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise
    epsilon_pred = model(xt, t.float().view(-1, 1) / T)
    return F.mse_loss(epsilon_pred, noise)

In [None]:
from tqdm.auto import tqdm

history = {'vae': [], 'fm': [], 'ddpm': []}

# Optymalizatory
opt_vae = torch.optim.AdamW(vae_model.parameters(), lr=LR)
opt_fm = torch.optim.AdamW(unet_model.parameters(), lr=LR) # Wspólny UNet
# Dla czystego porównania dwa osobne UNety, trenujemy je jeden po drugim.
unet_fm = SimpleUNet().to(DEVICE)
unet_ddpm = SimpleUNet().to(DEVICE)
opt_fm = torch.optim.AdamW(unet_fm.parameters(), lr=LR)
opt_ddpm = torch.optim.AdamW(unet_ddpm.parameters(), lr=LR)

print("Rozpoczynam trening modeli (10 epok każdy)...")

for epoch in range(1, EPOCHS + 1):
    l_vae, l_fm, l_ddpm = 0, 0, 0

    for batch, _ in tqdm(train_loader, desc=f"Epoch {epoch}"):
        batch = batch.to(DEVICE)

        # --- Trening VAE ---
        opt_vae.zero_grad()
        loss_v = get_loss_vae(vae_model, batch)
        loss_v.backward()
        opt_vae.step()
        l_vae += loss_v.item()

        # --- Trening FM ---
        opt_fm.zero_grad()
        loss_f = get_loss_fm(unet_fm, batch)
        loss_f.backward()
        opt_fm.step()
        l_fm += loss_f.item()

        # --- Trening DDPM ---
        opt_ddpm.zero_grad()
        loss_d = get_loss_ddpm(unet_ddpm, batch)
        loss_d.backward()
        opt_ddpm.step()
        l_ddpm += loss_d.item()

    # Logowanie średniej straty
    history['vae'].append(l_vae / len(train_loader))
    history['fm'].append(l_fm / len(train_loader))
    history['ddpm'].append(l_ddpm / len(train_loader))

    print(f"[{epoch}/{EPOCHS}] VAE: {history['vae'][-1]:.4f} | FM: {history['fm'][-1]:.4f} | DDPM: {history['ddpm'][-1]:.4f}")

In [None]:
@torch.no_grad()
def sample_vae(model, n=16):
    z = torch.randn(n, model.latent_dim).to(DEVICE)
    return model.decode(z).cpu()

@torch.no_grad()
def sample_fm(model, n=16, steps=50):
    x = torch.randn(n, 1, 28, 28).to(DEVICE)
    dt = 1.0 / steps
    for i in range(steps):
        t = torch.ones(n, 1).to(DEVICE) * (i / steps)
        v = model(x, t)
        x = x + v * dt
    return x.cpu()

@torch.no_grad()
def sample_ddpm(model, n=16):
    x = torch.randn(n, 1, 28, 28).to(DEVICE)
    for i in reversed(range(T)):
        t = (torch.ones(n, 1) * i / T).to(DEVICE)
        epsilon_pred = model(x, t)

        alpha = alphas[i]
        alpha_bar = alphas_cumprod[i]
        beta = betas[i]

        if i > 0:
            noise = torch.randn_like(x)
        else:
            noise = 0

        # Formuła odszumiania DDPM
        x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * epsilon_pred) + torch.sqrt(beta) * noise
    return x.cpu()

In [None]:
import matplotlib.pyplot as plt

def plot_results(history, vae_imgs, fm_imgs, ddpm_imgs):
    # --- Wykresy Funkcji Straty ---
    fig, ax = plt.subplots(1, 3, figsize=(18, 5))

    epochs_range = range(1, len(history['vae']) + 1)

    ax[0].plot(epochs_range, history['vae'], color='blue', marker='o')
    ax[0].set_title("VAE Loss (BCE + KLD)")
    ax[0].set_xlabel("Epoch")
    ax[0].grid(True)

    ax[1].plot(epochs_range, history['fm'], color='green', marker='o')
    ax[1].set_title("Flow Matching Loss (MSE)")
    ax[1].set_xlabel("Epoch")
    ax[1].grid(True)

    ax[2].plot(epochs_range, history['ddpm'], color='red', marker='o')
    ax[2].set_title("DDPM Loss (MSE)")
    ax[2].set_xlabel("Epoch")
    ax[2].grid(True)

    plt.tight_layout()
    plt.show()

    # --- Porównanie Generowanych Obrazów ---
    def show_row(imgs, title, row_idx, n=8):
        for i in range(n):
            plt.subplot(3, n, row_idx * n + i + 1)
            plt.imshow(imgs[i].squeeze(), cmap='gray')
            plt.axis('off')
            if i == 0: plt.title(title, loc='left', pad=10)

    plt.figure(figsize=(15, 7))
    show_row(vae_imgs, "VAE (Blurry)", 0)
    show_row(fm_imgs, "Flow Matching", 1)
    show_row(ddpm_imgs, "DDPM (Detailed)", 2)
    plt.suptitle("Porównanie Modelu: VAE vs FM vs DDPM (CelebA 28x28)", fontsize=16)
    plt.tight_layout()
    plt.show()

# Próbkowanie modeli
print("Generuję próbki...")
vae_samples = sample_vae(vae_model, n=16)
fm_samples = sample_fm(unet_fm, n=16)
ddpm_samples = sample_ddpm(unet_ddpm, n=16)

# Wyświetlanie
plot_results(history, vae_samples, fm_samples, ddpm_samples)

In [None]:
# --- KROK: Dłuższy trening (50 epok) i weryfikacja ---

NEW_EPOCHS = 50
vae_model = CelebVAE().to(DEVICE); unet_fm = SimpleUNet().to(DEVICE); unet_ddpm = SimpleUNet().to(DEVICE)
opt_vae = torch.optim.AdamW(vae_model.parameters(), lr=LR); opt_fm = torch.optim.AdamW(unet_fm.parameters(), lr=LR); opt_ddpm = torch.optim.AdamW(unet_ddpm.parameters(), lr=LR)
history = {'vae': [], 'fm': [], 'ddpm': []}

print(f"Rozpoczynam trening do {NEW_EPOCHS} epok...")

# Kontynuujemy od aktualnego stanu history
current_epoch = len(history['vae'])

for epoch in range(current_epoch + 1, NEW_EPOCHS + 1):
    l_vae, l_fm, l_ddpm = 0, 0, 0

    for batch, _ in tqdm(train_loader, desc=f"Epoch {epoch}/{NEW_EPOCHS}"):
        batch = batch.to(DEVICE)

        # 1. VAE
        opt_vae.zero_grad()
        loss_v = get_loss_vae(vae_model, batch)
        loss_v.backward()
        opt_vae.step()
        l_vae += loss_v.item()

        # 2. Flow Matching
        opt_fm.zero_grad()
        loss_f = get_loss_fm(unet_fm, batch)
        loss_f.backward()
        opt_fm.step()
        l_fm += loss_f.item()

        # 3. DDPM
        opt_ddpm.zero_grad()
        loss_d = get_loss_ddpm(unet_ddpm, batch)
        loss_d.backward()
        opt_ddpm.step()
        l_ddpm += loss_d.item()

    # Zapisywanie historii
    history['vae'].append(l_vae / len(train_loader))
    history['fm'].append(l_fm / len(train_loader))
    history['ddpm'].append(l_ddpm / len(train_loader))

    # Co 5 epok generujemy podgląd, żeby widzieć postęp
    if epoch % 10 == 0 or epoch == NEW_EPOCHS:
        print(f"Podgląd po {epoch} epokach:")
        with torch.no_grad():
            v_s = sample_vae(vae_model, n=4)
            f_s = sample_fm(unet_fm, n=4)
            d_s = sample_ddpm(unet_ddpm, n=4)

            # Szybki podgląd w konsoli
            combined = torch.cat([v_s, f_s, d_s], dim=0)
            grid = torchvision.utils.make_grid(combined, nrow=4)
            plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
            plt.title(f"Progress at epoch {epoch}")
            plt.axis('off')
            plt.show()

# --- FINALNA WERYFIKACJA ---
print("Trening zakończony. Generuję ostateczne zestawienie...")
final_vae = sample_vae(vae_model, n=16)
final_fm = sample_fm(unet_fm, n=16, steps=100) # Zwiększamy kroki dla lepszej jakości
final_ddpm = sample_ddpm(unet_ddpm, n=16)

plot_results(history, final_vae, final_fm, final_ddpm)