In [None]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import math
import lightning as L
from sklearn.datasets import make_moons
from torch.utils.data import DataLoader, Dataset, random_split
import sys
import torch
from torch.distributions import MultivariateNormal
sys.path.append("..")
from src.utils import MoonsDataset
from src.model import RealNVP
from src.dataset import ConditionalMoonsDataset
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [None]:
mps_device = torch.device("mps")

### Unconditional Moons

In [None]:
N = 2000
a = int(N * 0.8)
b = N - a
dataset = MoonsDataset(n_sample=N, random_state=3)
train_data, val_data = random_split(dataset, [a, b])
train_loader = DataLoader(train_data, batch_size=200)
val_loader = DataLoader(val_data, batch_size=b)

In [None]:
m = np.array(dataset.data)
plt.scatter(m[:, 0], m[:, 1], alpha=0.1)
plt.show()

In [None]:
model = RealNVP(d_model=128, n_layers=4, d_x=2)
trainer = L.Trainer(max_epochs=200, log_every_n_steps=1)
trainer.fit(model, train_loader, val_loader)

In [None]:
x, y = torch.meshgrid(
    torch.arange(-1.2,2.5,0.1), 
    torch.arange(-1.25,1.5,0.1),
    indexing='xy'
)
grid = torch.stack((x, y), dim=-1).flatten(end_dim=1)

In [None]:
with torch.no_grad():
    Z, log_det = model(grid)
log_probs = model.log_prob(Z.to(mps_device), log_det.to(mps_device))

In [None]:
z = log_probs.cpu().view(*x.shape)

In [None]:
plt.pcolormesh(x, y, np.exp(z))
plt.show()

In [None]:
with torch.no_grad():
    sample = model.to(mps_device).sample((400,))

In [None]:
plt.scatter(sample[:, 0].cpu(), sample[:, 1].cpu(), alpha=0.5)
plt.show()

### Conditional Moons

In [None]:
N = 4000
a = int(N * 0.8)
b = N - a
dataset = ConditionalMoonsDataset(n_sample=N)
train_data, val_data = random_split(dataset, [a, b])
train_loader = DataLoader(train_data, batch_size=800)
val_loader = DataLoader(val_data, batch_size=b)

In [None]:
callbacks = [EarlyStopping(monitor="val_loss", mode="min", patience=20)]
model = RealNVP(d_model=32, n_layers=4, d_x=2, d_theta=2, lr=1e-3, weight_decay=0)
trainer = L.Trainer(max_epochs=500, log_every_n_steps=1, callbacks=callbacks)
trainer.fit(model, train_loader, val_loader)

In [None]:
M = 400
with torch.no_grad():
    sample = model.to(mps_device).sample(M, dataset.get_observed_data())

In [None]:
plt.scatter(sample[:, 0].cpu(), sample[:, 1].cpu(), alpha=0.5)
plt.show()
# strange, this is learning the conditional distribution of X | \theta = 0, not 
# \theta | X = 0 ...

In [None]:
test = ConditionalMoonsDataset(20000)

In [None]:
samples = []
eps = 0.005
for e in test:
    if (e[0]**2).sum() < eps:
        samples.append(e[1])
samples = np.array(samples)

In [None]:
plt.scatter(samples[:, 0], samples[:, 1], alpha=.1)