In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === 1. Toy dataset ===
def get_data(n_samples=1024):
    x, _ = make_moons(n_samples=n_samples, noise=0.1)
    return torch.tensor(x, dtype=torch.float32)

# === 2. CNF dynamics: small time-dependent neural net ===
class CNFDynamics(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, dim)
        )
        self.net[-1].weight.data *= 0.01  # stabilize init

    def forward(self, t, z):
        t_vec = torch.ones(z.shape[0], 1, device=z.device) * t  # (B, 1)
        zt = torch.cat([z, t_vec], dim=1)  # (B, d+1)
        return self.net(zt)

# === 3. Divergence estimation using Hutchinson's trick ===
def divergence_approx(f, z):
    e = torch.randn_like(z)
    grad = torch.autograd.grad(
        outputs=f, inputs=z, grad_outputs=e,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    return torch.sum(grad * e, dim=1)

# === 4. ODE function module (for adjoint tracking) ===
class ODEFunc(nn.Module):
    def __init__(self, dynamics):
        super().__init__()
        self.dynamics = dynamics

    def forward(self, t, state):
        z, logp = state
        z.requires_grad_(True)
        dzdt = self.dynamics(t, z)
        dlogpdt = -divergence_approx(dzdt, z)
        return dzdt, dlogpdt

# === 5. CNF model ===
class CNF(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dynamics = CNFDynamics(dim)
        self.odefunc = ODEFunc(self.dynamics)

    def forward(self, z0):
        logp0 = standard_normal_logprob(z0).sum(1)
        t = torch.tensor([0.0, 1.0]).to(z0)

        zt, logpt = odeint(self.odefunc, (z0, logp0), t, rtol=1e-5, atol=1e-5)
        return zt[-1], logpt[-1]

# === 6. Standard Gaussian logprob ===
def standard_normal_logprob(z):
    log_z = -0.5 * np.log(2 * np.pi)
    return log_z - z.pow(2) / 2

# === 7. Train CNF ===
data = get_data().to(device)
model = CNF(dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1000):
    idx = torch.randint(0, data.shape[0], (256,))
    z0 = data[idx].to(device)
    zT, logpT = model(z0)
    loss = -logpT.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, NLL: {loss.item():.4f}")



RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn