In [None]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import math
import lightning as L
from sklearn.datasets import make_moons
from torch.utils.data import DataLoader, Dataset, random_split
import sys
import torch
from torch.distributions import MultivariateNormal
sys.path.append("..")
from src.utils import MoonsDataset
from src.model import RealNVP

In [None]:
mps_device = torch.device("mps")

### Unconditional Moons

In [None]:
N = 2000
a = int(N * 0.8)
b = N - a
dataset = MoonsDataset(n_sample=N, random_state=3)
train_data, val_data = random_split(dataset, [a, b])
train_loader = DataLoader(train_data, batch_size=200)
val_loader = DataLoader(val_data, batch_size=b)

In [None]:
m = np.array(dataset.data)
plt.scatter(m[:, 0], m[:, 1], alpha=0.1)
plt.show()

In [None]:
model = RealNVP(d_model=64, n_layers=4, d_x=2)
trainer = L.Trainer(max_epochs=200, log_every_n_steps=1)
trainer.fit(model, train_loader, val_loader)

In [None]:
x, y = torch.meshgrid(
    torch.arange(-1.2,2.5,0.1), 
    torch.arange(-1.25,1.5,0.1),
    indexing='xy'
)
grid = torch.stack((x, y), dim=-1).flatten(end_dim=1)

In [None]:
with torch.no_grad():
    Z, log_det = model(grid)
log_probs = model.log_prob(Z.to(mps_device), log_det.to(mps_device))

In [None]:
z = log_probs.cpu().view(*x.shape)

In [None]:
plt.pcolormesh(x, y, np.exp(z))
plt.show()

In [None]:
with torch.no_grad():
    sample = model.to(mps_device).sample((400,))

In [None]:
plt.scatter(sample[:, 0].cpu(), sample[:, 1].cpu(), alpha=0.5)

### Conditional Moons

In [None]:
N = 200
a = np.random.uniform(low=-math.pi/2, high=math.pi/2, size=N)
r = np.random.normal(loc=0.1, scale=0.01, size=N)
p = np.stack([r * np.cos(a) + 0.25, r * np.sin(a)])

In [None]:
plt.scatter(p[0], p[1])

In [None]:
plt.scatter(p[0], p[1])

In [None]:
plt.scatter(p[0] - p[1], p[0] + p[1])
plt.scatter(- p[0] + p[1], - p[0] - p[1])

In [None]:
# latents 
theta = np.random.uniform(-1, 1, (2, N))
b0 = (-theta[0] + theta[1]) / math.sqrt(2)
b1 = -(np.abs(theta[0] + theta[1])) / math.sqrt(2)

In [None]:
x = np.stack([
    p[0] + b0, p[1] + b1
])

In [None]:
plt.scatter(x[0], x[1])