In [None]:
# gan_pytorch.py
# A complete DCGAN training script from scratch (PyTorch).
# - Trains on MNIST (default) or CIFAR-10
# - Saves sample grids per epoch to ./samples
# - Saves checkpoints to ./checkpoints
#
# Usage:
#   python gan_pytorch.py --dataset mnist --epochs 20
#   python gan_pytorch.py --dataset cifar10 --epochs 50 --batch_size 128 --lr 2e-4
# Generate images from a trained checkpoint:
#   python gan_pytorch.py --generate 64 --checkpoint ./checkpoints/G_epoch_20.pt --dataset mnist
#
# Tip: For CPU only, it'll still run (just slower).

import os
import math
import random
import argparse
from pathlib import Path
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils as vutils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

import torchvision.transforms as T



def seed_all(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_dirs():
    Path("samples_wgan").mkdir(exist_ok=True)
    Path("checkpoints_wgan").mkdir(exist_ok=True)



def weights_init(m):
    """Initialize weights the DCGAN way (from the original paper)."""
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 or classname.find("ConvTranspose") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def cosine_sim(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    a_n = a / (a.norm(dim=-1, keepdim=True) + eps)
    b_n = b / (b.norm(dim=-1, keepdim=True) + eps)
    return (a_n * b_n).sum(dim=-1)

def pairwise_dot(x: torch.Tensor) -> torch.Tensor:
    return x @ x.t()


# -----------------------------
# Data
# -----------------------------
def get_data(
    dataset: str,
    image_size: int,
    batch_size: int,
    num_workers: int = 4,
) -> Tuple[torch.utils.data.DataLoader, int]:
    dataset = dataset.lower()
    if dataset not in {"mnist", "fashion-mnist", "cifar10"}:
        raise ValueError("dataset must be one of: mnist, fashion-mnist, cifar10")

    if dataset in {"mnist", "fashion-mnist"}:
        nc = 1
        mean = (0.5,)
        std = (0.5,)
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        if dataset == "mnist":
            ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
        else:
            ds = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
    else:
        nc = 3
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

    loader = torch.utils.data.DataLoader(
        ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    return loader, nc



# DCGAN

In [None]:
# -----------------------------
# Models: DCGAN-style Generator / Discriminator
# -----------------------------
class Generator(nn.Module):
    def __init__(self, nz: int, ngf: int, nc: int):
        """
        nz  = latent dim
        ngf = generator feature multiplier
        nc  = number of channels (1 for MNIST, 3 for CIFAR-10)
        """
        super().__init__()
        self.main = nn.Sequential(
            # input Z: (N, nz, 1, 1) -> (N, ngf*8, 4, 4)
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            # (N, ngf*8, 4, 4) -> (N, ngf*4, 8, 8)
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # (N, ngf*4, 8, 8) -> (N, ngf*2, 16, 16)
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # (N, ngf*2, 16, 16) -> (N, ngf, 32, 32)
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # (N, ngf, 32, 32) -> (N, nc, 64, 64)
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),  # outputs in [-1, 1]
        )

    def forward(self, z):
        return self.main(z)


class Discriminator(nn.Module):
    def __init__(self, nc: int, ndf: int):
        """
        nc  = number of channels
        ndf = discriminator feature multiplier
        """
        super().__init__()
        self.main = nn.Sequential(
            # (N, nc, 64, 64) -> (N, ndf, 32, 32)
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # -> (N, ndf*2, 16, 16)
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # -> (N, ndf*4, 8, 8)
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # -> (N, ndf*8, 4, 4)
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # -> (N, 1, 1, 1)
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # No Sigmoid: we'll use BCEWithLogitsLoss for stability
        )

    def forward(self, x):
        out = self.main(x)
        return out.view(-1)  # logits shape: (N,)

# -----------------------------
# Training
# -----------------------------
def gan_train(args):
    make_dirs()
    seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print(f"Device: {device}")

    # Data
    dataloader, nc = get_data(args.dataset, args.image_size, args.batch_size, args.num_workers)
    print(f"Dataset: {args.dataset} | nc={nc} | batches/epoch={len(dataloader)}")

    # Models
    netG = Generator(args.nz, args.ngf, nc).to(device)
    netD = Discriminator(nc, args.ndf).to(device)
    netG.apply(weights_init)
    netD.apply(weights_init)

    # Loss and Optim
    criterion = nn.BCEWithLogitsLoss()
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)  # for consistent eval

    step = 0
    for epoch in range(1, args.epochs + 1):
        for i, (real, _) in enumerate(dataloader):
            netD.train()
            netG.train()
            real = real.to(device)

            # -------------------------
            # Update D: maximize log(D(x)) + log(1 - D(G(z)))
            # -------------------------
            optimizerD.zero_grad(set_to_none=True)
            bsz = real.size(0)

            # Real
            logits_real = netD(real)
            # label smoothing on real can help a tiny bit (0.9 instead of 1.0)
            real_labels = torch.full((bsz,), 0.9, device=device)
            d_loss_real = criterion(logits_real, real_labels)

            # Fake
            noise = torch.randn(bsz, args.nz, 1, 1, device=device)
            fake = netG(noise).detach()  # detach: do not backprop through G
            logits_fake = netD(fake)
            fake_labels = torch.zeros(bsz, device=device)
            d_loss_fake = criterion(logits_fake, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizerD.step()

            # -------------------------
            # Update G: maximize log(D(G(z)))  <=> minimize BCEWithLogitsLoss(logits_fake, 1)
            # -------------------------
            optimizerG.zero_grad(set_to_none=True)
            noise = torch.randn(bsz, args.nz, 1, 1, device=device)
            gen = netG(noise)
            logits_gen = netD(gen)
            g_labels = torch.ones(bsz, device=device)
            g_loss = criterion(logits_gen, g_labels)
            g_loss.backward()
            optimizerG.step()

            # Logging
            if (i + 1) % args.log_interval == 0:
                print(
                    f"Epoch [{epoch}/{args.epochs}] "
                    f"Step [{i+1}/{len(dataloader)}] "
                    f"D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}"
                )

            # Save quick sample grids every sample_interval steps
            if step % args.sample_interval == 0:
                with torch.no_grad():
                    fake_fixed = netG(fixed_noise).cpu()
                grid = vutils.make_grid(fake_fixed, nrow=8, normalize=True, value_range=(-1, 1))
                vutils.save_image(grid, f"./samples/epoch{epoch:03d}_step{step:06d}.png")
            step += 1

        # End of epoch: save checkpoint + a clean epoch sample
        torch.save(netG.state_dict(), f"./checkpoints/G_epoch_{epoch}.pt")
        torch.save(netD.state_dict(), f"./checkpoints/D_epoch_{epoch}.pt")

        with torch.no_grad():
            fake_fixed = netG(fixed_noise).cpu()
        grid = vutils.make_grid(fake_fixed, nrow=8, normalize=True, value_range=(-1, 1))
        vutils.save_image(grid, f"./samples/epoch{epoch:03d}.png")
        print(f"[Epoch {epoch}] checkpoints saved and sample grid written.")


# -----------------------------
# Image Generation (inference)
# -----------------------------
def gan_generate(args):
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    dataloader, nc = get_data(args.dataset, args.image_size, 1, num_workers=0)  # to pick nc
    netG = Generator(args.nz, args.ngf, nc).to(device)
    if not args.checkpoint or not os.path.exists(args.checkpoint):
        raise FileNotFoundError("--checkpoint path is required for --generate mode.")

    netG.load_state_dict(torch.load(args.checkpoint, map_location=device))
    netG.eval()

    n = args.generate
    steps = math.ceil(n / 64)
    outdir = Path("generated")
    outdir.mkdir(exist_ok=True)
    all_imgs = []
    with torch.no_grad():
        for s in range(steps):
            b = min(64, n - s * 64)
            z = torch.randn(b, args.nz, 1, 1, device=device)
            fake = netG(z).cpu()
            all_imgs.append(fake)
    imgs = torch.cat(all_imgs, dim=0)
    grid = vutils.make_grid(imgs, nrow=8, normalize=True, value_range=(-1, 1))
    out_path = outdir / "gen_grid.png"
    vutils.save_image(grid, out_path)
    print(f"Generated {n} images -> {out_path}")


# -----------------------------
# Main / CLI
# -----------------------------
def parse_gan_args():
    p = argparse.ArgumentParser(description="Train a DCGAN from scratch in PyTorch.")
    p.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "fashion-mnist", "cifar10"],
                   help="which dataset to use")
    p.add_argument("--image_size", type=int, default=64, help="input/output image size (DCGAN uses 64)")
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--beta1", type=float, default=0.5, help="Adam beta1")
    p.add_argument("--nz", type=int, default=128, help="latent dim")
    p.add_argument("--ngf", type=int, default=64, help="generator feature maps")
    p.add_argument("--ndf", type=int, default=64, help="discriminator feature maps")
    p.add_argument("--log_interval", type=int, default=100)
    p.add_argument("--sample_interval", type=int, default=300, help="steps between quick sample grids")
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--cpu", action="store_true", help="force CPU even if CUDA is available")

    # generation mode
    p.add_argument("--generate", type=int, default=0, help=">0 to generate N images instead of training")
    p.add_argument("--checkpoint", type=str, default="", help="path to G checkpoint for --generate")
    return p.parse_args()



# args = parse_args()
# if args.generate and args.generate > 0:
#     generate(args)
# else:
#     train(args)


# CycleGAN

In [None]:
# cyclegan_pytorch.py
# Minimal but complete CycleGAN training on two unpaired folders (domain A, domain B).
# Uses ResNet generators + 70x70 PatchGAN discriminators, LSGAN loss, cycle + identity losses.
#
# Folder structure (examples):
#   --data_a/
#       class_dummy/  # ImageFolder needs one level; name doesn't matter
#           a1.jpg, a2.png, ...
#   --data_b/
#       class_dummy/
#           b1.jpg, b2.png, ...
#
# Usage:
#   python cyclegan_pytorch.py --data_a ./horses --data_b ./zebras --epochs 50 --img_size 256
#   (Put images under data_a/class_dummy/* and data_b/class_dummy/*)
#
# Outputs:
#   ./samples_cyclegan/   -> periodic translated samples
#   ./checkpoints_cyclegan/ -> G_AB/G_BA/D_A/D_B checkpoints per epoch



# ---------------------------
# Data
# ---------------------------
class UnpairedImageFolder(Dataset):
    def __init__(self, root_a: str, root_b: str, img_size: int):
        self.files_a = sorted([p for p in glob.glob(os.path.join(root_a, "**", "*.*"), recursive=True)
                               if p.lower().endswith((".jpg", ".jpeg", ".png"))])
        self.files_b = sorted([p for p in glob.glob(os.path.join(root_b, "**", "*.*"), recursive=True)
                               if p.lower().endswith((".jpg", ".jpeg", ".png"))])
        if len(self.files_a) == 0 or len(self.files_b) == 0:
            raise RuntimeError("Both domain folders must contain images.")
        self.len_a = len(self.files_a)
        self.len_b = len(self.files_b)

        self.tf = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __len__(self):
        # define epoch size as max of both for better mixing
        return max(self.len_a, self.len_b)

    def __getitem__(self, idx):
        a_path = self.files_a[idx % self.len_a]
        b_path = self.files_b[random.randint(0, self.len_b - 1)]
        a = self.tf(Image.open(a_path).convert("RGB"))
        b = self.tf(Image.open(b_path).convert("RGB"))
        return a, b


# ---------------------------
# Models (CycleGAN)
# ---------------------------
class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=False),
            nn.InstanceNorm2d(dim, affine=False, track_running_stats=False),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=False),
            nn.InstanceNorm2d(dim, affine=False, track_running_stats=False),
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    # c7s1-64, d128, d256, R* (6 for 128px, 9 for 256px), u128, u64, c7s1-3
    def __init__(self, in_ch=3, out_ch=3, n_res=9, ngf=64):
        super().__init__()
        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_ch, ngf, kernel_size=7, bias=False),
            nn.InstanceNorm2d(ngf, affine=False, track_running_stats=False),
            nn.ReLU(inplace=True),
        ]
        # downsample
        dim = ngf
        for _ in range(2):
            layers += [
                nn.Conv2d(dim, dim * 2, kernel_size=3, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(dim * 2, affine=False, track_running_stats=False),
                nn.ReLU(inplace=True),
            ]
            dim *= 2
        # residuals
        for _ in range(n_res):
            layers += [ResidualBlock(dim)]
        # upsample
        for _ in range(2):
            layers += [
                nn.ConvTranspose2d(dim, dim // 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(dim // 2, affine=False, track_running_stats=False),
                nn.ReLU(inplace=True),
            ]
            dim //= 2
        # output
        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(dim, out_ch, kernel_size=7),
            nn.Tanh(),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class PatchDiscriminator(nn.Module):
    # 70x70 PatchGAN (no BN in first layer; InstanceNorm afterwards)
    def __init__(self, in_ch=3, ndf=64):
        super().__init__()
        def block(in_c, out_c, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
            if norm:
                layers += [nn.InstanceNorm2d(out_c, affine=False, track_running_stats=False)]
            layers += [nn.LeakyReLU(0.2, inplace=True)]
            return layers

        self.model = nn.Sequential(
            *block(in_ch, ndf, norm=False),
            *block(ndf, ndf * 2),
            *block(ndf * 2, ndf * 4),
            # keep stride 1 in final conv(s) for patch output
            nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1),
            nn.InstanceNorm2d(ndf * 8, affine=False, track_running_stats=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 1),  # patch logits
        )

    def forward(self, x):
        return self.model(x)


def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, (nn.InstanceNorm2d,)):
        if m.weight is not None:
            nn.init.ones_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


# ---------------------------
# Training
# ---------------------------
def train_cyclegan(args):
    make_dirs()
    seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print("Device:", device)

    ds = UnpairedImageFolder(args.data_a, args.data_b, args.img_size)
    loader = DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    # models
    n_res = 9 if args.img_size >= 256 else 6
    G_AB = GeneratorResNet(3, 3, n_res=n_res, ngf=args.ngf).to(device)
    G_BA = GeneratorResNet(3, 3, n_res=n_res, ngf=args.ngf).to(device)
    D_A = PatchDiscriminator(3, ndf=args.ndf).to(device)  # distinguishes real A vs fake A (from B->A)
    D_B = PatchDiscriminator(3, ndf=args.ndf).to(device)

    G_AB.apply(init_weights)
    G_BA.apply(init_weights)
    D_A.apply(init_weights)
    D_B.apply(init_weights)

    # losses
    adv_criterion = nn.MSELoss()      # LSGAN
    cycle_criterion = nn.L1Loss()
    id_criterion = nn.L1Loss()

    # opt
    opt_G = optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=args.lr, betas=(0.5, 0.999))
    opt_D_A = optim.Adam(D_A.parameters(), lr=args.lr, betas=(0.5, 0.999))
    opt_D_B = optim.Adam(D_B.parameters(), lr=args.lr, betas=(0.5, 0.999))

    # targets for PatchGAN
    def real_like(x): return torch.ones_like(x, device=device)
    def fake_like(x): return torch.zeros_like(x, device=device)

    step = 0
    for epoch in range(1, args.epochs + 1):
        for i, (real_a, real_b) in enumerate(loader):
            real_a = real_a.to(device)
            real_b = real_b.to(device)

            # -----------------------
            #  Train Generators G_AB & G_BA
            # -----------------------
            opt_G.zero_grad(set_to_none=True)

            fake_b = G_AB(real_a)
            pred_fake_b = D_B(fake_b)
            loss_G_AB_adv = adv_criterion(pred_fake_b, real_like(pred_fake_b))

            fake_a = G_BA(real_b)
            pred_fake_a = D_A(fake_a)
            loss_G_BA_adv = adv_criterion(pred_fake_a, real_like(pred_fake_a))

            # Cycle
            rec_a = G_BA(fake_b)
            rec_b = G_AB(fake_a)
            loss_cycle = cycle_criterion(rec_a, real_a) + cycle_criterion(rec_b, real_b)
            loss_cycle = args.lambda_cyc * loss_cycle

            # Identity (optional but stabilizing)
            idt_a = G_BA(real_a)
            idt_b = G_AB(real_b)
            loss_id = id_criterion(idt_a, real_a) + id_criterion(idt_b, real_b)
            loss_id = args.lambda_id * loss_id

            loss_G = loss_G_AB_adv + loss_G_BA_adv + loss_cycle + loss_id
            loss_G.backward()
            opt_G.step()

            # -----------------------
            #  Train D_A
            # -----------------------
            opt_D_A.zero_grad(set_to_none=True)
            pred_real_a = D_A(real_a)
            loss_D_A_real = adv_criterion(pred_real_a, real_like(pred_real_a))

            with torch.no_grad():
                fake_a_det = fake_a
            pred_fake_a = D_A(fake_a_det)
            loss_D_A_fake = adv_criterion(pred_fake_a, fake_like(pred_fake_a))

            loss_D_A = 0.5 * (loss_D_A_real + loss_D_A_fake)
            loss_D_A.backward()
            opt_D_A.step()

            # -----------------------
            #  Train D_B
            # -----------------------
            opt_D_B.zero_grad(set_to_none=True)
            pred_real_b = D_B(real_b)
            loss_D_B_real = adv_criterion(pred_real_b, real_like(pred_real_b))

            with torch.no_grad():
                fake_b_det = fake_b
            pred_fake_b = D_B(fake_b_det)
            loss_D_B_fake = adv_criterion(pred_fake_b, fake_like(pred_fake_b))

            loss_D_B = 0.5 * (loss_D_B_real + loss_D_B_fake)
            loss_D_B.backward()
            opt_D_B.step()

            if (i + 1) % args.log_interval == 0:
                print(
                    f"[E{epoch}/{args.epochs}] [B{i+1}/{len(loader)}] "
                    f"G: {loss_G.item():.3f} (adv: {(loss_G_AB_adv+loss_G_BA_adv).item():.3f}, "
                    f"cyc: {loss_cycle.item():.3f}, id: {loss_id.item():.3f}) | "
                    f"D_A: {loss_D_A.item():.3f} D_B: {loss_D_B.item():.3f}"
                )

            if step % args.sample_interval == 0:
                with torch.no_grad():
                    a2b = G_AB(real_a[:4])
                    b2a = G_BA(real_b[:4])
                grid = vutils.make_grid(
                    torch.cat([real_a[:4], a2b, real_b[:4], b2a], dim=0),
                    nrow=4, normalize=True, value_range=(-1, 1)
                )
                vutils.save_image(grid, f"samples_cyclegan/e{epoch:03d}_s{step:06d}.png")
            step += 1

        # save per epoch
        torch.save(G_AB.state_dict(), f"checkpoints_cyclegan/G_AB_e{epoch}.pt")
        torch.save(G_BA.state_dict(), f"checkpoints_cyclegan/G_BA_e{epoch}.pt")
        torch.save(D_A.state_dict(),  f"checkpoints_cyclegan/D_A_e{epoch}.pt")
        torch.save(D_B.state_dict(),  f"checkpoints_cyclegan/D_B_e{epoch}.pt")
        print(f"[Epoch {epoch}] checkpoints saved.")


def parse_cyclegan_args():
    p = argparse.ArgumentParser(description="CycleGAN from scratch (PyTorch).")
    p.add_argument("--data_a", type=str, required=True, help="path to domain A images (root)")
    p.add_argument("--data_b", type=str, required=True, help="path to domain B images (root)")
    p.add_argument("--img_size", type=int, default=256)
    p.add_argument("--batch_size", type=int, default=2)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--lambda_cyc", type=float, default=10.0)
    p.add_argument("--lambda_id", type=float, default=5.0)
    p.add_argument("--ngf", type=int, default=64)
    p.add_argument("--ndf", type=int, default=64)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--log_interval", type=int, default=100)
    p.add_argument("--sample_interval", type=int, default=300)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--cpu", action="store_true")
    return p.parse_args()


# WGAN

In [None]:
# wgan_gp_pytorch.py
# WGAN-GP on MNIST/FashionMNIST/CIFAR10. DCGAN-ish generator, critic without BN, gradient penalty.
#
# Usage:
#   python wgan_gp_pytorch.py --dataset cifar10 --epochs 50 --batch_size 128
#
# Outputs:
#   ./samples_wgan/ (images)
#   ./checkpoints_wgan/ (G/C state dicts)


class Gen(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z): return self.net(z)


class Critic(nn.Module):
    # No BatchNorm in critic for WGAN-GP
    def __init__(self, nc, ndf):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x): return self.net(x).view(-1)  # scalar per sample


def gradient_penalty(critic, real, fake, device):
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, 1, device=device)
    interp = eps * real + (1 - eps) * fake
    interp.requires_grad_(True)
    scores = critic(interp)
    grads = torch.autograd.grad(
        outputs=scores, inputs=interp,
        grad_outputs=torch.ones_like(scores, device=device),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp


def get_wgan_data(name, image_size, batch_size, workers=4):
    name = name.lower()
    if name in {"mnist", "fashion-mnist"}:
        nc = 1
        tf = T.Compose([T.Resize(image_size), T.ToTensor(), T.Normalize((0.5,), (0.5,))])
        ds = datasets.MNIST if name == "mnist" else datasets.FashionMNIST
        train = ds(root="./data", train=True, download=True, transform=tf)
    elif name == "cifar10":
        nc = 3
        tf = T.Compose([T.Resize(image_size), T.ToTensor(), T.Normalize((0.5,)*3, (0.5,)*3)])
        train = datasets.CIFAR10(root="./data", train=True, download=True, transform=tf)
    else:
        raise ValueError("dataset must be mnist | fashion-mnist | cifar10")

    loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True,
                                         num_workers=workers, pin_memory=True)
    return loader, nc


def train_wgan(args):
    make_dirs()
    seed_all(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    print("Device:", device)

    loader, nc = get_wgan_data(args.dataset, args.image_size, args.batch_size, args.num_workers)
    G = Gen(args.nz, args.ngf, nc).to(device)
    C = Critic(nc, args.ndf).to(device)

    optG = optim.Adam(G.parameters(), lr=args.lr, betas=(0.0, 0.9))
    optC = optim.Adam(C.parameters(), lr=args.lr, betas=(0.0, 0.9))

    fixed = torch.randn(64, args.nz, 1, 1, device=device)

    step = 0
    for epoch in range(1, args.epochs + 1):
        for i, (real, _) in enumerate(loader):
            real = real.to(device)
            bsz = real.size(0)

            # update critic n_critic times
            for _ in range(args.n_critic):
                z = torch.randn(bsz, args.nz, 1, 1, device=device)
                fake = G(z).detach()
                optC.zero_grad(set_to_none=True)
                loss_c = (C(fake).mean() - C(real).mean())
                gp = gradient_penalty(C, real, fake, device)
                total_c = loss_c + args.lambda_gp * gp
                total_c.backward()
                optC.step()

            # update generator
            z = torch.randn(bsz, args.nz, 1, 1, device=device)
            optG.zero_grad(set_to_none=True)
            fake = G(z)
            # maximize critic(fake) -> minimize -mean
            g_loss = -C(fake).mean()
            g_loss.backward()
            optG.step()

            if (i + 1) % args.log_interval == 0:
                print(f"[E{epoch}/{args.epochs}] [B{i+1}/{len(loader)}] "
                      f"C: {total_c.item():.3f} (gp {gp.item():.3f}) | G: {g_loss.item():.3f}")

            if step % args.sample_interval == 0:
                with torch.no_grad():
                    imgs = G(fixed).cpu()
                grid = vutils.make_grid(imgs, nrow=8, normalize=True, value_range=(-1, 1))
                vutils.save_image(grid, f"samples_wgan/e{epoch:03d}_s{step:06d}.png")
            step += 1

        torch.save(G.state_dict(), f"checkpoints_wgan/G_e{epoch}.pt")
        torch.save(C.state_dict(), f"checkpoints_wgan/C_e{epoch}.pt")
        print(f"[Epoch {epoch}] checkpoints saved.")


def parse_wgan_args():
    p = argparse.ArgumentParser(description="WGAN-GP from scratch (PyTorch).")
    p.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "fashion-mnist", "cifar10"])
    p.add_argument("--image_size", type=int, default=64)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--nz", type=int, default=128)
    p.add_argument("--ngf", type=int, default=64)
    p.add_argument("--ndf", type=int, default=64)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--lambda_gp", type=float, default=10.0)
    p.add_argument("--n_critic", type=int, default=5)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--log_interval", type=int, default=100)
    p.add_argument("--sample_interval", type=int, default=300)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--cpu", action="store_true")
    return p.parse_args()

# Vec2Vec GAN

In [None]:
# vec2vec_gan.py
# Unpaired vector-to-vector translation with:
# - Output-level adversarial (M1->M2 and M2->M1)
# - Latent-level adversarial (A1 vs A2 after shared backbone T)
# - Reconstruction (R1, R2)
# - Cycle-consistency (F2(F1(x)) ~ x and F1(F2(y)) ~ y)
# - Vector Space Preservation (VSP) on pairwise dot-products
#
# Works with:
#   (A) real embeddings from .npy files (unpaired):  --data_m1 pathA.npy --data_m2 pathB.npy
#   (B) synthetic toy data (default): mixtures of Gaussians in two spaces with an affine warp.
#
# Usage (toy):
#   python vec2vec_gan.py --epochs 30 --batch_size 128
#
# Usage (real .npy):
#   python vec2vec_gan.py --data_m1 m1.npy --data_m2 m2.npy --epochs 50 --dim 768
#
# Inference:
#   After training, it prints a small demo translating a handful of vectors,
#   along with cosine similarities before/after to illustrate geometry preservation.


