after `pip install --force-reinstall --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
`

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
from torchdiffeq import odeint

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f081cbcf7f0>

In [4]:
# -----------------------
# Toy dataset: two moons
# -----------------------
def two_moons(n=2048, noise=0.06, device="cpu"):
    n1 = n // 2
    n2 = n - n1
    # upper moon
    t1 = torch.rand(n1, device=device) * math.pi
    x1 = torch.stack([torch.cos(t1), torch.sin(t1)], dim=1)
    # lower moon (shifted)
    t2 = torch.rand(n2, device=device) * math.pi
    x2 = torch.stack([1 - torch.cos(t2), -torch.sin(t2) - 0.5], dim=1)
    x = torch.cat([x1, x2], dim=0)
    x += noise * torch.randn_like(x)
    return x

In [5]:
# -----------------------
# Utilities
# -----------------------
def standard_normal_logprob(z):
    # (B,D) -> (B,)
    log_z = -0.5 * math.log(2 * math.pi)
    return (log_z - 0.5 * z ** 2).sum(dim=1)

def hutchinson_trace(df_dz_v, v):
    # df_dz_v is Jv; trace(J) ≈ v^T J v
    return (df_dz_v * v).sum(dim=1)

In [6]:
# -----------------------
# ODE function f_theta(t, z)
# Time is concatenated as a feature (Neural ODE with time conditioning)
# -----------------------
class ODEfunc(nn.Module):
    def __init__(self, dim, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, dim),
        )
        # Small weight init helps stability early on
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.8)
                nn.init.zeros_(m.bias)

    def forward(self, t, z):
        # z: (B, D)
        t_feat = torch.ones(z.shape[0], 1, device=z.device) * t
        inp = torch.cat([z, t_feat], dim=1)
        return self.net(inp)

In [7]:
# -----------------------
# CNF wrapper: integrates (z, logp) jointly
# d/dt logp = -Tr(df/dz) (Hutchinson)
# -----------------------
class CNF(nn.Module):
    def __init__(self, dim, hidden=64, t0=0.0, t1=1.0):
        super().__init__()
        self.f = ODEfunc(dim, hidden)
        self.register_buffer("t_span", torch.tensor([t0, t1], dtype=torch.float32))

    def _aug_dynamics(self, t, state):
        z, logp, eps = state  # z: (B,D), logp: (B,), eps: (B,D)
        z = z.requires_grad_(True)

        with torch.enable_grad():
            f = self.f(t, z)                             # (B,D)
            # Jacobian-vector product J^T * eps  (or J * eps—either works for trace with same eps)
            # We compute Jv by autograd: grad(f, z, v) gives v^T J (i.e., J^T v), so use same eps
            Jt_eps = grad(f, z, eps, retain_graph=True, create_graph=True)[0]  # (B,D)
            # trace(J) ≈ v^T J v == (J^T v)·v
            trace_est = hutchinson_trace(Jt_eps, eps)     # (B,)
            dlogp = -trace_est
        return (f, dlogp, torch.zeros_like(eps))          # eps is constant in time

    def transform(self, x, reverse=False):
        """
        If reverse=False: x -> z (to base) integrating t0->t1
        If reverse=True:  z -> x (for sampling) integrating t1->t0
        Returns (z_or_x, logp_delta)
        """
        B, D = x.shape
        device = x.device
        # Rademacher noise usually works well; Gaussian also fine.
        eps = torch.randint(0, 2, x.shape, device=device, dtype=torch.float32) * 2 - 1

        if reverse:
            t_span = torch.flip(self.t_span, dims=[0])
        else:
            t_span = self.t_span

        logp0 = torch.zeros(B, device=device)
        state0 = (x, logp0, eps)
        zT, logp_T, _ = odeint(self._aug_dynamics, state0, t_span, atol=1e-5, rtol=1e-5)
        z1 = zT[-1]
        logp1 = logp_T[-1]
        return z1, logp1

    def log_prob(self, x):
        z, delta_logp = self.transform(x, reverse=False)
        logpz = standard_normal_logprob(z)
        return logpz + delta_logp  # log p_x(x) = log p_z(z) + ∫ dlogp

    def sample(self, n, device="cpu"):
        z = torch.randn(n, 2, device=device)
        x, _ = self.transform(z, reverse=True)
        return x

In [8]:
# -----------------------
# Train loop
# -----------------------
def train(device="cpu"):
    dim = 2
    model = CNF(dim=dim, hidden=64, t0=0.0, t1=1.0).to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    batch_size = 512

    for step in range(4000):
        x = two_moons(n=batch_size, device=device)
        # Maximize log-likelihood (minimize NLL)
        logpx = model.log_prob(x)
        loss = -logpx.mean()

        opt.zero_grad(set_to_none=True)
        loss.backward()
        # Optional: gradient clipping for stability
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        opt.step()

        if (step + 1) % 50 == 0:
            torch.cuda.empty_cache()

        if (step + 1) % 200 == 0:
            print(f"step {step+1:4d}  NLL: {loss.item():.3f}")

    # Sample a few points after training
    with torch.no_grad():
        samples = model.sample(4096, device=device).cpu()
    return model, samples

In [None]:
assert torch.cuda.is_available()

device = torch.device("cuda:0")
model, samples = train(device=device)
# Save samples if you want to plot later
torch.save(samples, "cnf_samples.pt")
print("Saved samples to cnf_samples.pt")