# Serbian Dish Image Generator - GAN Implementation

This notebook implements a conditional GAN system for generating Serbian food images from text prompts using CLIP embeddings. It combines the functionality of the individual scripts (`dataset.py`, `models.py`, `train_with_plotting.py`, etc.) into a single interactive environment.

## 1. Imports and Configuration

In [1]:
import os
import glob
import math
import json
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image, make_grid

# Determine device
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


## 2. Dataset Loading (`dataset.py`)

In [2]:
IMG_EXTS = ('.jpg', '.jpeg', '.png', '.webp')

class CaptionImageSet(Dataset):
    def __init__(self, root="data/processed", size=128, embeddings_dir="embedds"):
        img_dir = os.path.join(root, "images")
        # Allow custom embeddings directory name
        emb_dir = os.path.join(root, embeddings_dir)

        ids = []
        # Check if directories exist before globbing
        if os.path.exists(img_dir) and os.path.exists(emb_dir):
            for p in glob.glob(os.path.join(img_dir, "*")):
                base, ext = os.path.splitext(os.path.basename(p))
                if ext.lower() in IMG_EXTS and os.path.exists(os.path.join(emb_dir, base + ".npy")):
                    ids.append(base)
            ids.sort()
        else:
            print(f"Warning: directories {img_dir} or {emb_dir} not found.")

        if not ids and os.path.exists(img_dir):
            print(f"Warning: No image/embedding pairs found in {root}")
            # raise RuntimeError(f"No image/embedding pairs found. Check {img_dir} and {emb_dir}.")

        self.ids = ids
        self.img_dir, self.emb_dir = img_dir, emb_dir

        if ids:
            first_emb_path = os.path.join(emb_dir, ids[0] + ".npy")
            try:
                first_emb = np.load(first_emb_path)
                print(f"Dataset info: {len(ids)} samples, embedding dim: {first_emb.shape}")
            except:
                print(f"Dataset info: {len(ids)} samples")

        self.tf = T.Compose([
            T.Resize(size, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(size),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)  # [-1,1]
        ])

    def __len__(self): return len(self.ids)

    def __getitem__(self, i):
        id_ = self.ids[i]

        # Improved image loading with proper error handling
        try:
            img_path = os.path.join(self.img_dir, id_ + ".jpg")
            if os.path.exists(img_path):
                img = Image.open(img_path).convert("RGB")
            else:
                # Look for other supported formats
                alt_paths = [os.path.join(self.img_dir, id_ + ext) for ext in IMG_EXTS]
                found_path = next((p for p in alt_paths if os.path.exists(p)), None)

                if found_path is None:
                    raise FileNotFoundError(f"No image found for ID '{id_}' in {self.img_dir}")

                img = Image.open(found_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Failed to load image for ID '{id_}': {str(e)}")

        # Improved embedding loading with error handling
        try:
            emb_path = os.path.join(self.emb_dir, id_ + ".npy")
            if not os.path.exists(emb_path):
                raise FileNotFoundError(f"No embedding found for ID '{id_}' at {emb_path}")

            e = np.load(emb_path).astype("float32")

            if e.size == 0:
                raise ValueError(f"Empty embedding file for ID '{id_}'")

            # Normalize embedding
            norm = np.linalg.norm(e)
            if norm < 1e-8:
                raise ValueError(f"Zero or near-zero embedding norm for ID '{id_}'")

            e /= norm
            e = torch.from_numpy(e)
        except Exception as e:
            raise RuntimeError(f"Failed to load embedding for ID '{id_}': {str(e)}")

        # Apply image transform
        try:
            x = self.tf(img)
        except Exception as e:
            raise RuntimeError(f"Failed to transform image for ID '{id_}': {str(e)}")

        return x, e

## 3. Differentiable Augmentation (`diffaug.py`)

In [3]:
AUG_POLICIES = ["color", "translation", "cutout"]

def rand_brightness(x): 
    return x + (torch.rand(x.size(0), 1, 1, 1, device=x.device) - 0.5)

def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    return (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, device=x.device) * 2) + x_mean

def rand_contrast(x):
    x_mean = x.mean(dim=(1,2,3), keepdim=True)
    return (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, device=x.device) + 0.5) + x_mean