# --------------------------
# Dataset (unpaired)
# --------------------------
class UnpairedEmbeddingDataset(torch.utils.data.Dataset):
    """
    - If data_m1/m2 given: load .npy arrays of shape (N, d) and (M, d).
    - n1/n2: number of samples to generate for each domain.
    - mixtures: number of mixtures to generate for synthetic data.
    - std: standard deviation of the normal distribution for synthetic data.
    - Else: generate synthetic mixtures with known affine warp between spaces.
    """
    def __init__(self, dim: int, n1: int, n2: int,
                 data_m1: Optional[str] = None,
                 data_m2: Optional[str] = None,
                 mixtures: int = 4,
                 std: float = 0.6):
        if data_m1 and data_m2:
            X: np.ndarray = np.load(data_m1)  # shape (N, d)
            Y: np.ndarray = np.load(data_m2)  # shape (M, d)
            assert X.ndim == 2 and Y.ndim == 2, "Loaded arrays must be 2D."
            assert X.shape[1] == Y.shape[1], "Both spaces must have same dimension."
            self.X = torch.from_numpy(X).float()
            self.Y = torch.from_numpy(Y).float()
            self.dim = X.shape[1]
        else:
            # synthetic toy data
            self.dim = dim
            centers = torch.randn(mixtures, dim) * 3.0
            # sample m1
            idxs1 = torch.randint(0, mixtures, (n1,))
            self.X = centers[idxs1] + std * torch.randn(n1, dim)

            # create an affine mapping + nonlinearity to define m2 domain
            W = torch.randn(dim, dim)
            U, _, Vt = torch.linalg.svd(W, full_matrices=False)
            R = (U @ Vt)  # orthogonal (rotation/reflection)
            scale = 1.2
            b = torch.randn(dim) * 0.2

            # Slight nonlinearity: tanh on a few coords
            Y_base = self.X @ (scale * R).T + b
            Y_base[:, : dim // 6] = torch.tanh(Y_base[:, : dim // 6])

            # re-sample independently for unpaired setting
            idxs2 = torch.randint(0, mixtures, (n2,))
            base2 = centers[idxs2] + std * torch.randn(n2, dim)
            self.Y = base2 @ (scale * R).T + b
            self.Y[:, : dim // 6] = torch.tanh(self.Y[:, : dim // 6])

        # normalize to roughly unit norm (optional, helps numerics for dot/cos)
        self.X = nn.functional.normalize(self.X, dim=-1)
        self.Y = nn.functional.normalize(self.Y, dim=-1)

    def __len__(self):
        # Epoch length = max of both (like unpaired image datasets)
        return max(len(self.X), len(self.Y))

    def __getitem__(self, idx):
        x = self.X[idx % len(self.X)]
        y = self.Y[random.randint(0, len(self.Y) - 1)]
        return x, y


# --------------------------
# Models
# --------------------------
class MLP(nn.Module):
    def __init__(self, dims, act=nn.SiLU, layernorm=True, dropout=0.0):
        super().__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                if layernorm:
                    layers.append(nn.LayerNorm(dims[i+1]))
                layers.append(act())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class V2VResidualBlock(nn.Module):
    def __init__(self, dim, hidden_mult=2, dropout=0.0):
        super().__init__()
        hid = dim * hidden_mult
        self.f = nn.Sequential(
            nn.Linear(dim, hid),
            nn.LayerNorm(hid),
            nn.SiLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Linear(hid, dim),
        )
        self.n = nn.LayerNorm(dim)

    def forward(self, x):
        return self.n(x + self.f(x))


class V2VResidualBackbone(nn.Module):
    def __init__(self, dim, depth=4, hidden_mult=2, dropout=0.0):
        super().__init__()
        self.blocks = nn.ModuleList([V2VResidualBlock(dim, hidden_mult, dropout) for _ in range(depth)])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x


class AdapterIn(nn.Module):
    # A1/A2: d -> z
    def __init__(self, d, z):
        super().__init__()
        self.f = MLP([d, 2*z, z], act=nn.SiLU, layernorm=True)

    def forward(self, x):
        return self.f(x)


class AdapterOut(nn.Module):
    # B1/B2: z -> d
    def __init__(self, z, d):
        super().__init__()
        self.f = MLP([z, 2*z, d], act=nn.SiLU, layernorm=True)

    def forward(self, x):
        return self.f(x)


class DV2Viscriminator(nn.Module):
    # Simple MLP with spectral norm on the first layers for stability
    def __init__(self, in_dim, width=512, depth=3):
        super().__init__()
        layers = []
        d_prev = in_dim
        for i in range(depth - 1):
            lin = nn.utils.spectral_norm(nn.Linear(d_prev, width))
            layers += [lin, nn.LeakyReLU(0.2, inplace=True)]
            d_prev = width
        layers.append(nn.utils.spectral_norm(nn.Linear(d_prev, 1)))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x).squeeze(-1)  # logits


# --------------------------
# Training helpers (losses)
# --------------------------
class Losses:
    def __init__(self, lambda_rec=1.0, lambda_cc=10.0, lambda_vsp=1.0):
        self.bce = nn.BCEWithLogitsLoss()
        self.mse = nn.MSELoss()
        self.lambda_rec = lambda_rec
        self.lambda_cc = lambda_cc
        self.lambda_vsp = lambda_vsp

    def adv_d(self, D, real, fake):
        # Discriminator loss: real->1, fake->0
        return self.bce(D(real), torch.ones_like(real[:, 0] if real.ndim == 2 else real).float()) + \
               self.bce(D(fake.detach()), torch.zeros_like(fake[:, 0] if fake.ndim == 2 else fake).float())

    def adv_g(self, D, fake):
        # Generator (translator) wants fake->1
        return self.bce(D(fake), torch.ones_like(fake[:, 0] if fake.ndim == 2 else fake).float())

    def reconstruction(self, x, r1, y, r2):
        return self.mse(r1, x) + self.mse(r2, y)

    def cycle(self, x, x_cycled, y, y_cycled):
        return self.mse(x_cycled, x) + self.mse(y_cycled, y)

    def vsp(self, X, FX, Y, FY, max_subset: int = 64):
        """
        Preserve pairwise dot products within a random subset (to keep O(B^2) manageable).
        """
        def term(A, FA):
            B = A.size(0)
            if B > max_subset:
                idx = torch.randperm(B, device=A.device)[:max_subset]
                A = A[idx]
                FA = FA[idx]
            dots_src = A @ A.t()
            dots_tgt = FA @ FA.t()
            return ((dots_src - dots_tgt) ** 2).mean()

        return term(X, FX) + term(Y, FY)


# --------------------------
# Full Model container
# --------------------------
class Vec2VecGAN(nn.Module):
    """
    Vec2VecGAN model.
    Args:
        d: int, dimension of the input space.
        z: int, dimension of the latent space.
        back_depth: int, number of residual blocks in the backbone.
        disc_width: int, width of the discriminator.
        disc_depth: int, depth of the discriminator.
    """
    def __init__(self, 
        d: int, 
        z: int,
        back_depth=4,
        disc_width=512, 
        disc_depth=3
    ):
        super().__init__()
        # Adapters and backbone
        self.A1 = AdapterIn(d, z)
        self.A2 = AdapterIn(d, z)
        self.T = V2VResidualBackbone(z, depth=back_depth, hidden_mult=2)

        self.B1 = AdapterOut(z, d)
        self.B2 = AdapterOut(z, d)

        # Output-space discriminators
        self.D_M2 = Discriminator(d, width=disc_width, depth=disc_depth)  # judges F1 outputs against real M2
        self.D_M1 = Discriminator(d, width=disc_width, depth=disc_depth)  # judges F2 outputs against real M1

        # Latent discriminators (two symmetric ones)
        self.DL_1 = Discriminator(z, width=disc_width, depth=disc_depth)  # real=z2, fake=z1
        self.DL_2 = Discriminator(z, width=disc_width, depth=disc_depth)  # real=z1, fake=z2

    # convenience
    def encode1(self, x):  # M1 -> latent
        return self.T(self.A1(x))

    def encode2(self, y):  # M2 -> latent
        return self.T(self.A2(y))

    def F1(self, x):  # M1 -> M2
        return self.B2(self.encode1(x))

    def F2(self, y):  # M2 -> M1
        return self.B1(self.encode2(y))

    def R1(self, x):  # M1 -> latent -> M1
        return self.B1(self.encode1(x))

    def R2(self, y):  # M2 -> latent -> M2
        return self.B2(self.encode2(y))


# --------------------------
# Training loop
# --------------------------
def train_v2c_gan(args):
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    seed_all(args.seed)
    ds = UnpairedEmbeddingDataset(
        dim=args.dim, n1=args.n_m1, n2=args.n_m2,
        data_m1=args.data_m1, data_m2=args.data_m2,
        mixtures=args.mixtures, std=args.std
    )
    loader = torch.utils.data.DataLoader(
        ds, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=0
    )

    model = Vec2VecGAN(d=ds.dim, z=args.z, back_depth=args.back_depth,
                       disc_width=args.disc_width, disc_depth=args.disc_depth).to(device)

    # Two optimizers: one for discriminators, one for (A1,A2,T,B1,B2)
    params_D = list(model.D_M1.parameters()) + list(model.D_M2.parameters()) + \
               list(model.DL_1.parameters()) + list(model.DL_2.parameters())
    params_G = list(model.A1.parameters()) + list(model.A2.parameters()) + list(model.T.parameters()) + \
               list(model.B1.parameters()) + list(model.B2.parameters())

    optD = optim.Adam(params_D, lr=args.lr_d, betas=(0.5, 0.999))
    optG = optim.Adam(params_G, lr=args.lr_g, betas=(0.5, 0.999))

    losses = Losses(lambda_rec=args.lambda_rec, lambda_cc=args.lambda_cc, lambda_vsp=args.lambda_vsp)

    model.train()
    global_step = 0
    for epoch in range(1, args.epochs + 1):
        for i, (x, y) in enumerate(loader):
            x = x.to(device)  # M1 batch
            y = y.to(device)  # M2 batch

            # ---- forward passes
            z1 = model.encode1(x)       # M1 -> latent
            z2 = model.encode2(y)       # M2 -> latent

            f1 = model.B2(z1)           # M1 -> M2
            f2 = model.B1(z2)           # M2 -> M1

            r1 = model.B1(z1)           # M1 -> M1
            r2 = model.B2(z2)           # M2 -> M2

            x_cyc = model.F2(f1)        # M1 -> M2 -> M1
            y_cyc = model.F1(f2)        # M2 -> M1 -> M2

            # ---- 1) Discriminators
            optD.zero_grad(set_to_none=True)

            # Output-level D: D_M2 (real=y, fake=f1), D_M1 (real=x, fake=f2)
            d_m2 = losses.adv_d(model.D_M2, y, f1)
            d_m1 = losses.adv_d(model.D_M1, x, f2)

            # Latent-level D: DL_1 (real=z2, fake=z1), DL_2 (real=z1, fake=z2)
            d_l1 = losses.adv_d(model.DL_1, z2, z1)
            d_l2 = losses.adv_d(model.DL_2, z1, z2)

            d_total = d_m1 + d_m2 + d_l1 + d_l2
            d_total.backward()
            optD.step()

            # ---- 2) Generators (adapters + backbone)
            optG.zero_grad(set_to_none=True)

            # refresh logits (avoid reusing detached fakes for G)
            z1 = model.encode1(x)
            z2 = model.encode2(y)
            f1 = model.B2(z1)
            f2 = model.B1(z2)

            # adversarial G-side: make fakes look real; make both latents "look like" the other's
            g_adv = losses.adv_g(model.D_M2, f1) + losses.adv_g(model.D_M1, f2) + \
                    losses.adv_g(model.DL_1, z1) + losses.adv_g(model.DL_2, z2)

            # reconstruction and cycle
            r1 = model.B1(z1)
            r2 = model.B2(z2)
            x_cyc = model.F2(f1)
            y_cyc = model.F1(f2)

            g_rec = losses.reconstruction(x, r1, y, r2)
            g_cc  = losses.cycle(x, x_cyc, y, y_cyc)
            g_vsp = losses.vsp(x, f1, y, f2, max_subset=args.vsp_subset)

            g_total = g_adv + losses.lambda_rec * g_rec + losses.lambda_cc * g_cc + losses.lambda_vsp * g_vsp
            g_total.backward()
            optG.step()

            if (i + 1) % args.log_interval == 0:
                print(f"[E{epoch:03d} B{i+1:04d}] "
                      f"D: {d_total.item():.3f} | "
                      f"G_adv: {g_adv.item():.3f} | "
                      f"Rec: {g_rec.item():.3f} | "
                      f"CC: {g_cc.item():.3f} | "
                      f"VSP: {g_vsp.item():.3f}")

            global_step += 1

        # save a light checkpoint each epoch
        if args.save_ckpt:
            Path("checkpoints").mkdir(exist_ok=True)
            torch.save(model.state_dict(), f"checkpoints/vec2vec_e{epoch}.pt")

    return model, ds


# --------------------------
# Inference / demo
# --------------------------
def demo_inference(model: Vec2VecGAN, ds: UnpairedEmbeddingDataset, k: int = 5, device: str = "cpu"):
    model.eval()
    with torch.no_grad():
        # pick k samples from M1, translate to M2
        idx = torch.randint(0, len(ds.X), (k,))
        x = ds.X[idx].to(device)
        f1 = model.F1(x)

        # compute average cosine between x and its reconstruction R1
        r1 = model.R1(x)
        cos_rec = cosine_sim(x, r1).mean().item()

        # compute a rough neighborhood preservation score:
        # cosine between pairwise distances before/after (k x k)
        x_dots = (x @ x.t()).flatten()
        f_dots = (f1 @ f1.t()).flatten()
        corr = torch.corrcoef(torch.stack([x_dots, f_dots]))[0, 1].item()

        print("\n--- Inference Demo (M1 -> M2) ---")
        print(f"Mean cosine(x, R1(x)) ≈ {cos_rec:.3f}  (reconstruction quality in M1)")
        print(f"Correlation of pairwise dots (x vs F1(x)) ≈ {corr:.3f}  (geometry preservation)")

        # Show a couple of individual cosines between F2(F1(x)) and x (cycle)
        x_cyc = model.F2(f1)
        cyc_cos = cosine_sim(x, x_cyc)
        print("Cosine(x, F2(F1(x))) for a few samples:", [f"{c.item():.3f}" for c in cyc_cos])


# --------------------------
# CLI
# --------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Unpaired vec2vec GAN with Reconstruction, Cycle and VSP.")
    # data
    p.add_argument("--data_m1", type=str, default="", help=".npy file with M1 embeddings (N,d)")
    p.add_argument("--data_m2", type=str, default="", help=".npy file with M2 embeddings (M,d)")
    p.add_argument("--dim", type=int, default=128, help="embedding dimension (toy data)")
    p.add_argument("--n_m1", type=int, default=20000)
    p.add_argument("--n_m2", type=int, default=20000)
    p.add_argument("--mixtures", type=int, default=6, help="num mixture centers for toy data")
    p.add_argument("--std", type=float, default=0.6, help="cluster std for toy data")
    # model
    p.add_argument("--z", type=int, default=128, help="shared latent dim")
    p.add_argument("--back_depth", type=int, default=4)
    p.add_argument("--disc_width", type=int, default=512)
    p.add_argument("--disc_depth", type=int, default=3)
    # train
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--lr_g", type=float, default=2e-4)
    p.add_argument("--lr_d", type=float, default=2e-4)
    p.add_argument("--lambda_rec", type=float, default=1.0)
    p.add_argument("--lambda_cc", type=float, default=10.0)
    p.add_argument("--lambda_vsp", type=float, default=1.0)
    p.add_argument("--vsp_subset", type=int, default=64, help="subset size for VSP BxB dot loss")
    p.add_argument("--log_interval", type=int, default=100)
    p.add_argument("--save_ckpt", action="store_true")
    # misc
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


# args = parse_args()
# # if real data paths are provided, trust their dim
# if args.data_m1 and args.data_m2:
#     assert os.path.exists(args.data_m1) and os.path.exists(args.data_m2), "npy paths not found."
# model, ds = train_v2c_gan(args)
# dev = next(model.parameters()).device
# demo_inference(model, ds, k=8, device=dev)


## TO-vec2vec-gan


In [None]:
# File 1 — `to_v2vgan_train.py`
from typing import Optional

import torch
import torch.nn as nn


"""
Task-Operator vec2vec GAN (TO-v2vGAN)
-------------------------------------
Self-contained PyTorch training script that implements:
  • Output-level adversarial losses (X->Y, Y->X)
  • Latent-level adversarial losses (push shared latent alignment)
  • Reconstruction, Cycle-consistency, Vector-Space Preservation (VSP)
  • Operator prior: T(z) = P z + O(z) with near-orthogonal P and low-rank O (via bottleneck)
Supports: synthetic toy data or .npy embedding files for X (descriptions) and Y (models).

Usage examples
--------------
# synthetic quick run
python to_v2vgan_train.py --epochs 20 --batch_size 128 --z 128 --dim 128

# real embedding files (unpaired)
python to_v2vgan_train.py --data_x desc.npy --data_y model.npy --epochs 50 --z 256

Outputs
-------
checkpoints/vec2vec_e{epoch}.pt  (full model state)
logs/metrics.csv                  (CSV log)

Notes
-----
• This script is a clean consolidation of the components discussed in the paper template.
• For brevity, we use BCE-GAN (you can flip to WGAN-GP if preferred).
• Low-rank prior on O is encouraged via a bottleneck r << z plus L2; swap in spectral/nuclear
  surrogates if you need stronger control.
"""

# ---------------------
# Utils
# ---------------------

def ensure_dirs():
    Path("checkpoints").mkdir(exist_ok=True)
    Path("logs").mkdir(exist_ok=True)



def info_nce(anchor: torch.Tensor, positives: torch.Tensor, tau: float = 0.07) -> torch.Tensor:
    # anchor: (B, d), positives: (B, d)
    a = anchor / (anchor.norm(dim=-1, keepdim=True) + 1e-8)
    p = positives / (positives.norm(dim=-1, keepdim=True) + 1e-8)
    logits = a @ p.t() / tau  # (B, B)
    targets = torch.arange(anchor.size(0), device=anchor.device)
    return nn.CrossEntropyLoss()(logits, targets)


# -----------------------------
# Datasets
# -----------------------------
class UnpairedEmbeddingDataset(torch.utils.data.Dataset):
    """
    Unpaired X and Y embeddings.
    - If data_x/data_y are paths to .npy, load them.
    - Else generate synthetic mixtures for X, and map to Y by an affine+nonlinear warp,
      but re-sample unpaired so X and Y are not aligned index-wise.
    """
    def __init__(self,
                 dim: int,
                 n_x: int,
                 n_y: int,
                 data_x: Optional[str] = None,
                 data_y: Optional[str] = None,
                 mixtures: int = 6,
                 std: float = 0.6):
        if data_x and data_y:
            X = np.load(data_x)
            Y = np.load(data_y)
            assert X.ndim == 2 and Y.ndim == 2, "npy arrays must be 2D (N,d)"
            assert X.shape[1] == Y.shape[1] or dim == X.shape[1] == Y.shape[1], "dim mismatch"
            self.dim = X.shape[1]
            self.X = torch.from_numpy(X).float()
            self.Y = torch.from_numpy(Y).float()
        else:
            self.dim = dim
            centers = torch.randn(mixtures, dim) * 3.0
            # X
            idxs_x = torch.randint(0, mixtures, (n_x,))
            X = centers[idxs_x] + std * torch.randn(n_x, dim)

            # define Y transform
            W = torch.randn(dim, dim)
            U, _, Vt = torch.linalg.svd(W, full_matrices=False)
            R = U @ Vt  # orthogonal
            scale = 1.15
            b = torch.randn(dim) * 0.2

            # unpaired resample for Y
            idxs_y = torch.randint(0, mixtures, (n_y,))
            Y = centers[idxs_y] + std * torch.randn(n_y, dim)
            Y = Y @ (scale * R).T + b
            Y[:, : dim // 6] = torch.tanh(Y[:, : dim // 6])

            self.X = X
            self.Y = Y

        # normalize (helps dot/cos stability)
        self.X = nn.functional.normalize(self.X, dim=-1)
        self.Y = nn.functional.normalize(self.Y, dim=-1)

    def __len__(self):
        return max(len(self.X), len(self.Y))

    def __getitem__(self, idx):
        x = self.X[idx % len(self.X)]
        y = self.Y[random.randint(0, len(self.Y) - 1)]
        return x, y


class PairedEmbeddingDataset(torch.utils.data.Dataset):
    """
    Optional paired anchors (X_paired, Y_paired) — same length, same dim.
    """
    def __init__(self, paired_x: str, paired_y: str):
        Xp = np.load(paired_x)
        Yp = np.load(paired_y)
        assert Xp.ndim == 2 and Yp.ndim == 2 and Xp.shape[0] == Yp.shape[0], "paired arrays must match"
        assert Xp.shape[1] == Yp.shape[1], "dim mismatch in paired arrays"
        self.Xp = nn.functional.normalize(torch.from_numpy(Xp).float(), dim=-1)
        self.Yp = nn.functional.normalize(torch.from_numpy(Yp).float(), dim=-1)

    def __len__(self):
        return self.Xp.size(0)

    def __getitem__(self, idx):
        return self.Xp[idx], self.Yp[idx]


# -----------------------------
# Modules
# -----------------------------
class TOVMLP(nn.Module):
    def __init__(self, dims, act=nn.SiLU, layernorm=True, dropout=0.0):
        super().__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims) - 2:
                if layernorm:
                    layers.append(nn.LayerNorm(dims[i+1]))
                layers.append(act())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)

    def forward(self, x): return self.net(x)


class TOVAdapterIn(nn.Module):
    # A_x, A_y: d -> z
    def __init__(self, d, z):
        super().__init__()
        self.f = MLP([d, 2*z, z], act=nn.SiLU, layernorm=True)

    def forward(self, x): return self.f(x)


class TOVAdapterOut(nn.Module):
    # B_x, B_y: z -> d
    def __init__(self, z, d):
        super().__init__()
        self.f = MLP([z, 2*z, d], act=nn.SiLU, layernorm=True)

    def forward(self, z): return self.f(z)


class TaskOperatorResidual(nn.Module):
    """
    O(z): low-rank residual via bottleneck r << zdim
    """
    def __init__(self, zdim: int, rank: int):
        super().__init__()
        rank = max(1, min(rank, zdim))
        self.f = nn.Sequential(
            nn.Linear(zdim, rank, bias=False),
            nn.SiLU(),
            nn.Linear(rank, zdim, bias=False),
        )

    def forward(self, z): return self.f(z)


class BackboneTO(nn.Module):
    """
    T(z) = P z + O(z)
    P: linear, near-orthogonal; O: low-rank residual (task operator)
    """
    def __init__(self, zdim: int, rank: int):
        super().__init__()
        self.P = nn.Linear(zdim, zdim, bias=False)  # will regularize towards orthogonal
        nn.init.eye_(self.P.weight)                 # start near identity
        self.O = TaskOperatorResidual(zdim, rank)
        self.norm = nn.LayerNorm(zdim)

    def forward(self, z):
        return self.norm(self.P(z) + self.O(z))


