In [None]:
"""
HW4: Train DCGAN, WGAN / WGAN‑GP, and ACGAN on CIFAR‑10 (PyTorch)

"""

import argparse
import os
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import csv

# ----------------------------------------------------
# Config
# ----------------------------------------------------

@dataclass
class GANConfig:
    model: str = "dcgan"              # dcgan | wgan | wgan-gp | acgan
    data_root: str = "./data"
    epochs: int = 50
    batch_size: int = 128
    z_dim: int = 100
    g_lr: float = 2e-4
    d_lr: float = 2e-4
    beta1: float = 0.5
    beta2: float = 0.999
    img_size: int = 64
    channels: int = 3
    critic_iters: int = 5            # WGAN(-GP)
    clip_value: float = 0.01         # WGAN
    gp_lambda: float = 10.0          # WGAN-GP
    num_workers: int = 4
    out_dir: str = "."
    seed: int = 42
    device: Optional[str] = None

# ----------------------------------------------------
# Utils
# ----------------------------------------------------

def set_seed(seed):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_dirs(run_name, out_dir):
    ckpt = os.path.join(out_dir, "checkpoints", run_name)
    samp = os.path.join(out_dir, "samples", run_name)
    os.makedirs(ckpt, exist_ok=True)
    os.makedirs(samp, exist_ok=True)
    return ckpt, samp


def _save_loss_plot(curves: dict, out_path: str, title: str):
    """curves: {name: [per-epoch values]}"""
    plt.figure()
    for k, v in curves.items():
        if v:
            plt.plot(range(1, len(v)+1), v, label=k)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()


def _append_csv_row(csv_file: str, headers: list, row: list):
    exists = os.path.exists(csv_file)
    with open(csv_file, 'a', newline='') as f:
        w = csv.writer(f)
        if not exists:
            w.writerow(headers)
        w.writerow(row)

# ----------------------------------------------------
# Models
# ----------------------------------------------------

class DCGANGenerator(nn.Module):
    def __init__(self, z_dim, channels=3, base=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, base*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(base*8), nn.ReLU(True),
            nn.ConvTranspose2d(base*8, base*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*4), nn.ReLU(True),
            nn.ConvTranspose2d(base*4, base*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*2), nn.ReLU(True),
            nn.ConvTranspose2d(base*2, base, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base), nn.ReLU(True),
            nn.ConvTranspose2d(base, channels, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
    def forward(self, z):
        return self.net(z.view(z.size(0), z.size(1), 1, 1))


class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=3, base=64, sigmoid_out=True):
        super().__init__()
        layers = [
            nn.Conv2d(channels, base, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False), nn.BatchNorm2d(base*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*2, base*4, 4, 2, 1, bias=False), nn.BatchNorm2d(base*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*4, base*8, 4, 2, 1, bias=False), nn.BatchNorm2d(base*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*8, 1, 4, 1, 0, bias=False),
        ]
        if sigmoid_out:
            layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).view(-1)


class ACGANGenerator(nn.Module):
    def __init__(self, z_dim, n_classes=10, channels=3, base=64, emb_dim=50):
        super().__init__()
        self.embed = nn.Embedding(n_classes, emb_dim)
        self.proj = nn.Linear(z_dim + emb_dim, z_dim)
        self.g = DCGANGenerator(z_dim, channels, base)
    def forward(self, z, y):
        e = self.embed(y)
        zc = self.proj(torch.cat([z, e], 1))
        return self.g(zc)


class ACGANDiscriminator(nn.Module):
    def __init__(self, n_classes=10, channels=3, base=64):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(channels, base, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False), nn.BatchNorm2d(base*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*2, base*4, 4, 2, 1, bias=False), nn.BatchNorm2d(base*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*4, base*8, 4, 2, 1, bias=False), nn.BatchNorm2d(base*8), nn.LeakyReLU(0.2, True),
        )
        self.adv = nn.Sequential(nn.Conv2d(base*8, 1, 4, 1, 0, bias=False), nn.Sigmoid())
        self.cls = nn.Conv2d(base*8, n_classes, 4, 1, 0, bias=False)
    def forward(self, x):
        f = self.f(x)
        return self.adv(f).view(-1), self.cls(f).view(x.size(0), -1)

# ----------------------------------------------------
# Loss helpers
# ----------------------------------------------------

def gradient_penalty(C, real, fake, device):
    bs = real.size(0)
    eps = torch.rand(bs, 1, 1, 1, device=device)
    xhat = eps * real + (1 - eps) * fake
    xhat.requires_grad_(True)
    pred = C(xhat)
    grads = torch.autograd.grad(pred, xhat, torch.ones_like(pred), True, True)[0]
    return ((grads.view(bs, -1).norm(2, 1) - 1)**2).mean()

# ----------------------------------------------------
# Data
# ----------------------------------------------------