def color(x):
    for fn in (rand_brightness, rand_saturation, rand_contrast):
        x = fn(x)
    return x

def translation(x, ratio=0.125):
    B, C, H, W = x.shape
    shift_x = int(H * ratio + 0.5)
    shift_y = int(W * ratio + 0.5)
    
    translation_x = torch.randint(-shift_x, shift_x + 1, size=(B,), device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=(B,), device=x.device)
    
    grid_x = torch.arange(W, dtype=torch.float32, device=x.device).view(1, 1, W).repeat(B, H, 1)
    grid_y = torch.arange(H, dtype=torch.float32, device=x.device).view(1, H, 1).repeat(B, 1, W)
    
    grid_x = grid_x + translation_x.view(B, 1, 1)
    grid_y = grid_y + translation_y.view(B, 1, 1)
    
    grid_x = 2.0 * grid_x / (W - 1) - 1.0
    grid_y = 2.0 * grid_y / (H - 1) - 1.0
    
    grid = torch.stack([grid_x, grid_y], dim=-1)
    x_translated = F.grid_sample(x, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    return x_translated

def cutout(x, ratio=0.5):
    B, C, H, W = x.shape
    cut_h, cut_w = int(H * ratio + 0.5), int(W * ratio + 0.5)
    mask = torch.ones_like(x)
    
    for i in range(B):
        cy = torch.randint(0, H, (1,), device=x.device).item()
        cx = torch.randint(0, W, (1,), device=x.device).item()
        y1 = max(0, cy - cut_h // 2)
        y2 = min(H, y1 + cut_h)
        x1 = max(0, cx - cut_w // 2)
        x2 = min(W, x1 + cut_w)
        mask[i, :, y1:y2, x1:x2] = 0
    
    return x * mask

def diff_augment(x, policies=AUG_POLICIES):
    for p in policies:
        if p == "color": x = color(x)
        elif p == "translation": x = translation(x)
        elif p == "cutout": x = cutout(x)
    return x

## 4. Exponential Moving Average (`ema.py`)

In [4]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                assert name in self.shadow
                new_avg = self.shadow[name] * self.decay + p.data * (1.0 - self.decay)
                self.shadow[name] = new_avg.clone()

    def apply_shadow(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.backup[name] = p.data.clone()
                p.data = self.shadow[name].clone()

    def copy_to(self, model):
        """Copy shadow weights to model without creating backup."""
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data = self.shadow[name].clone()

    def restore(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data = self.backup[name].clone()
        self.backup = {}

## 5. Model Architectures (`models.py`)

In [None]:
class CondBN(nn.Module):
    def __init__(self, ch, cond_dim):
        super().__init__()
        self.bn = nn.BatchNorm2d(ch, affine=False)
        self.gam = nn.Linear(cond_dim, ch)
        self.bet = nn.Linear(cond_dim, ch)
    def forward(self, x, y):
        h = self.bn(x)
        g = self.gam(y).unsqueeze(-1).unsqueeze(-1)
        b = self.bet(y).unsqueeze(-1).unsqueeze(-1)
        return h * (1 + g) + b

class ResBlockG(nn.Module):
    def __init__(self, in_c, out_c, cond_dim, up=True):
        super().__init__()
        self.up = up
        self.cbn1 = CondBN(in_c, cond_dim)
        self.cbn2 = CondBN(out_c, cond_dim)
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.skip = nn.Conv2d(in_c, out_c, 1) if in_c!=out_c else nn.Identity()
    def forward(self, x, y):
        h = self.cbn1(x, y); h = F.relu(h)
        if self.up: h = F.interpolate(h, scale_factor=2, mode="nearest")
        h = self.conv1(h)
        h = self.cbn2(h, y); h = F.relu(h); h = self.conv2(h)
        s = x
        if self.up: s = F.interpolate(s, scale_factor=2, mode="nearest")
        s = self.skip(s)
        return h + s

class Generator(nn.Module):
    def __init__(self, z_dim=128, cond_in=768, cond_hidden=256, base_ch=64, out_size=128):
        super().__init__()
        self.cond = nn.Sequential(
            nn.Linear(cond_in, 512), nn.ReLU(True),
            nn.Linear(512, cond_hidden)
        )
        self.fc = nn.Linear(z_dim + cond_hidden, 4*4*base_ch*16)
        ch = base_ch*16
        blocks, size = [], 4
        while size < out_size:
            blocks.append(ResBlockG(ch, ch//2, cond_dim=cond_hidden, up=True))
            ch //= 2; size *= 2
        self.blocks = nn.ModuleList(blocks)
        self.bn = nn.BatchNorm2d(ch, affine=True)
        self.conv_out = nn.Conv2d(ch, 3, 3, 1, 1)

    def forward(self, z, e):
        y = self.cond(e)
        h = self.fc(torch.cat([z, y], dim=1))
        ch = h.shape[1] // (4*4)
        h = h.view(-1, ch, 4, 4)
        for b in self.blocks: h = b(h, y)
        h = F.relu(self.bn(h))
        x = torch.tanh(self.conv_out(h))
        return x

class ResBlockD(nn.Module):
    def __init__(self, in_c, out_c, down=True):
        super().__init__()
        self.down = down
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.skip = nn.Conv2d(in_c, out_c, 1) if in_c!=out_c else nn.Identity()
        
    def forward(self, x):
        h = F.relu(x); h = self.conv1(h)
        h = F.relu(h); h = self.conv2(h)
        s = x
        if self.down:
            h = F.avg_pool2d(h, 2)
            s = F.avg_pool2d(s, 2)
        s = self.skip(s)
        return h + s

class Discriminator(nn.Module):
    def __init__(self, cond_in=768, base_ch=64, in_size=128):
        super().__init__()
        ch = base_ch
        blocks = [ResBlockD(3, ch, down=True)]
        size, c = in_size, ch
        while size > 4:
            blocks.append(ResBlockD(c, c*2, down=True))
            c *= 2; size //= 2
        self.blocks = nn.ModuleList(blocks)
        self.conv_out = nn.Conv2d(c, c, 3, 1, 1)
        self.lin = nn.Linear(c, 1)
        self.embed = nn.Linear(cond_in, c)

    def forward(self, x, e):
        h = x
        for b in self.blocks: h = b(h)
        h = F.relu(self.conv_out(h))
        h = torch.sum(h, dim=(2,3))
        out = self.lin(h) + torch.sum(self.embed(e) * h, dim=1, keepdim=True)
        return out

## 6. Training Logic (`train_with_plotting.py`)

In [5]:
def hinge_d(real_logits, fake_logits, mis_logits=None, mis_weight=0.5):
    loss = F.relu(1 - real_logits).mean() + F.relu(1 + fake_logits).mean()
    if mis_logits is not None:
        loss = loss + mis_weight * F.relu(1 + mis_logits).mean()
    return loss

def hinge_g(fake_logits):
    return -fake_logits.mean()

def r1_penalty(real_x, real_logits):
    grad = torch.autograd.grad(
        outputs=real_logits.sum(), inputs=real_x,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    penalty = grad.pow(2).reshape(grad.size(0), -1).sum(dim=1).mean()
    return penalty

def save_samples(path, imgs):
    grid = make_grid((imgs.clamp(-1,1)+1)/2, nrow=int(math.sqrt(imgs.size(0))+0.5))
    save_image(grid, path)

def plot_losses(log_data, output_dir):
    """Plot and save training curves."""
    if not log_data: return
    
    steps = [x['step'] for x in log_data]
    d_losses = [x['d_loss'] for x in log_data]
    g_losses = [x['g_loss'] for x in log_data]
    real_logits = [x.get('real_logits_mean', 0) for x in log_data]
    fake_logits = [x.get('fake_logits_mean', 0) for x in log_data]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    axes[0, 0].plot(steps, d_losses, label='D Loss', alpha=0.7)
    axes[0, 0].plot(steps, g_losses, label='G Loss', alpha=0.7)
    axes[0, 0].set_xlabel('Steps')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Losses')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # D loss only (zoomed)
    axes[0, 1].plot(steps, d_losses, color='blue', alpha=0.7)
    axes[0, 1].set_xlabel('Steps')
    axes[0, 1].set_ylabel('D Loss')
    axes[0, 1].set_title('Discriminator Loss')
    axes[0, 1].grid(True, alpha=0.3)
    
    # G loss only (zoomed)
    axes[1, 0].plot(steps, g_losses, color='orange', alpha=0.7)
    axes[1, 0].set_xlabel('Steps')
    axes[1, 0].set_ylabel('G Loss')
    axes[1, 0].set_title('Generator Loss')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Logits (real vs fake scores)
    axes[1, 1].plot(steps, real_logits, label='Real Logits', alpha=0.7)
    axes[1, 1].plot(steps, fake_logits, label='Fake Logits', alpha=0.7)
    axes[1, 1].set_xlabel('Steps')
    axes[1, 1].set_ylabel('Logits')
    axes[1, 1].set_title('Discriminator Outputs')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'), dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

In [6]:
def train(args):
    print(f"Using device: {device}")

    # ----- Data
    ds = CaptionImageSet(args.data_root, size=args.img_size, embeddings_dir=args.embeddings_dir)
    if len(ds) == 0:
        print("No data found. Exiting training.")
        return
        
    emb_dim = ds[0][1].numel()
    dl = DataLoader(ds, batch_size=args.batch, shuffle=True,
                    num_workers=args.num_workers, drop_last=True, pin_memory=True)

    # ----- Models
    G = Generator(z_dim=args.z_dim, cond_in=emb_dim, cond_hidden=args.cond_dim,
                  base_ch=args.base_ch, out_size=args.img_size).to(device)
    D = Discriminator(cond_in=emb_dim, base_ch=args.base_ch, in_size=args.img_size).to(device)

    optG = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.0, 0.9))
    optD = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.0, 0.9))

    ema = EMA(G, decay=args.ema)

    # Use AMP if on CUDA
    scaler = torch.amp.GradScaler('cuda', enabled=(device=="cuda"))

    # Resume from checkpoint if specified
    start_step = 0
    if args.resume_from and os.path.exists(args.resume_from):
        print(f"Loading checkpoint from {args.resume_from}")
        checkpoint = torch.load(args.resume_from, map_location=device)
        G.load_state_dict(checkpoint["G"])
        D.load_state_dict(checkpoint["D"])
        optG.load_state_dict(checkpoint["optG"])
        optD.load_state_dict(checkpoint["optD"])
        ema.shadow = checkpoint["ema"]
        start_step = checkpoint["step"]
        print(f"Resumed from step {start_step}")

    os.makedirs(args.out_dir, exist_ok=True)
    
    # Fixed samples for visualization
    fixed = next(iter(dl))
    actual_n_sample = min(args.n_sample, args.batch)
    fixed_e = fixed[1][:actual_n_sample].to(device)
    fixed_z = torch.randn(actual_n_sample, args.z_dim, device=device)
    print(f"Using {actual_n_sample} samples for visualization")

    # Logging
    log_data = []
    log_file = os.path.join(args.out_dir, 'training_log.json')

    # Load existing log data if resuming
    if args.resume_from and os.path.exists(log_file):
        try:
            with open(log_file, 'r') as f:
                log_data = json.load(f)
            print(f"Loaded {len(log_data)} existing log entries")
        except:
            print("Could not load existing log data, starting fresh")
            log_data = []

    step = start_step
    pbar = tqdm(total=args.iters, desc="training", initial=start_step)

    while step < args.iters:
        for x, e in dl:
            x, e = x.to(device, non_blocking=True), e.to(device, non_blocking=True)
            B = x.size(0)

            # Create mismatched text by random permutation
            perm_idx = torch.randperm(B)
            while torch.equal(perm_idx, torch.arange(B)) and B > 1:
                perm_idx = torch.randperm(B)
            e_mis = e[perm_idx]

            # ----------------- D update -----------------
            for _ in range(args.n_disc):
                z = torch.randn(B, args.z_dim, device=device)
                with torch.no_grad():
                    x_fake = G(z, e).detach()

                xr, xf = x, x_fake
                if args.diffaugment:
                    xr = diff_augment(xr)
                    xf = diff_augment(xf)

                xr.requires_grad_(True)
                # AMP context
                with torch.amp.autocast('cuda', enabled=(device=="cuda")):
                    real_logits = D(xr, e)
                    fake_logits = D(xf, e)
                    mis_logits  = D(xr, e_mis) if args.use_mismatch else None
                    d_loss = hinge_d(real_logits, fake_logits, mis_logits, mis_weight=args.mismatch_w)

                # Compute total discriminator loss (including R1 if needed)
                total_d_loss = d_loss
                if (step % args.r1_every) == 0:
                    with torch.amp.autocast('cuda', enabled=False):
                        r1 = r1_penalty(xr, real_logits)
                    total_d_loss = d_loss + (args.r1_gamma/2) * r1

                optD.zero_grad(set_to_none=True)
                if device=="cuda":
                    scaler.scale(total_d_loss).backward()
                else:
                    total_d_loss.backward()

                if device=="cuda":
                    if args.grad_clip > 0:
                        scaler.unscale_(optD)
                        torch.nn.utils.clip_grad_norm_(D.parameters(), args.grad_clip)
                    scaler.step(optD)
                    scaler.update()
                else:
                    if args.grad_clip > 0:
                        torch.nn.utils.clip_grad_norm_(D.parameters(), args.grad_clip)
                    optD.step()

            # ----------------- G update -----------------
            z = torch.randn(B, args.z_dim, device=device)
            with torch.amp.autocast('cuda', enabled=(device=="cuda")):
                x_fake = G(z, e)
                xf = diff_augment(x_fake) if args.diffaugment else x_fake
                g_loss = hinge_g(D(xf, e))

            optG.zero_grad(set_to_none=True)
            if device=="cuda":
                scaler.scale(g_loss).backward()
                if args.grad_clip > 0:
                    scaler.unscale_(optG)
                    torch.nn.utils.clip_grad_norm_(G.parameters(), args.grad_clip)
                scaler.step(optG); scaler.update()
            else:
                g_loss.backward()
                if args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(G.parameters(), args.grad_clip)
                optG.step()

            # EMA update
            ema.update(G)

            # ----------------- Logging -----------------
            if step % args.log_every == 0:
                log_entry = {
                    'step': step,
                    'd_loss': float(d_loss.detach().cpu()),
                    'g_loss': float(g_loss.detach().cpu()),
                    'real_logits_mean': float(real_logits.mean().detach().cpu()),
                    'fake_logits_mean': float(fake_logits.mean().detach().cpu()),
                }
                log_data.append(log_entry)
                
                # Save log file
                with open(log_file, 'w') as f:
                    json.dump(log_data, f, indent=2)
                
                # Plot curves
                if step > 0 and step % args.plot_every == 0:
                    plot_losses(log_data, args.out_dir)

            # ----------------- Samples -----------------
            if (step % args.sample_every) == 0:
                ema.apply_shadow(G)
                with torch.no_grad():
                    imgs = G(fixed_z, fixed_e)
                    save_samples(os.path.join(args.out_dir, f"samples_{step:07d}.png"), imgs)
                ema.restore(G)

            # ----------------- Checkpoints -----------------
            if (step % args.ckpt_every) == 0 and step > 0:
                torch.save({
                    "G": G.state_dict(), "D": D.state_dict(),
                    "optG": optG.state_dict(), "optD": optD.state_dict(),
                    "ema": ema.shadow, "step": step,
                    "args": vars(args)
                }, os.path.join(args.out_dir, f"ckpt_{step:07d}.pt"))

            step += 1
            pbar.update(1)
            pbar.set_postfix({
                "d_loss": f"{float(d_loss.detach().cpu()):.4f}", 
                "g_loss": f"{float(g_loss.detach().cpu()):.4f}"
            })

            if step >= args.iters:
                break

    # Final plot
    plot_losses(log_data, args.out_dir)
    pbar.close()
    print(f"Training completed! Check {args.out_dir} for results.")

## 7. Run Training

In [None]:
class Args:
    data_root = "data/processed"
    embeddings_dir = "enhanced_768d_embeds"
    out_dir = "runs/cgan_notebook"
    img_size = 128
    z_dim = 128
    cond_dim = 512
    base_ch = 32
    batch = 4
    iters = 2000 
    n_disc = 1
    g_lr = 5e-5
    d_lr = 1.5e-4
    grad_clip = 1.0
    num_workers = 0 
    diffaugment = True
    use_mismatch = True
    mismatch_w = 0.5
    r1_every = 16
    r1_gamma = 2.0
    ema = 0.999
    sample_every = 200
    ckpt_every = 500
    n_sample = 8
    log_every = 10
    plot_every = 100
    resume_from = "" 

args = Args()

train(args)