In [None]:
# talagrand_binary_diffusion.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import imageio
from tqdm import tqdm

# ------------------ Config ------------------
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs     = 45
tau        = 3.0
gamma      = 1.0 - torch.exp(torch.tensor(-tau))
lr         = 1e-4
steps      = 100
cfg_w      = 3.5          
drop_rate  = 0.12         
seed       = 42
torch.manual_seed(seed)

save_dir = "talagrand_diffusion"
os.makedirs(save_dir, exist_ok=True)

# ------------------ Data ------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x > 0.5).float().squeeze(0).view(-1))  # binarize
])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# ------------------ Model ------------------
class ScoreNet(nn.Module):
    def __init__(self, hidden=448, n_classes=11):  
        super().__init__()
        self.emb = nn.Embedding(n_classes, 64)
        self.net = nn.Sequential(
            nn.Linear(784 + 1 + 64, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 784)
        )
    def forward(self, x, t, y):
        if y is None:
            y = torch.full((x.shape[0],), 10, device=x.device, dtype=torch.long)
        e = self.emb(y)
        inp = torch.cat([x, t.unsqueeze(1), e], dim=1)
        return self.net(inp)

model = ScoreNet().to(device)
opt   = torch.optim.Adam(model.parameters(), lr=lr)

# ------------------ Forward noise ------------------
def forward(x0, t):
    x0 = x0.view(-1, 784)
    p = gamma * t.unsqueeze(1)
    flip = torch.rand_like(x0) < p
    x_t = torch.where(flip, 1 - x0, x0)
    rand = torch.rand_like(x0) < p * 0.1
    x_t = torch.where(rand, torch.bernoulli(torch.full_like(x0, 0.5)), x_t)
    return x_t

# ------------------ Sampling ------------------
@torch.no_grad()
def sample(digit=None, steps=steps, w=cfg_w):
    model.eval()
    x = torch.bernoulli(torch.full((1, 784), 0.5, device=device))
    y_cond = None if digit is None else torch.tensor([digit], device=device, dtype=torch.long)

    for i in range(steps):
        t = torch.tensor([1.0 - i/steps], device=device).expand(1)
        score_uncond = model(x, t, None)
        score = score_uncond
        if y_cond is not None and w > 1.0:
            score_cond = model(x, t, y_cond)
            score = score_uncond + w * (score_cond - score_uncond)

        r1 = torch.exp(score) * (1.0/steps)
        r0 = torch.exp(-score) * (1.0/steps)
        p1 = r1 / (r1 + r0 + 1e-12)
        p0 = r0 / (r1 + r0 + 1e-12)

        x = torch.where((x == 0) & (torch.rand_like(x) < p1), torch.ones_like(x), x)
        x = torch.where((x == 1) & (torch.rand_like(x) < p0), torch.zeros_like(x), x)
    return x[0].view(28, 28).cpu()

# ------------------ Save 0-9 grid ------------------
def save_grid(epoch):
    w = cfg_w if epoch >= 10 else 1.0
    imgs = [sample(d, w=w) for d in range(10)]
    plt.figure(figsize=(12, 2.4))
    for i, img in enumerate(imgs):
        plt.subplot(1, 10, i+1)
        plt.imshow(img, cmap="gray", vmin=0, vmax=1)
        plt.title(i, fontsize=16)
        plt.axis("off")
    plt.suptitle(f"Epoch {epoch}", fontsize=18, y=0.98)
    plt.tight_layout()
    path = f"{save_dir}/epoch_{epoch:03d}.png"
    plt.savefig(path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"â†’ {os.path.basename(path)}")

# ------------------ Training ------------------
model_path = "talagrand_diffusion_final.pt"

if not os.path.exists(model_path):
    print("Training from scratch")
    save_grid(0)

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        for x0, y in loader:
            x0, y = x0.to(device), y.to(device)
            if torch.rand(1) < drop_rate:
                y = None
            t = torch.rand(x0.shape[0], device=device)
            x_t = forward(x0, t)

            with torch.no_grad():
                retain = 1 - gamma * t.unsqueeze(1)
                target = (x0 - 0.5) * 4.0 / (retain + 1e-6)

            pred = model(x_t, t, y)
            loss = F.mse_loss(pred, target)

            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()

        print(f"Epoch {epoch:2d} | Loss {total_loss/len(loader):.6f}")
        save_grid(epoch)

        if epoch % 15 == 0:
            torch.save(model.state_dict(), f"talagrand_diffusion_{epoch}.pt")

    torch.save(model.state_dict(), model_path)
else:
    print(f"Loading {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))

# ------------------ Make GIF ------------------
print("Building GIF...")
frames = []
for ep in range(0, epochs + 1, 1 if epochs <= 30 else 2):
    p = f"{save_dir}/epoch_{ep:03d}.png"
    if os.path.exists(p):
        frames.append(imageio.imread(p))
imageio.mimsave("talagrand_training.gif", frames, fps=10)
print("GIF saved: talagrand_training.gif")