# Simple Energy-Based Model (attempt)

In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
if device == 'cuda': torch.cuda.manual_seed_all(0)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=128, shuffle=True, drop_last=True, num_workers=2, pin_memory=(device=='cuda'))

class EBM(nn.Module):
    def __init__(self):
        super().__init__()
        ch = 64
        self.f = nn.Sequential(
            nn.Conv2d(1, ch, 3, 2, 1), nn.SiLU(),
            nn.Conv2d(ch, ch*2, 3, 2, 1), nn.SiLU(),
            nn.Conv2d(ch*2, ch*4, 3, 2, 1), nn.SiLU(),
        )
        self.head = nn.Linear(ch*4*4*4, 1)

    def forward(self, x):
        h = self.f(x)
        h = h.flatten(1)
        return self.head(h).squeeze(-1)

E = EBM().to(device)
opt = torch.optim.AdamW(E.parameters(), lr=1e-4, betas=(0.9, 0.999))


N = len(ds)
replay = torch.randn(N, 1, 28, 28, device=device)

def sgld(init, steps=120, eps=1e-3):
    x = init.detach().clone().requires_grad_(True)
    noise_scale = (2.0 * eps) ** 0.5
    for _ in range(steps):
        e = E(x).sum()
        (g,) = torch.autograd.grad(e, x, create_graph=False)
        x = x - eps * g + noise_scale * torch.randn_like(x)
        x = x.detach().clamp(-1, 1).requires_grad_(True)
    return x.detach().clamp(-1, 1)


E.train()
log_every = 100
global_step = 0

for epoch in range(10):
    for x_pos, _ in loader:
        x_pos = x_pos.to(device, non_blocking=True)
        idx = torch.randint(0, N, (x_pos.size(0),), device=device)
        x_neg0 = replay[idx]

        mix = torch.rand(x_pos.size(0), 1, 1, 1, device=device)
        x_start = (mix * x_pos + (1 - mix) * x_neg0).clamp(-1, 1)

        x_neg = sgld(x_start, steps=120)
        replay[idx] = x_neg

        e_pos = E(x_pos)
        e_neg = E(x_neg)
        loss = e_pos.mean() - e_neg.mean()
        reg_loss = (e_pos ** 2 + e_neg ** 2).mean()
        opt.zero_grad(set_to_none=True)
        (loss + 1e-4 * reg_loss).backward()
        nn.utils.clip_grad_norm_(E.parameters(), 1.0)
        opt.step()

        global_step += 1
        if global_step % log_every == 0:
            with torch.no_grad():
                print(f"step {global_step:6d}  loss {loss.item():.4f}  e_pos {e_pos.mean().item():.3f}  e_neg {e_neg.mean().item():.3f}")


In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


x_samp = sgld(torch.randn(16, 1, 28, 28, device=device), steps=1000)
grid = make_grid((x_samp + 1) / 2, nrow=8)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.show()