def get_loader(root, size, batch, workers):
    tfm = transforms.Compose([
        transforms.Resize(size), transforms.CenterCrop(size),
        transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    ds = datasets.CIFAR10(root=root, train=True, download=True, transform=tfm)
    return DataLoader(ds, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True)

# ----------------------------------------------------
# Training loops (DCGAN / WGAN / ACGAN)
# ----------------------------------------------------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, 'bias', None) is not None and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.zeros_(m.bias)


def train_dcgan(cfg):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = DCGANGenerator(cfg.z_dim).to(device)
    D = DCGANDiscriminator().to(device)
    G.apply(weights_init); D.apply(weights_init)
    opt_g = optim.Adam(G.parameters(), lr=cfg.g_lr, betas=(cfg.beta1, cfg.beta2))
    opt_d = optim.Adam(D.parameters(), lr=cfg.d_lr, betas=(cfg.beta1, cfg.beta2))
    bce = nn.BCELoss()
    fixed_z = torch.randn(64, cfg.z_dim, device=device)

    run_name = f"dcgan_{cfg.img_size}"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    d_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        d_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, _ in loader:
            real = real.to(device)
            bs = real.size(0)

            # --- D
            z = torch.randn(bs, cfg.z_dim, device=device)
            fake = G(z).detach()
            loss_d = (
                bce(D(real), torch.ones(bs, device=device))
                + bce(D(fake), torch.zeros(bs, device=device))
            )
            opt_d.zero_grad(); loss_d.backward(); opt_d.step()

            # --- G
            z = torch.randn(bs, cfg.z_dim, device=device)
            gen = G(z)
            loss_g = bce(D(gen), torch.ones(bs, device=device))
            opt_g.zero_grad(); loss_g.backward(); opt_g.step()

            d_sum += loss_d.item(); g_sum += loss_g.item(); n_batches += 1

        d_epoch = d_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        d_losses.append(d_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','D_loss','G_loss'], [ep, d_epoch, g_epoch])

        print(f"[DCGAN] Ep{ep} D:{d_epoch:.4f} G:{g_epoch:.4f}")
        save_grid(G, fixed_z, samp, ep)
        _save_loss_plot({'D': d_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), 'DCGAN Loss')
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(D.state_dict(), os.path.join(ckpt, f"D_{ep:04d}.pt"))

    return run_name, {'D': d_losses, 'G': g_losses}


def train_wgan(cfg, gp=False):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = DCGANGenerator(cfg.z_dim).to(device)
    C = DCGANDiscriminator(sigmoid_out=False).to(device)  # critic
    G.apply(weights_init); C.apply(weights_init)

    opt_g = optim.RMSprop(G.parameters(), lr=5e-5) if not gp else optim.Adam(G.parameters(), lr=1e-4, betas=(0.0, 0.9))
    opt_c = optim.RMSprop(C.parameters(), lr=5e-5) if not gp else optim.Adam(C.parameters(), lr=1e-4, betas=(0.0, 0.9))

    fixed_z = torch.randn(64, cfg.z_dim, device=device)
    run_name = "wgan-gp" if gp else "wgan"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    c_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        c_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, _ in loader:
            real = real.to(device)
            bs = real.size(0)

            # critic updates
            for _ in range(cfg.critic_iters):
                z = torch.randn(bs, cfg.z_dim, device=device)
                fake = G(z).detach()
                C.zero_grad()
                loss_c = -(C(real).mean() - C(fake).mean())
                if gp:
                    loss_c = loss_c + cfg.gp_lambda * gradient_penalty(C, real, fake, device)
                loss_c.backward(); opt_c.step()
                if not gp:
                    for p in C.parameters():
                        p.data.clamp_(-cfg.clip_value, cfg.clip_value)

            # generator update
            z = torch.randn(bs, cfg.z_dim, device=device)
            G.zero_grad()
            loss_g = -C(G(z)).mean()
            loss_g.backward(); opt_g.step()

            c_sum += loss_c.item(); g_sum += loss_g.item(); n_batches += 1

        c_epoch = c_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        c_losses.append(c_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','Critic_loss','G_loss'], [ep, c_epoch, g_epoch])

        print(f"[{'WGAN-GP' if gp else 'WGAN'}] Ep{ep} C:{c_epoch:.4f} G:{g_epoch:.4f}")
        save_grid(G, fixed_z, samp, ep)
        _save_loss_plot({'Critic': c_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), f"{'WGAN-GP' if gp else 'WGAN'} Loss")
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(C.state_dict(), os.path.join(ckpt, f"C_{ep:04d}.pt"))

    return run_name, {'Critic': c_losses, 'G': g_losses}


def train_acgan(cfg):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = ACGANGenerator(cfg.z_dim).to(device)
    D = ACGANDiscriminator().to(device)
    G.apply(weights_init); D.apply(weights_init)

    opt_g = optim.Adam(G.parameters(), lr=cfg.g_lr, betas=(cfg.beta1, cfg.beta2))
    opt_d = optim.Adam(D.parameters(), lr=cfg.d_lr, betas=(cfg.beta1, cfg.beta2))
    adv_loss = nn.BCELoss(); cls_loss = nn.CrossEntropyLoss()

    fixed_z = torch.randn(64, cfg.z_dim, device=device)
    fixed_y = torch.tensor([i % 10 for i in range(64)], dtype=torch.long, device=device)

    run_name = "acgan"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    d_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        d_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, labels in loader:
            real = real.to(device); labels = labels.to(device)
            bs = real.size(0)

            # --- D: adv + aux cls
            valid = torch.ones(bs, device=device); fakef = torch.zeros(bs, device=device)
            D.zero_grad()
            adv_r, cls_r = D(real)
            d_adv_r = adv_loss(adv_r, valid)
            d_cls_r = cls_loss(cls_r, labels)

            z = torch.randn(bs, cfg.z_dim, device=device)
            y_fake = torch.randint(0, 10, (bs,), device=device)
            x_fake = G(z, y_fake).detach()
            adv_f, cls_f = D(x_fake)
            d_adv_f = adv_loss(adv_f, fakef)
            d_cls_f = cls_loss(cls_f, y_fake)

            d_loss = d_adv_r + d_adv_f + 0.5 * (d_cls_r + d_cls_f)
            d_loss.backward(); opt_d.step()

            # --- G: fool + correct class
            G.zero_grad()
            z = torch.randn(bs, cfg.z_dim, device=device)
            y = torch.randint(0, 10, (bs,), device=device)
            gen = G(z, y)
            adv_p, cls_p = D(gen)
            g_loss = adv_loss(adv_p, valid) + cls_loss(cls_p, y)
            g_loss.backward(); opt_g.step()

            d_sum += d_loss.item(); g_sum += g_loss.item(); n_batches += 1

        d_epoch = d_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        d_losses.append(d_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','D_loss','G_loss'], [ep, d_epoch, g_epoch])

        print(f"[ACGAN] Ep{ep} D:{d_epoch:.4f} G:{g_epoch:.4f}")
        with torch.no_grad():
            fake_samples = G(fixed_z, fixed_y)
            grid = utils.make_grid(fake_samples, nrow=8, normalize=True, value_range=(-1, 1))
            utils.save_image(grid, os.path.join(samp, f"epoch_{ep:04d}.png"))
        _save_loss_plot({'D': d_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), 'ACGAN Loss')
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(D.state_dict(), os.path.join(ckpt, f"D_{ep:04d}.pt"))

    return run_name, {'D': d_losses, 'G': g_losses}

# ----------------------------------------------------
# Misc helpers & tests
# ----------------------------------------------------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, 'bias', None) is not None and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.zeros_(m.bias)


def save_grid(G, z, sample_dir, epoch):
    with torch.no_grad():
        fake = G(z)
        grid = utils.make_grid(fake, nrow=8, normalize=True, value_range=(-1, 1))
        utils.save_image(grid, os.path.join(sample_dir, f"epoch_{epoch:04d}.png"))


def self_test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    z_dim, bs = 100, 16

    G = DCGANGenerator(z_dim).to(device)
    D = DCGANDiscriminator().to(device)
    z = torch.randn(bs, z_dim, device=device)
    x_fake = G(z)
    assert x_fake.shape == (bs, 3, 64, 64)
    assert D(x_fake).shape == (bs,)

    C = DCGANDiscriminator(sigmoid_out=False).to(device)
    assert C(x_fake).shape == (bs,)

    Gc = ACGANGenerator(z_dim).to(device)
    Dc = ACGANDiscriminator().to(device)
    y = torch.randint(0, 10, (bs,), device=device)
    x_fake_c = Gc(z, y)
    adv, cls = Dc(x_fake_c)
    assert x_fake_c.shape == (bs, 3, 64, 64)
    assert adv.shape == (bs,)
    assert cls.shape == (bs, 10)

    print("[SELF-TEST] All shapes OK.")

# ----------------------------------------------------
# Main (Jupyter/Colab friendly)
# ----------------------------------------------------

import sys

def get_parser():
    p = argparse.ArgumentParser(add_help=True)
    p.add_argument('--model', type=str, choices=['dcgan', 'wgan', 'wgan-gp', 'acgan'], default='dcgan')
    p.add_argument('--data-root', type=str, default='./data')
    p.add_argument('--epochs', type=int, default=50)
    p.add_argument('--batch-size', type=int, default=128)
    p.add_argument('--z-dim', type=int, default=100)
    p.add_argument('--g-lr', type=float, default=2e-4)
    p.add_argument('--d-lr', type=float, default=2e-4)
    p.add_argument('--beta1', type=float, default=0.5)
    p.add_argument('--beta2', type=float, default=0.999)
    p.add_argument('--img-size', type=int, default=64)
    p.add_argument('--critic-iters', type=int, default=5)
    p.add_argument('--clip', dest='clip_value', type=float, default=0.01)
    p.add_argument('--gp', dest='gp_lambda', type=float, default=10.0)
    p.add_argument('--num-workers', type=int, default=4)
    p.add_argument('--out-dir', type=str, default='.')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', type=str, default=None)
    p.add_argument('--self-test', action='store_true')
    return p


def _cfg_from_namespace(ns: argparse.Namespace) -> GANConfig:
    return GANConfig(
        model=ns.model,
        data_root=ns.data_root,
        epochs=ns.epochs,
        batch_size=ns.batch_size,
        z_dim=ns.z_dim,
        g_lr=ns.g_lr,
        d_lr=ns.d_lr,
        beta1=ns.beta1,
        beta2=ns.beta2,
        img_size=ns.img_size,
        critic_iters=ns.critic_iters,
        clip_value=ns.clip_value,
        gp_lambda=ns.gp_lambda,
        num_workers=ns.num_workers,
        out_dir=ns.out_dir,
        seed=ns.seed,
        device=ns.device,
    )


def run(model: str='dcgan', data_root: str='./data', epochs: int=50, batch_size: int=128,
        z_dim: int=100, g_lr: float=2e-4, d_lr: float=2e-4, beta1: float=0.5, beta2: float=0.999,
        img_size: int=64, critic_iters: int=5, clip_value: float=0.01, gp_lambda: float=10.0,
        num_workers: int=4, out_dir: str='.', seed: int=42, device: Optional[str]=None,
        self_test_flag: bool=False):
    """Notebook-friendly entry point. Call this from a Jupyter cell.

    Example:
        run('dcgan', epochs=1, batch_size=64)
        run('wgan', epochs=1, critic_iters=5, clip_value=0.01)
        run('acgan', epochs=1)
        run(self_test_flag=True)
    """
    args = argparse.Namespace(
        model=model, data_root=data_root, epochs=epochs, batch_size=batch_size,
        z_dim=z_dim, g_lr=g_lr, d_lr=d_lr, beta1=beta1, beta2=beta2,
        img_size=img_size, critic_iters=critic_iters, clip_value=clip_value,
        gp_lambda=gp_lambda, num_workers=num_workers, out_dir=out_dir,
        seed=seed, device=device, self_test=self_test_flag
    )
    cfg = _cfg_from_namespace(args)
    set_seed(cfg.seed)
    if args.self_test:
        self_test();
        return None, {}
    if cfg.model == 'dcgan':
        return train_dcgan(cfg)
    elif cfg.model == 'wgan':
        return train_wgan(cfg, gp=False)
    elif cfg.model == 'wgan-gp':
        return train_wgan(cfg, gp=True)
    elif cfg.model == 'acgan':
        return train_acgan(cfg)
    else:
        raise ValueError(cfg.model)


def main():
    parser = get_parser()
    # In Jupyter, IPython passes extra flags like -f <kernel.json>.
    args, _unknown = parser.parse_known_args()

    if args.self_test:
        set_seed(0); self_test(); return

    cfg = _cfg_from_namespace(args)
    set_seed(cfg.seed)

    if cfg.model == 'dcgan':
        run_name, curves = train_dcgan(cfg)
    elif cfg.model == 'wgan':
        run_name, curves = train_wgan(cfg, gp=False)
    elif cfg.model == 'wgan-gp':
        run_name, curves = train_wgan(cfg, gp=True)
    elif cfg.model == 'acgan':
        run_name, curves = train_acgan(cfg)
    else:
        raise ValueError(cfg.model)
    # Single-run curve already saved within trainer.


# Removed auto-execution block to prevent automatic training when running the cell in Jupyter.
# if __name__ == '__main__':
#     if 'ipykernel' in sys.modules:
#         sys.argv = [sys.argv[0]]
#     main()

# -------- Notebook helpers for combined plots --------

def combined_loss_plot(models_curves: dict, out_path: str = 'samples/combined/loss_curve.png'):
    """
    models_curves: {
        'dcgan': {'G': [...], 'D': [...]},
        'wgan':  {'G': [...], 'Critic': [...]},
        'wgan-gp': {'G': [...], 'Critic': [...]},
        'acgan': {'G': [...], 'D': [...]}
    }
    Saves a 1x2 figure: left=Generator losses, right=Discriminator/Critic losses.
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    # Generator
    for name, curves in models_curves.items():
        g = curves.get('G', [])
        if g:
            axes[0].plot(range(1, len(g)+1), g, label=name)
    axes[0].set_title('Generator loss')
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss'); axes[0].legend()

    # Discriminator / Critic
    for name, curves in models_curves.items():
        d = curves.get('D', [])
        c = curves.get('Critic', [])
        y = d if d else c
        if y:
            axes[1].plot(range(1, len(y)+1), y, label=name)
    axes[1].set_title('Discriminator / Critic loss')
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Loss'); axes[1].legend()

    fig.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches='tight')
    plt.close(fig)


def capture_run(model: str, **kwargs):
    """Run a training job and return only the per-epoch loss curves.
    Example:
        dc = capture_run('dcgan', epochs=10, num_workers=1)
    """
    _, curves = run(model=model, **kwargs)
    return curves


def run_all_and_plot(epochs=10, num_workers=1, out_dir='.'):
    """Convenience: run DCGAN, WGAN, and ACGAN back-to-back and save a combined plot."""
    curves = {}
    # Reduce workers to avoid dataloader freeze on shared systems
    curves['dcgan'] = capture_run('dcgan', epochs=epochs, num_workers=num_workers, out_dir=out_dir)
    curves['wgan']  = capture_run('wgan',  epochs=epochs, num_workers=num_workers, out_dir=out_dir, critic_iters=5, clip_value=0.01)
    curves['acgan'] = capture_run('acgan', epochs=epochs, num_workers=num_workers, out_dir=out_dir)
    out_path = os.path.join(out_dir, 'samples', 'combined', 'loss_curve.png')
    combined_loss_plot(curves, out_path)
    print(f"Saved combined loss plot to: {out_path}")


In [None]:
dc = capture_run("dcgan", epochs=10, num_workers=1)
wg = capture_run("wgan", epochs=10, num_workers=1)
ac = capture_run("acgan", epochs=10, num_workers=1)

combined_loss_plot(
    {"dcgan": dc, "wgan": wg, "acgan": ac},
    out_path="samples/combined/loss_curve.png"
)


In [None]:
# you can run this cell if you want to see DCGAN epochs else the above 2 ceells are enough

In [None]:
import argparse
import os
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import csv

# ----------------------------------------------------
# Config
# ----------------------------------------------------

@dataclass
class GANConfig:
    model: str = "dcgan"              # dcgan | wgan | wgan-gp | acgan
    data_root: str = "./data"
    epochs: int = 50
    batch_size: int = 128
    z_dim: int = 100
    g_lr: float = 2e-4
    d_lr: float = 2e-4
    beta1: float = 0.5
    beta2: float = 0.999
    img_size: int = 64
    channels: int = 3
    critic_iters: int = 5            # WGAN(-GP)
    clip_value: float = 0.01         # WGAN
    gp_lambda: float = 10.0          # WGAN-GP
    num_workers: int = 4
    out_dir: str = "."
    seed: int = 42
    device: Optional[str] = None

# ----------------------------------------------------
# Utils
# ----------------------------------------------------

def set_seed(seed):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_dirs(run_name, out_dir):
    ckpt = os.path.join(out_dir, "checkpoints", run_name)
    samp = os.path.join(out_dir, "samples", run_name)
    os.makedirs(ckpt, exist_ok=True)
    os.makedirs(samp, exist_ok=True)
    return ckpt, samp


def _save_loss_plot(curves: dict, out_path: str, title: str):
    """curves: {name: [per-epoch values]}"""
    plt.figure()
    for k, v in curves.items():
        if v:
            plt.plot(range(1, len(v)+1), v, label=k)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()


def _append_csv_row(csv_file: str, headers: list, row: list):
    exists = os.path.exists(csv_file)
    with open(csv_file, 'a', newline='') as f:
        w = csv.writer(f)
        if not exists:
            w.writerow(headers)
        w.writerow(row)

# ----------------------------------------------------
# Models
# ----------------------------------------------------

class DCGANGenerator(nn.Module):
    def __init__(self, z_dim, channels=3, base=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, base*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(base*8), nn.ReLU(True),
            nn.ConvTranspose2d(base*8, base*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*4), nn.ReLU(True),
            nn.ConvTranspose2d(base*4, base*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*2), nn.ReLU(True),
            nn.ConvTranspose2d(base*2, base, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base), nn.ReLU(True),
            nn.ConvTranspose2d(base, channels, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
    def forward(self, z):
        return self.net(z.view(z.size(0), z.size(1), 1, 1))


class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=3, base=64, sigmoid_out=True):
        super().__init__()
        layers = [
            nn.Conv2d(channels, base, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False), nn.BatchNorm2d(base*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*2, base*4, 4, 2, 1, bias=False), nn.BatchNorm2d(base*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*4, base*8, 4, 2, 1, bias=False), nn.BatchNorm2d(base*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*8, 1, 4, 1, 0, bias=False),
        ]
        if sigmoid_out:
            layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).view(-1)


class ACGANGenerator(nn.Module):
    def __init__(self, z_dim, n_classes=10, channels=3, base=64, emb_dim=50):
        super().__init__()
        self.embed = nn.Embedding(n_classes, emb_dim)
        self.proj = nn.Linear(z_dim + emb_dim, z_dim)
        self.g = DCGANGenerator(z_dim, channels, base)
    def forward(self, z, y):
        e = self.embed(y)
        zc = self.proj(torch.cat([z, e], 1))
        return self.g(zc)


class ACGANDiscriminator(nn.Module):
    def __init__(self, n_classes=10, channels=3, base=64):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(channels, base, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False), nn.BatchNorm2d(base*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*2, base*4, 4, 2, 1, bias=False), nn.BatchNorm2d(base*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(base*4, base*8, 4, 2, 1, bias=False), nn.BatchNorm2d(base*8), nn.LeakyReLU(0.2, True),
        )
        self.adv = nn.Sequential(nn.Conv2d(base*8, 1, 4, 1, 0, bias=False), nn.Sigmoid())
        self.cls = nn.Conv2d(base*8, n_classes, 4, 1, 0, bias=False)
    def forward(self, x):
        f = self.f(x)
        return self.adv(f).view(-1), self.cls(f).view(x.size(0), -1)

# ----------------------------------------------------
# Loss helpers
# ----------------------------------------------------

def gradient_penalty(C, real, fake, device):
    bs = real.size(0)
    eps = torch.rand(bs, 1, 1, 1, device=device)
    xhat = eps * real + (1 - eps) * fake
    xhat.requires_grad_(True)
    pred = C(xhat)
    grads = torch.autograd.grad(pred, xhat, torch.ones_like(pred), True, True)[0]
    return ((grads.view(bs, -1).norm(2, 1) - 1)**2).mean()

# ----------------------------------------------------
# Data
# ----------------------------------------------------

def get_loader(root, size, batch, workers):
    tfm = transforms.Compose([
        transforms.Resize(size), transforms.CenterCrop(size),
        transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    ds = datasets.CIFAR10(root=root, train=True, download=True, transform=tfm)
    return DataLoader(ds, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True)

# ----------------------------------------------------
# Training loops (DCGAN / WGAN / ACGAN)
# ----------------------------------------------------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, 'bias', None) is not None and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.zeros_(m.bias)


def train_dcgan(cfg):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = DCGANGenerator(cfg.z_dim).to(device)
    D = DCGANDiscriminator().to(device)
    G.apply(weights_init); D.apply(weights_init)
    opt_g = optim.Adam(G.parameters(), lr=cfg.g_lr, betas=(cfg.beta1, cfg.beta2))
    opt_d = optim.Adam(D.parameters(), lr=cfg.d_lr, betas=(cfg.beta1, cfg.beta2))
    bce = nn.BCELoss()
    fixed_z = torch.randn(64, cfg.z_dim, device=device)

    run_name = f"dcgan_{cfg.img_size}"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    d_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        d_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, _ in loader:
            real = real.to(device)
            bs = real.size(0)

            # --- D
            z = torch.randn(bs, cfg.z_dim, device=device)
            fake = G(z).detach()
            loss_d = (
                bce(D(real), torch.ones(bs, device=device))
                + bce(D(fake), torch.zeros(bs, device=device))
            )
            opt_d.zero_grad(); loss_d.backward(); opt_d.step()

            # --- G
            z = torch.randn(bs, cfg.z_dim, device=device)
            gen = G(z)
            loss_g = bce(D(gen), torch.ones(bs, device=device))
            opt_g.zero_grad(); loss_g.backward(); opt_g.step()

            d_sum += loss_d.item(); g_sum += loss_g.item(); n_batches += 1

        d_epoch = d_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        d_losses.append(d_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','D_loss','G_loss'], [ep, d_epoch, g_epoch])

        print(f"[DCGAN] Ep{ep} D:{d_epoch:.4f} G:{g_epoch:.4f}")
        save_grid(G, fixed_z, samp, ep)
        _save_loss_plot({'D': d_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), 'DCGAN Loss')
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(D.state_dict(), os.path.join(ckpt, f"D_{ep:04d}.pt"))

    return run_name, {'D': d_losses, 'G': g_losses}


def train_wgan(cfg, gp=False):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = DCGANGenerator(cfg.z_dim).to(device)
    C = DCGANDiscriminator(sigmoid_out=False).to(device)  # critic
    G.apply(weights_init); C.apply(weights_init)

    opt_g = optim.RMSprop(G.parameters(), lr=5e-5) if not gp else optim.Adam(G.parameters(), lr=1e-4, betas=(0.0, 0.9))
    opt_c = optim.RMSprop(C.parameters(), lr=5e-5) if not gp else optim.Adam(C.parameters(), lr=1e-4, betas=(0.0, 0.9))

    fixed_z = torch.randn(64, cfg.z_dim, device=device)
    run_name = "wgan-gp" if gp else "wgan"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    c_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        c_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, _ in loader:
            real = real.to(device)
            bs = real.size(0)

            # critic updates
            for _ in range(cfg.critic_iters):
                z = torch.randn(bs, cfg.z_dim, device=device)
                fake = G(z).detach()
                C.zero_grad()
                loss_c = -(C(real).mean() - C(fake).mean())
                if gp:
                    loss_c = loss_c + cfg.gp_lambda * gradient_penalty(C, real, fake, device)
                loss_c.backward(); opt_c.step()
                if not gp:
                    for p in C.parameters():
                        p.data.clamp_(-cfg.clip_value, cfg.clip_value)

            # generator update
            z = torch.randn(bs, cfg.z_dim, device=device)
            G.zero_grad()
            loss_g = -C(G(z)).mean()
            loss_g.backward(); opt_g.step()

            c_sum += loss_c.item(); g_sum += loss_g.item(); n_batches += 1

        c_epoch = c_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        c_losses.append(c_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','Critic_loss','G_loss'], [ep, c_epoch, g_epoch])

        print(f"[{'WGAN-GP' if gp else 'WGAN'}] Ep{ep} C:{c_epoch:.4f} G:{g_epoch:.4f}")
        save_grid(G, fixed_z, samp, ep)
        _save_loss_plot({'Critic': c_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), f"{'WGAN-GP' if gp else 'WGAN'} Loss")
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(C.state_dict(), os.path.join(ckpt, f"C_{ep:04d}.pt"))

    return run_name, {'Critic': c_losses, 'G': g_losses}


def train_acgan(cfg):
    device = torch.device(cfg.device or ("cuda" if torch.cuda.is_available() else "cpu"))
    loader = get_loader(cfg.data_root, cfg.img_size, cfg.batch_size, cfg.num_workers)
    G = ACGANGenerator(cfg.z_dim).to(device)
    D = ACGANDiscriminator().to(device)
    G.apply(weights_init); D.apply(weights_init)

    opt_g = optim.Adam(G.parameters(), lr=cfg.g_lr, betas=(cfg.beta1, cfg.beta2))
    opt_d = optim.Adam(D.parameters(), lr=cfg.d_lr, betas=(cfg.beta1, cfg.beta2))
    adv_loss = nn.BCELoss(); cls_loss = nn.CrossEntropyLoss()

    fixed_z = torch.randn(64, cfg.z_dim, device=device)
    fixed_y = torch.tensor([i % 10 for i in range(64)], dtype=torch.long, device=device)

    run_name = "acgan"
    ckpt, samp = make_dirs(run_name, cfg.out_dir)

    d_losses, g_losses = [], []
    csv_path = os.path.join(samp, 'loss_log.csv')

    for ep in range(1, cfg.epochs+1):
        d_sum = 0.0; g_sum = 0.0; n_batches = 0
        for real, labels in loader:
            real = real.to(device); labels = labels.to(device)
            bs = real.size(0)

            # --- D: adv + aux cls
            valid = torch.ones(bs, device=device); fakef = torch.zeros(bs, device=device)
            D.zero_grad()
            adv_r, cls_r = D(real)
            d_adv_r = adv_loss(adv_r, valid)
            d_cls_r = cls_loss(cls_r, labels)

            z = torch.randn(bs, cfg.z_dim, device=device)
            y_fake = torch.randint(0, 10, (bs,), device=device)
            x_fake = G(z, y_fake).detach()
            adv_f, cls_f = D(x_fake)
            d_adv_f = adv_loss(adv_f, fakef)
            d_cls_f = cls_loss(cls_f, y_fake)

            d_loss = d_adv_r + d_adv_f + 0.5 * (d_cls_r + d_cls_f)
            d_loss.backward(); opt_d.step()

            # --- G: fool + correct class
            G.zero_grad()
            z = torch.randn(bs, cfg.z_dim, device=device)
            y = torch.randint(0, 10, (bs,), device=device)
            gen = G(z, y)
            adv_p, cls_p = D(gen)
            g_loss = adv_loss(adv_p, valid) + cls_loss(cls_p, y)
            g_loss.backward(); opt_g.step()

            d_sum += d_loss.item(); g_sum += g_loss.item(); n_batches += 1

        d_epoch = d_sum / max(1, n_batches)
        g_epoch = g_sum / max(1, n_batches)
        d_losses.append(d_epoch); g_losses.append(g_epoch)
        _append_csv_row(csv_path, ['epoch','D_loss','G_loss'], [ep, d_epoch, g_epoch])

        print(f"[ACGAN] Ep{ep} D:{d_epoch:.4f} G:{g_epoch:.4f}")
        with torch.no_grad():
            fake_samples = G(fixed_z, fixed_y)
            grid = utils.make_grid(fake_samples, nrow=8, normalize=True, value_range=(-1, 1))
            utils.save_image(grid, os.path.join(samp, f"epoch_{ep:04d}.png"))
        _save_loss_plot({'D': d_losses, 'G': g_losses}, os.path.join(samp, 'loss_curve.png'), 'ACGAN Loss')
        torch.save(G.state_dict(), os.path.join(ckpt, f"G_{ep:04d}.pt"))
        torch.save(D.state_dict(), os.path.join(ckpt, f"D_{ep:04d}.pt"))

    return run_name, {'D': d_losses, 'G': g_losses}

# ----------------------------------------------------
# Misc helpers & tests
# ----------------------------------------------------

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, 'bias', None) is not None and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.zeros_(m.bias)


def save_grid(G, z, sample_dir, epoch):
    with torch.no_grad():
        fake = G(z)
        grid = utils.make_grid(fake, nrow=8, normalize=True, value_range=(-1, 1))
        utils.save_image(grid, os.path.join(sample_dir, f"epoch_{epoch:04d}.png"))


def self_test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    z_dim, bs = 100, 16

    G = DCGANGenerator(z_dim).to(device)
    D = DCGANDiscriminator().to(device)
    z = torch.randn(bs, z_dim, device=device)
    x_fake = G(z)
    assert x_fake.shape == (bs, 3, 64, 64)
    assert D(x_fake).shape == (bs,)

    C = DCGANDiscriminator(sigmoid_out=False).to(device)
    assert C(x_fake).shape == (bs,)

    Gc = ACGANGenerator(z_dim).to(device)
    Dc = ACGANDiscriminator().to(device)
    y = torch.randint(0, 10, (bs,), device=device)
    x_fake_c = Gc(z, y)
    adv, cls = Dc(x_fake_c)
    assert x_fake_c.shape == (bs, 3, 64, 64)
    assert adv.shape == (bs,)
    assert cls.shape == (bs, 10)

    print("[SELF-TEST] All shapes OK.")

# ----------------------------------------------------
# Main (Jupyter/Colab friendly)
# ----------------------------------------------------

import sys

def get_parser():
    p = argparse.ArgumentParser(add_help=True)
    p.add_argument('--model', type=str, choices=['dcgan', 'wgan', 'wgan-gp', 'acgan'], default='dcgan')
    p.add_argument('--data-root', type=str, default='./data')
    p.add_argument('--epochs', type=int, default=50)
    p.add_argument('--batch-size', type=int, default=128)
    p.add_argument('--z-dim', type=int, default=100)
    p.add_argument('--g-lr', type=float, default=2e-4)
    p.add_argument('--d-lr', type=float, default=2e-4)
    p.add_argument('--beta1', type=float, default=0.5)
    p.add_argument('--beta2', type=float, default=0.999)
    p.add_argument('--img-size', type=int, default=64)
    p.add_argument('--critic-iters', type=int, default=5)
    p.add_argument('--clip', dest='clip_value', type=float, default=0.01)
    p.add_argument('--gp', dest='gp_lambda', type=float, default=10.0)
    p.add_argument('--num-workers', type=int, default=4)
    p.add_argument('--out-dir', type=str, default='.')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', type=str, default=None)
    p.add_argument('--self-test', action='store_true')
    return p


def _cfg_from_namespace(ns: argparse.Namespace) -> GANConfig:
    return GANConfig(
        model=ns.model,
        data_root=ns.data_root,
        epochs=ns.epochs,
        batch_size=ns.batch_size,
        z_dim=ns.z_dim,
        g_lr=ns.g_lr,
        d_lr=ns.d_lr,
        beta1=ns.beta1,
        beta2=ns.beta2,
        img_size=ns.img_size,
        critic_iters=ns.critic_iters,
        clip_value=ns.clip_value,
        gp_lambda=ns.gp_lambda,
        num_workers=ns.num_workers,
        out_dir=ns.out_dir,
        seed=ns.seed,
        device=ns.device,
    )


def run(model: str='dcgan', data_root: str='./data', epochs: int=50, batch_size: int=128,
        z_dim: int=100, g_lr: float=2e-4, d_lr: float=2e-4, beta1: float=0.5, beta2: float=0.999,
        img_size: int=64, critic_iters: int=5, clip_value: float=0.01, gp_lambda: float=10.0,
        num_workers: int=4, out_dir: str='.', seed: int=42, device: Optional[str]=None,
        self_test_flag: bool=False):
    """Notebook-friendly entry point. Call this from a Jupyter cell.

    Example:
        run('dcgan', epochs=1, batch_size=64)
        run('wgan', epochs=1, critic_iters=5, clip_value=0.01)
        run('acgan', epochs=1)
        run(self_test_flag=True)
    """
    args = argparse.Namespace(
        model=model, data_root=data_root, epochs=epochs, batch_size=batch_size,
        z_dim=z_dim, g_lr=g_lr, d_lr=d_lr, beta1=beta1, beta2=beta2,
        img_size=img_size, critic_iters=critic_iters, clip_value=clip_value,
        gp_lambda=gp_lambda, num_workers=num_workers, out_dir=out_dir,
        seed=seed, device=device, self_test=self_test_flag
    )
    cfg = _cfg_from_namespace(args)
    set_seed(cfg.seed)
    if args.self_test:
        self_test();
        return None, {}
    if cfg.model == 'dcgan':
        return train_dcgan(cfg)
    elif cfg.model == 'wgan':
        return train_wgan(cfg, gp=False)
    elif cfg.model == 'wgan-gp':
        return train_wgan(cfg, gp=True)
    elif cfg.model == 'acgan':
        return train_acgan(cfg)
    else:
        raise ValueError(cfg.model)


def main():
    parser = get_parser()
    # In Jupyter, IPython passes extra flags like -f <kernel.json>.
    args, _unknown = parser.parse_known_args()

    if args.self_test:
        set_seed(0); self_test(); return

    cfg = _cfg_from_namespace(args)
    set_seed(cfg.seed)

    if cfg.model == 'dcgan':
        run_name, curves = train_dcgan(cfg)
    elif cfg.model == 'wgan':
        run_name, curves = train_wgan(cfg, gp=False)
    elif cfg.model == 'wgan-gp':
        run_name, curves = train_wgan(cfg, gp=True)
    elif cfg.model == 'acgan':
        run_name, curves = train_acgan(cfg)
    else:
        raise ValueError(cfg.model)
    # Single-run curve already saved within trainer.


if __name__ == '__main__':
    if 'ipykernel' in sys.modules:
        sys.argv = [sys.argv[0]]  # ignore notebook args like -f kernel.json
    main()

# -------- Notebook helpers for combined plots --------

def combined_loss_plot(models_curves: dict, out_path: str = 'samples/combined/loss_curve.png'):
    """
    models_curves: {
        'dcgan': {'G': [...], 'D': [...]},
        'wgan':  {'G': [...], 'Critic': [...]},
        'wgan-gp': {'G': [...], 'Critic': [...]},
        'acgan': {'G': [...], 'D': [...]}
    }
    Saves a 1x2 figure: left=Generator losses, right=Discriminator/Critic losses.
    """
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    # Generator
    for name, curves in models_curves.items():
        g = curves.get('G', [])
        if g:
            axes[0].plot(range(1, len(g)+1), g, label=name)
    axes[0].set_title('Generator loss')
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss'); axes[0].legend()

    # Discriminator / Critic
    for name, curves in models_curves.items():
        d = curves.get('D', [])
        c = curves.get('Critic', [])
        y = d if d else c
        if y:
            axes[1].plot(range(1, len(y)+1), y, label=name)
    axes[1].set_title('Discriminator / Critic loss')
    axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Loss'); axes[1].legend()

    fig.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, bbox_inches='tight')
    plt.close(fig)


def capture_run(model: str, **kwargs):
    """Run a training job and return only the per-epoch loss curves.
    Example:
        dc = capture_run('dcgan', epochs=10, num_workers=1)
    """
    _, curves = run(model=model, **kwargs)
    return curves


def run_all_and_plot(epochs=10, num_workers=1, out_dir='.'):
    """Convenience: run DCGAN, WGAN, and ACGAN back-to-back and save a combined plot."""
    curves = {}
    # Reduce workers to avoid dataloader freeze on shared systems
    curves['dcgan'] = capture_run('dcgan', epochs=epochs, num_workers=num_workers, out_dir=out_dir)
    curves['wgan']  = capture_run('wgan',  epochs=epochs, num_workers=num_workers, out_dir=out_dir, critic_iters=5, clip_value=0.01)
    curves['acgan'] = capture_run('acgan', epochs=epochs, num_workers=num_workers, out_dir=out_dir)
    out_path = os.path.join(out_dir, 'samples', 'combined', 'loss_curve.png')
    combined_loss_plot(curves, out_path)
    print(f"Saved combined loss plot to: {out_path}")