class TOVDiscriminator(nn.Module):
    """
    Simple spectral-normalized MLP critic (logits).
    """
    def __init__(self, in_dim: int, width: int = 512, depth: int = 3):
        super().__init__()
        layers = []
        d = in_dim
        for _ in range(depth - 1):
            lin = nn.utils.spectral_norm(nn.Linear(d, width))
            layers += [lin, nn.LeakyReLU(0.2, inplace=True)]
            d = width
        layers.append(nn.utils.spectral_norm(nn.Linear(d, 1)))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x).squeeze(-1)  # logits


class TOv2vGAN(nn.Module):
    """
    Full model container with adapters, backbone (P+O), and discriminators.
    """
    def __init__(self, d: int, z: int, rank_o: int,
                 disc_width=512, disc_depth=3):
        super().__init__()
        # adapters
        self.Ax = TOVAdapterIn(d, z)
        self.Ay = TOVAdapterIn(d, z)
        # backbone
        self.T = BackboneTO(zdim=z, rank=rank_o)
        # outputs
        self.Bx = TOVAdapterOut(z, d)
        self.By = TOVAdapterOut(z, d)
        # discriminators
        self.Dy = TOVDiscriminator(d, width=disc_width, depth=disc_depth)  # real y vs F(x)
        self.Dx = TOVDiscriminator(d, width=disc_width, depth=disc_depth)  # real x vs G(y)
        self.Dl1 = TOVDiscriminator(z, width=disc_width, depth=disc_depth) # real z_y vs fake z_x
        self.Dl2 = TOVDiscriminator(z, width=disc_width, depth=disc_depth) # real z_x vs fake z_y

    # encoders
    def enc_x(self, x): return self.T(self.Ax(x))
    def enc_y(self, y): return self.T(self.Ay(y))

    # translators
    def F_xy(self, x): return self.By(self.enc_x(x))
    def F_yx(self, y): return self.Bx(self.enc_y(y))

    # recon
    def R_x(self, x): return self.Bx(self.enc_x(x))
    def R_y(self, y): return self.By(self.enc_y(y))


# -----------------------------
# Losses
# -----------------------------
class LossPack:
    def __init__(self,
                 lambda_rec=1.0,
                 lambda_cyc=10.0,
                 lambda_vsp=1.0,
                 lambda_ortho=0.1,
                 lambda_energy=0.01,
                 lambda_decomp=0.0,
                 lambda_nce=0.0,
                 lambda_l2=0.0,
                 vsp_subset=64):
        self.bce = nn.BCEWithLogitsLoss()
        self.mse = nn.MSELoss()
        self.huber = nn.SmoothL1Loss(beta=1.0)
        self.lambda_rec = lambda_rec
        self.lambda_cyc = lambda_cyc
        self.lambda_vsp = lambda_vsp
        self.lambda_ortho = lambda_ortho
        self.lambda_energy = lambda_energy
        self.lambda_decomp = lambda_decomp
        self.lambda_nce = lambda_nce
        self.lambda_l2 = lambda_l2
        self.vsp_subset = vsp_subset

    # Adversarial (D)
    def adv_d(self, D, real, fake):
        # label smoothing for real
        ones = torch.full((real.size(0),), 0.9, device=real.device)
        zeros = torch.zeros(fake.size(0), device=fake.device)
        return self.bce(D(real), ones) + self.bce(D(fake.detach()), zeros)

    # Adversarial (G)
    def adv_g(self, D, fake):
        ones = torch.ones(fake.size(0), device=fake.device)
        return self.bce(D(fake), ones)

    # Reconstruction
    def rec(self, x, rx, y, ry):
        return self.mse(rx, x) + self.mse(ry, y)

    # Cycle
    def cyc(self, x, x_cyc, y, y_cyc):
        return self.mse(x_cyc, x) + self.mse(y_cyc, y)

    # VSP
    def vsp(self, X, FX, Y, FY):
        def term(A, FA):
            if A.size(0) > self.vsp_subset:
                idx = torch.randperm(A.size(0), device=A.device)[:self.vsp_subset]
                A = A[idx]
                FA = FA[idx]
            return ((pairwise_dot(A) - pairwise_dot(FA)) ** 2).mean()
        return term(X, FX) + term(Y, FY)

    # Operator regularizers
    def op_losses(self, model: TOv2vGAN, Ax_x: torch.Tensor, Ay_y: torch.Tensor,
              ByPxAx: torch.Tensor, ByOA: torch.Tensor):
        P = model.T.P.weight  # (z,z)
        Id = torch.eye(P.size(0), device=P.device)

        # 1) P near-orthogonal (near-isometry)
        L_ortho = torch.norm(P.T @ P - Id, p='fro') ** 2

        # 2) O should have small energy on typical latents
        Ozx = model.T.O(Ax_x)
        Ozy = model.T.O(Ay_y)
        L_energy = (Ozx.pow(2).sum(dim=-1).mean() + Ozy.pow(2).sum(dim=-1).mean())

        # 3) Optional additive decomposition in Y-space:
        #    F_xy(x) ≈ By(P(Ax(x))) + By(O(Ax(x)))
        #    (Use T.norm so this matches the actual forward path.)
        L_decomp = torch.tensor(0.0, device=Ax_x.device)
        if self.lambda_decomp > 0:
            z_sum = model.T.P(Ax_x) + model.T.O(Ax_x)
            z_sum = model.T.norm(z_sum)              # matches T.forward's LayerNorm
            y_full = model.By(z_sum)                 # F_{x->y}(x)
            y_add  = ByPxAx + ByOA                   # By(P(Ax(x))) + By(O(Ax(x)))
            L_decomp = nn.MSELoss()(y_full, y_add)

        return (
            self.lambda_ortho * L_ortho
            + self.lambda_energy * L_energy
            + self.lambda_decomp * L_decomp
        )


    # Paired tethers
    def paired_losses(self, x_p: torch.Tensor, y_p: torch.Tensor, yhat_p: torch.Tensor):
        L = 0.0
        if self.lambda_nce > 0:
            L += self.lambda_nce * info_nce(yhat_p, y_p)
        if self.lambda_l2 > 0:
            L += self.lambda_l2 * self.huber(yhat_p, y_p)
        return L


# -----------------------------
# Training
# -----------------------------
def make_loader(dataset, batch_size):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)


def next_batch(it, loader):
    try:
        batch = next(it)
    except StopIteration:
        it = iter(loader)
        batch = next(it)
    return batch, it


