In [1]:
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
from dataclasses import dataclass, field, InitVar
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np


@dataclass
class CheckBoardDataset(Dataset):
    samples: int
    rows: int = 4
    cols: int = 4

    def __len__(self):
        return self.samples

    def sample_x0(self):
        return np.random.randn(2)

    def sample_x1(self):
        x = np.random.uniform(-self.cols // 2, self.cols // 2)
        y = np.random.uniform(-self.rows // 2, self.rows // 2)
        y = np.where((np.floor(x) + np.floor(y)) % 2 == 0, y, -y)
        return np.array([x, y])

    def sample_t(self):
        return np.random.rand(1)

    def __getitem__(self, idx):
        x0 = self.sample_x0().astype(np.float32)
        x1 = self.sample_x1().astype(np.float32)
        t = self.sample_t().astype(np.float32)
        return x0, x1, t

    def loader(self, batch_size: int):
        return DataLoader(self, batch_size=batch_size, num_workers=8)

In [3]:
import torch
from torch import nn
from lightning import LightningModule, Trainer
from lightning.pytorch import callbacks


class ResBlock(nn.Module):
    def __init__(self, dim, act:nn.Module):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.mu = nn.Linear(dim, dim, bias=False)
        self.logsigma = nn.Linear(dim, dim, bias=False)
        self.lin1 = nn.Linear(dim, dim)
        self.act = act
        self.lin2 = nn.Linear(dim, dim)

    def __call__(self, x, c):
        x = self.norm(x)
        x = x*self.logsigma(c).exp() + self.mu(c)
        h = self.lin2(self.act(self.lin1(x)))
        return x + h


class CNF(LightningModule):
    def __init__(self, dim, hidden=256, num_blocks=4, act=nn.GELU()):
        super().__init__()
        self.in_proj = nn.Linear(dim, hidden)
        self.t_proj = nn.Linear(1, hidden)
        self.out_proj = nn.Linear(hidden, dim)
        self.blocks = nn.ModuleList(ResBlock(hidden, act) for _ in range(num_blocks))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-5)

    def __call__(self, t, x):
        x = self.in_proj(x)
        c = self.t_proj(t)
        for block in self.blocks:
            x = block(x, c)
        x = self.out_proj(x)
        return x

    def training_step(self, batch, batch_idx):
        x0, x1, t = batch
        xt = t * x1 + (1 - t) * x0 # + sigma_min * t * x0

        flow = self(t, xt)
        target_flow = (x1 - x0) # +sigma_min * x0
        loss = torch.nn.functional.mse_loss(flow, target_flow)
        return loss

In [4]:
model = CNF(2)
dataset = CheckBoardDataset(samples=128 * 1024)
trainer = Trainer(
    max_epochs=100,
    callbacks=[
        callbacks.RichProgressBar(),
        callbacks.RichModelSummary(),
        callbacks.StochasticWeightAveraging(0.01),
    ],
)
trainer.fit(model, dataset.loader(batch_size=1024))

Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISI

Output()

In [None]:
def push_foward(model, x, n_steps=10):
    dt = 1 / n_steps
    xt = [x]
    for t in tqdm(torch.arange(0., 1., dt)):
        t = t* torch.ones(x.shape[0], 1)
        x = x + dt * model(t, x)
        xt.append(x.detach())
    return xt

In [None]:
x0, x1, t = next(iter(dataset.loader(batch_size=64*1024)))
xt = push_foward(model, x0, n_steps=10)

In [None]:
plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.title('x0')
plt.hist2d(*x0.T, bins=100)
plt.subplot(132)
plt.title('x1')
plt.hist2d(*xt[-1].T, bins=100)
plt.subplot(133)
plt.title('target')
plt.hist2d(*x1.T, bins=100)
plt.show()

plt.figure(figsize=(len(xt)*2, 2))
for i, x in enumerate(xt):
    plt.subplot(1, len(xt), i+1)
    plt.title(f"t={i/(len(xt)-1):.2f}")
    plt.hist2d(x[:, 0], x[:, 1], bins=100)
plt.show()