In [None]:
import torch.nn as nn
from torchdyn.core import NeuralODE
from torchdyn.datasets import ToyDataset
from torchdyn.models import CNF
from torchdyn.nn import Augmenter, DataControl, DepthCat

# from torchdyn.utils import *

In [None]:
# quick run for automated notebook validation
dry_run = False

In [None]:
data = ToyDataset()
n_samples = 1 << 14
n_gaussians = 7

X, yn = data.generate(n_samples, "diffeqml", noise=5e-2)
X = (X - X.mean()) / X.std()

import matplotlib.pyplot as plt

plt.figure(figsize=(3, 3))
plt.scatter(X[:, 0], X[:, 1], c="olive", alpha=0.3, s=1)

In [None]:
import torch
import torch.utils.data as data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
X_train = torch.Tensor(X).to(device)
train = data.TensorDataset(X_train)
trainloader = data.DataLoader(train, batch_size=1024, shuffle=True)

In [None]:
def hutch_trace(x_out, x_in, noise=None, **kwargs):
    """Hutchinson's trace Jacobian estimator, O(1) call to autograd"""
    jvp = torch.autograd.grad(x_out, x_in, noise, create_graph=True)[0]
    trJ = torch.einsum("bi,bi->b", jvp, noise)
    return trJ

In [None]:
f = nn.Sequential(
    nn.Linear(2, 64),
    nn.Softplus(),
    nn.Linear(64, 64),
    nn.Softplus(),
    nn.Linear(64, 64),
    nn.Softplus(),
    nn.Linear(64, 2),
)

from torch.distributions import (
    Categorical,
    MultivariateNormal,
    SigmoidTransform,
    TransformedDistribution,
    Uniform,
)

prior = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))

# stochastic estimators require a definition of a distribution where "noise" vectors are sampled from
noise_dist = MultivariateNormal(torch.zeros(2).to(device), torch.eye(2).to(device))
# cnf wraps the net as with other energy models
cnf = CNF(f, trace_estimator=hutch_trace, noise_dist=noise_dist)
nde = NeuralODE(cnf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

In [None]:
model = nn.Sequential(Augmenter(augment_idx=1, augment_dims=1), nde).to(device)

In [None]:
import pytorch_lightning as pl


class Learner(pl.LightningModule):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.iters = 0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        self.iters += 1
        x = batch[0]
        xtrJ = self.model(x)
        logprob = (
            prior.log_prob(xtrJ[:, 1:]).to(x) - xtrJ[:, 0]
        )  # logp(z_S) = logp(z_0) - \int_0^S trJ
        loss = -torch.mean(logprob)
        nde.nfe = 0
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=2e-3, weight_decay=1e-5)

    def train_dataloader(self):
        return trainloader

In [None]:
learn = Learner(model)
trainer = pl.Trainer(max_epochs=600)
trainer.fit(learn);