def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    seed_all(args.seed)

    # Data
    ds_u = UnpairedEmbeddingDataset(
        dim=args.dim, n_x=args.n_x, n_y=args.n_y,
        data_x=args.data_x if args.data_x else None,
        data_y=args.data_y if args.data_y else None,
        mixtures=args.mixtures, std=args.std
    )
    loader_u = make_loader(ds_u, args.batch_size)
    it_u = iter(loader_u)

    ds_p = None
    loader_p = None
    it_p = None
    if args.paired_x and args.paired_y:
        ds_p = PairedEmbeddingDataset(args.paired_x, args.paired_y)
        loader_p = make_loader(ds_p, min(args.batch_size, len(ds_p)))
        it_p = iter(loader_p)

    d = ds_u.X.shape[1]  # embedding dim
    model = TOv2vGAN(d=d, z=args.z, rank_o=args.rank_o, disc_width=args.disc_width, disc_depth=args.disc_depth).to(device)

    # Params
    params_D = list(model.Dx.parameters()) + list(model.Dy.parameters()) + list(model.Dl1.parameters()) + list(model.Dl2.parameters())
    params_G = list(model.Ax.parameters()) + list(model.Ay.parameters()) + list(model.T.parameters()) + list(model.Bx.parameters()) + list(model.By.parameters())

    optD = optim.Adam(params_D, lr=args.lr_d, betas=(0.5, 0.999))
    optG = optim.Adam(params_G, lr=args.lr_g, betas=(0.5, 0.999))

    losses = LossPack(lambda_rec=args.lambda_rec, lambda_cyc=args.lambda_cyc, lambda_vsp=args.lambda_vsp,
                      lambda_ortho=args.lambda_ortho, lambda_energy=args.lambda_energy,
                      lambda_decomp=0.0, lambda_nce=args.lambda_nce, lambda_l2=args.lambda_l2,
                      vsp_subset=args.vsp_subset)

    global_step = 0
    for epoch in range(1, args.epochs + 1):
        for _ in range(len(loader_u)):
            (x_u, y_u), it_u = next_batch(it_u, loader_u)
            x_u = x_u.to(device)
            y_u = y_u.to(device)

            # ---- Forward (unpaired)
            zx = model.enc_x(x_u)   # z_x
            zy = model.enc_y(y_u)   # z_y
            y_hat = model.By(zx)    # F_{x->y}(x)
            x_hat = model.Bx(zy)    # F_{y->x}(y)
            rx = model.Bx(zx)       # R_x(x)
            ry = model.By(zy)       # R_y(y)

            x_cyc = model.F_yx(y_hat)  # x -> y -> x
            y_cyc = model.F_xy(x_hat)  # y -> x -> y

            # ---- 1) Discriminators
            optD.zero_grad(set_to_none=True)
            d_y = losses.adv_d(model.Dy, y_u, y_hat)  # real y vs fake y_hat
            d_x = losses.adv_d(model.Dx, x_u, x_hat)  # real x vs fake x_hat
            d_l1 = losses.adv_d(model.Dl1, zy, zx)    # latent: real=zy, fake=zx
            d_l2 = losses.adv_d(model.Dl2, zx, zy)    # latent: real=zx, fake=zy
            d_total = d_y + d_x + d_l1 + d_l2
            d_total.backward()
            optD.step()

            # ---- 2) Generators (adapters/backbone)
            optG.zero_grad(set_to_none=True)
            # refresh forward for generator grads
            zx = model.enc_x(x_u)
            zy = model.enc_y(y_u)
            y_hat = model.By(zx)
            x_hat = model.Bx(zy)
            rx = model.Bx(zx)
            ry = model.By(zy)
            x_cyc = model.F_yx(y_hat)
            y_cyc = model.F_xy(x_hat)

            # adversarial
            g_adv = (losses.adv_g(model.Dy, y_hat) + losses.adv_g(model.Dx, x_hat) +
                     losses.adv_g(model.Dl1, zx) + losses.adv_g(model.Dl2, zy))
            # structural
            g_rec = losses.rec(x_u, rx, y_u, ry)
            g_cyc = losses.cyc(x_u, x_cyc, y_u, y_cyc)
            g_vsp = losses.vsp(x_u, y_hat, y_u, x_hat)

            # operator regularizers (need Ax(x), Ay(y))
            ax = model.Ax(x_u)
            ay = model.Ay(y_u)
            ByPxAx = model.By(model.T.P(ax))
            ByOA   = model.By(model.T.O(ax))
            g_op = losses.op_losses(model, ax, ay, ByPxAx, ByOA)

            # paired tethers (optional)
            g_pair = 0.0
            if ds_p is not None and (losses.lambda_nce > 0 or losses.lambda_l2 > 0):
                (xp, yp), it_p = next_batch(it_p, loader_p)
                xp = xp.to(device)
                yp = yp.to(device)
                yhat_p = model.F_xy(xp)
                g_pair = losses.paired_losses(xp, yp, yhat_p)

            g_total = g_adv + losses.lambda_rec * g_rec + losses.lambda_cyc * g_cyc + losses.lambda_vsp * g_vsp + g_op + g_pair
            g_total.backward()
            nn.utils.clip_grad_norm_(params_G, max_norm=5.0)
            optG.step()

            # occasional re-orthogonalization of P (project to closest orthogonal)
            if args.reortho_every > 0 and (global_step + 1) % args.reortho_every == 0:
                with torch.no_grad():
                    W = model.T.P.weight.data
                    U, _, Vt = torch.linalg.svd(W, full_matrices=False)
                    model.T.P.weight.copy_(U @ Vt)

            # logs
            if (global_step + 1) % args.log_interval == 0:
                print(f"[E{epoch:03d} S{global_step+1:06d}] "
                      f"D:{d_total.item():.3f} | G_adv:{g_adv.item():.3f} "
                      f"| Rec:{g_rec.item():.3f} Cyc:{g_cyc.item():.3f} VSP:{g_vsp.item():.3f} "
                      f"| Op:{g_op.item():.3f} Pair:{(g_pair if isinstance(g_pair,float) else g_pair.item()):.3f}")

            global_step += 1

        # save ckpt
        if args.save_ckpt:
            Path("checkpoints").mkdir(exist_ok=True)
            torch.save(model.state_dict(), f"checkpoints/to_v2v_e{epoch}.pt")

    return model, ds_u


# -----------------------------
# Inference / Demo
# -----------------------------
def demo(model: TOv2vGAN, ds: UnpairedEmbeddingDataset, k: int = 8, device: str = "cpu"):
    model.eval()
    with torch.no_grad():
        # sample k queries from X, translate to Y
        idx = torch.randint(0, ds.X.size(0), (k,))
        xq = ds.X[idx].to(device)
        y_hat = model.F_xy(xq)

        # recon / cycle quality (in X)
        rx = model.R_x(xq)
        xcyc = model.F_yx(y_hat)
        rec_cos = cosine_sim(xq, rx).mean().item()
        cyc_cos = cosine_sim(xq, xcyc).mean().item()

        # neighborhood preservation: corr of pairwise dot matrices (k x k)
        x_dots = pairwise_dot(xq).flatten()
        yh_dots = pairwise_dot(y_hat).flatten()
        corr = torch.corrcoef(torch.stack([x_dots, yh_dots]))[0, 1].item()

        # retrieval demo: find nearest Y for each translated y_hat (cosine top-1)
        Yall = ds.Y.to(device)
        yhat_n = y_hat / (y_hat.norm(dim=-1, keepdim=True) + 1e-8)
        Yn = Yall / (Yall.norm(dim=-1, keepdim=True) + 1e-8)
        sims = yhat_n @ Yn.t()  # (k, |Y|)
        top1 = sims.argmax(dim=1).tolist()

        print("\n=== Inference Demo (X -> Y) ===")
        print(f"Mean cosine(x, R_x(x))      : {rec_cos:.3f}")
        print(f"Mean cosine(x, cyc(x))      : {cyc_cos:.3f}")
        print(f"VSP corr (pairwise dots)    : {corr:.3f}")
        print(f"Top-1 retrieved indices in Y: {top1[:min(5, len(top1))]} ... (k={k})")


# -----------------------------
# CLI
# -----------------------------
def parse_args_tov():
    p = argparse.ArgumentParser(description="Task-Operator vec2vec GAN (TO-v2vGAN)")
    # data
    p.add_argument("--data_x", type=str, default="", help=".npy file for X-space (descriptions)")
    p.add_argument("--data_y", type=str, default="", help=".npy file for Y-space (models)")
    p.add_argument("--paired_x", type=str, default="", help="optional paired X .npy")
    p.add_argument("--paired_y", type=str, default="", help="optional paired Y .npy")
    p.add_argument("--dim", type=int, default=128, help="embedding dimension (synthetic)")
    p.add_argument("--n_x", type=int, default=20000)
    p.add_argument("--n_y", type=int, default=20000)
    p.add_argument("--mixtures", type=int, default=6)
    p.add_argument("--std", type=float, default=0.6)
    # model
    p.add_argument("--z", type=int, default=128)
    p.add_argument("--rank_o", type=int, default=16, help="low-rank bottleneck for O(z)")
    p.add_argument("--disc_width", type=int, default=512)
    p.add_argument("--disc_depth", type=int, default=3)
    # train
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--lr_g", type=float, default=2e-4)
    p.add_argument("--lr_d", type=float, default=2e-4)
    p.add_argument("--lambda_rec", type=float, default=1.0)
    p.add_argument("--lambda_cyc", type=float, default=10.0)
    p.add_argument("--lambda_vsp", type=float, default=1.0)
    p.add_argument("--lambda_ortho", type=float, default=0.1)
    p.add_argument("--lambda_energy", type=float, default=0.01)
    p.add_argument("--lambda_nce", type=float, default=0.0)
    p.add_argument("--lambda_l2", type=float, default=0.0)
    p.add_argument("--vsp_subset", type=int, default=64)
    p.add_argument("--log_interval", type=int, default=200)
    p.add_argument("--reortho_every", type=int, default=1000, help="steps between P re-orthogonalization; 0=off")
    p.add_argument("--save_ckpt", action="store_true")
    # misc
    p.add_argument("--cpu", action="store_true")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args_tov()
    if bool(args.data_x) ^ bool(args.data_y):
        raise SystemExit("Provide both --data_x and --data_y, or neither (synthetic).")
    if bool(args.paired_x) ^ bool(args.paired_y):
        raise SystemExit("Provide both --paired_x and --paired_y, or neither.")
    model, ds = train(args)
    dev = next(model.parameters()).device
    demo(model, ds, k=8, device=dev)

