In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
from sklearn.datasets import make_swiss_roll
from sklearn.utils import shuffle
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from IPython.display import clear_output

device = "cuda" if torch.cuda.is_available() else "cpu"

## Dataset

In [None]:
DS_NAME = "gaussians"

if DS_NAME == "swiss":
    def make_swiss_dataset(num_samples):
        X0, _ = make_swiss_roll(num_samples // 2, noise=0.3, random_state=0)
        X1, _ = make_swiss_roll(num_samples // 2, noise=0.3, random_state=0)
        X0 = X0[:, [0, 2]]
        X1 = X1[:, [0, 2]]
        X1 = -X1
        X, y = shuffle(
            np.concatenate([X0, X1], axis=0),
            np.concatenate([np.zeros(len(X0)), np.ones(len(X1))], axis=0),
            random_state=0)
        X = (X - X.mean(axis=0)) / X.std(axis=0)
    
        return X.astype(np.float32), y

    X, Y = make_swiss_dataset(2000)
elif DS_NAME == "gaussians":
    factor, sigma = 0.5, 0.3
    mu1, sigma1 = factor * np.array([-1, -1]), sigma
    mu2, sigma2 = factor * np.array([1, 1]), sigma
    ds_size = 10_000
    
    X1 = np.random.randn(ds_size//2, len(mu1)) * sigma1 + mu1
    Y1 = np.ones(len(X1)) * 0.
    X2 = np.random.randn(ds_size//2, len(mu2)) * sigma2 + mu2
    Y2 = np.ones(len(X2)) * 1.
    
    X = np.concatenate([X1, X2]).astype(np.float32)
    Y = np.concatenate([Y1, Y2])

MEAN, STD = np.mean(X, axis=0), np.std(X, axis=0)

plt.figure(figsize=(4, 4))
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=Y)
plt.show()

## Construct Loader

In [None]:
class LabeledDataset(Dataset):
    def __init__(self, X, Y):
        self.X = (X - MEAN[None]) / STD[None]
        self.Y = Y.astype(np.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [None]:
BATCH_SIZE = 256
train_loader = DataLoader(
    LabeledDataset(X=X, Y=Y),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

## Unconditional

In [None]:
%load_ext autoreload
%autoreload 2
    
from pfgmpp import PFGMPP
from models import BaseField

model = BaseField(dim=2, hidden_dim=64, n_layers=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
D = 1024

pfgmpp = PFGMPP(
    data_dim=2,
    model=model,
    optimizer=optimizer,
    device=device,
    sigma_min=0.002,
    sigma_max=80.,
)

In [None]:
n_epochs = 20
sample_every = 10_000
n_gens = 5000

for _ in range(n_epochs):
    clear_output()
    gens = pfgmpp.sample(n_gens, num_steps=32) * STD[None] + MEAN[None]
    plt.figure(figsize=(4, 4))
    sns.scatterplot(x=gens[:, 0], y=gens[:, 1])
    plt.xlabel(""); plt.ylabel("")
    plt.show()

    pfgmpp.train(train_loader, n_iters=sample_every, log_every=5_000)

## Class-conditional

In [None]:
%load_ext autoreload
%autoreload 2
    
from pfgmpp import PFGMPP
from models import BaseField

N_CLASSES = 2

model = BaseField(dim=2, hidden_dim=64, n_layers=3, n_classes=N_CLASSES).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
D = 1024

pfgmpp = PFGMPP(
    data_dim=2,
    model=model,
    optimizer=optimizer,
    device=device,
    sigma_min=0.002,
    sigma_max=80.,
)

In [None]:
sample_every = 25_000
n_gens = 5000
label = torch.concatenate([
    torch.zeros(n_gens//2), torch.ones(n_gens//2),
]).to(device).to(torch.long)

for n_epochs in range(200):
    clear_output()
    gens = pfgmpp.sample(n_gens, num_steps=32, label=label) * STD[None] + MEAN[None]
    plt.figure(figsize=(4, 4))
    sns.scatterplot(x=gens[:, 0], y=gens[:, 1], hue=label)
    plt.xlabel(""); plt.ylabel("")

    pfgmpp.train(train_loader, n_iters=sample_every, log_every=5_000)

In [None]:
n_gens = 20_000
label = torch.concatenate([
    torch.zeros(n_gens//2), torch.ones(n_gens//2),
]).to(device).to(torch.long)
gens = pfgmpp.sample(n_gens, num_steps=32, label=label) * STD[None] + MEAN[None]
plt.figure(figsize=(4, 4))
sns.scatterplot(x=gens[:, 0], y=gens[:, 1], hue=label);
plt.show()