In [None]:
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


"""
Metamodel-Constrained Graph Decoder (Skeleton)
----------------------------------------------
Condition on a model embedding y (or \hat{y} = F_{x->y}(x)) and decode a typed graph
(e.g., UML/ArchiMate/OntoUML) that respects a given metamodel.

This is a minimal, extensible scaffold illustrating:
  • Metamodel schema (node/edge types, allowed incidence)
  • Validity checks during autoregressive generation (masking invalid actions)
  • Training with teacher forcing from gold graphs

You should replace Dataset, featureization, and training loop to fit your data.
"""

# ---------------- Metamodel schema ----------------
@dataclass
class MetaModel:
    node_types: List[str]
    edge_types: List[str]
    # adjacency constraints: (edge_type) -> (src_allowed_types, dst_allowed_types)
    constraints: Dict[str, Tuple[List[str], List[str]]]

    def node_type_idx(self) -> Dict[str, int]:
        return {t:i for i,t in enumerate(self.node_types)}
    def edge_type_idx(self) -> Dict[str, int]:
        return {t:i for i,t in enumerate(self.edge_types)}

# Example tiny ArchiMate-like subset
ARCHI_TINY = MetaModel(
    node_types=["BusinessActor","BusinessRole","ApplicationComponent","DataObject"],
    edge_types=["Assignment","Access","Serving"],
    constraints={
        "Assignment": (["BusinessActor","BusinessRole"],["BusinessRole","ApplicationComponent"]),
        "Access": (["ApplicationComponent"],["DataObject"]),
        "Serving": (["ApplicationComponent"],["BusinessRole","BusinessActor"]),
    }
)

# ---------------- Graph representation ----------------
class TypedGraph:
    def __init__(self, node_types: List[int], edges: List[Tuple[int,int,int]]):
        """edges: list of (src_idx, dst_idx, edge_type_idx)"""
        self.node_types = node_types
        self.edges = edges

# ---------------- Decoder ----------------
class GraphDecoder(nn.Module):
    """
    Autoregressive decoder with validity masking.
    At each step it either (a) adds a node with a type, or (b) adds an edge between existing nodes with a valid edge type.
    Conditioning: concatenates \hat{y} to the state.
    """
    def __init__(self, y_dim: int, hidden: int, num_node_types: int, num_edge_types: int):
        super().__init__()
        self.y_proj = nn.Linear(y_dim, hidden)
        self.state = nn.GRU(input_size=hidden, hidden_size=hidden, batch_first=True)
        self.node_head = nn.Linear(hidden, num_node_types)
        self.edge_head = nn.Linear(hidden, num_edge_types)
        # scoring heads for picking src/dst among existing nodes
        self.src_head = nn.Linear(hidden, 1)
        self.dst_head = nn.Linear(hidden, 1)

    def forward(self, y_emb: torch.Tensor, steps: int, mm: MetaModel,
                teacher: Optional[TypedGraph] = None, tau: float = 0.0) -> TypedGraph:
        """Greedy/teacher-forced decode. tau>0 enables Gumbel noise for exploration."""
        device = y_emb.device
        h = torch.tanh(self.y_proj(y_emb)).unsqueeze(1)  # (B=1,1,H) for simplicity
        hs, _ = self.state(h)  # one tick to initialize
        hidden = hs[:, -1, :]  # (1,H)
        node_types: List[int] = []
        edges: List[Tuple[int,int,int]] = []

        et2i = mm.edge_type_idx()

        # add at least one node
        logits_nt = self.node_head(hidden)  # (1, num_node_types)
        probs_nt = F.softmax(logits_nt, dim=-1)
        if tau > 0:
            g = -torch.log(-torch.log(torch.rand_like(probs_nt)))
            probs_nt = F.softmax((logits_nt + g) / max(1e-6, tau), dim=-1)
        nt = int(torch.argmax(probs_nt, dim=-1))
        node_types.append(nt)

        # subsequent steps: alternately add node or edge
        for t in range(steps):
            # re-encode a simple state vector: counts and last type
            state_vec = torch.cat([
                hidden,
                F.one_hot(torch.tensor([nt], device=device), num_classes=len(mm.node_types)).float()
            ], dim=-1)
            hs, _ = self.state(state_vec.unsqueeze(1), None)
            hidden = hs[:, -1, :]

            # decide to add node vs edge (heuristic: add edges after at least 2 nodes)
            add_edge = (len(node_types) >= 2)
            if not add_edge:
                logits_nt = self.node_head(hidden)
                nt = int(torch.argmax(F.softmax(logits_nt, dim=-1), dim=-1))
                node_types.append(nt)
                continue

            # choose edge type with validity mask
            logits_et = self.edge_head(hidden).squeeze(0)
            mask = torch.zeros_like(logits_et)
            # any pair of nodes whose types satisfy constraint is valid
            valid_any = False
            for e_name,(src_ok,dst_ok) in mm.constraints.items():
                e_idx = et2i[e_name]
                # check if at least one valid pair exists now
                for s in range(len(node_types)):
                    for d in range(len(node_types)):
                        if s==d: 
                            continue
                        if (mm.node_types[node_types[s]] in src_ok) and (mm.node_types[node_types[d]] in dst_ok):
                            mask[e_idx] = 1.0
                            valid_any = True
                            break
                    if valid_any:
                        break
            if not valid_any:
                # fallback: add another node
                logits_nt = self.node_head(hidden)
                nt = int(torch.argmax(F.softmax(logits_nt, dim=-1), dim=-1))
                node_types.append(nt)
                continue
            probs_et = F.softmax(logits_et.masked_fill(mask<0.5, -1e9), dim=-1)
            et = int(torch.argmax(probs_et))

            # select src/dst among nodes with masks
            src_scores = torch.zeros(len(node_types), device=device)
            dst_scores = torch.zeros(len(node_types), device=device)
            for i in range(len(node_types)):
                src_scores[i] = self.src_head(hidden).squeeze()
                dst_scores[i] = self.dst_head(hidden).squeeze()
            # pick the first valid pair greedily
            chosen = None
            e_name = mm.edge_types[et]
            src_ok, dst_ok = mm.constraints[e_name]
            for s in range(len(node_types)):
                for d in range(len(node_types)):
                    if s==d: 
                        continue
                    if (mm.node_types[node_types[s]] in src_ok) and (mm.node_types[node_types[d]] in dst_ok):
                        chosen = (s,d)
                        break
                if chosen: 
                    break
            if chosen is not None:
                s,d = chosen
                edges.append((s,d,et))
            else:
                # no valid pair; skip
                pass

        return TypedGraph(node_types, edges)


# ---------------- Training skeleton ----------------
class GraphDataset(torch.utils.data.Dataset):
    """Placeholder: return (y_embedding, gold_graph) pairs."""
    def __init__(self, Y: torch.Tensor, graphs: List[TypedGraph]):
        self.Y = Y
        self.graphs = graphs
    def __len__(self): return len(self.graphs)
    def __getitem__(self, i): return self.Y[i], self.graphs[i]


def teacher_force_loss(dec: GraphDecoder, y: torch.Tensor, gold: TypedGraph, mm: MetaModel):
    """Sketch: cross-entropy over node types and edge triplets with constraint masks."""
    device = y.device
    # For brevity we mock a tiny supervision: encourage first node type
    out_first = dec.node_head(torch.tanh(dec.y_proj(y))).squeeze(0)
    target_nt0 = torch.tensor(gold.node_types[0], device=device)
    return F.cross_entropy(out_first.unsqueeze(0), target_nt0.unsqueeze(0))


def train_decoder(dec: GraphDecoder, ds: GraphDataset, mm: MetaModel, epochs=5, lr=1e-3):
    opt = torch.optim.Adam(dec.parameters(), lr=lr)
    for e in range(1, epochs+1):
        total = 0.0
        for y, g in ds:
            y = y.unsqueeze(0)  # (1, D)
            loss = teacher_force_loss(dec, y, g, mm)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += float(loss.item())
        print(f"[Decoder] Epoch {e} loss={total/len(ds):.4f}")


if __name__ == "__main__":
    # Tiny smoke test
    D = 64
    y = torch.randn(1, D)
    dec = GraphDecoder(y_dim=D, hidden=128, num_node_types=len(ARCHI_TINY.node_types),
                       num_edge_types=len(ARCHI_TINY.edge_types))
    tg = dec(y, steps=5, mm=ARCHI_TINY)
    print("Nodes:", [ARCHI_TINY.node_types[t] for t in tg.node_types])
    print("Edges:", [(s,d,ARCHI_TINY.edge_types[e]) for (s,d,e) in tg.edges])

In [1]:
import pickle

with open("datasets/eamodelset_nl2cm_embedding.pkl", "rb") as f:
    data = pickle.load(f)

In [None]:
import numpy as np

np.stack(data['CM_Serialization_Emb']).shape

(978, 1536)