## Config

In [None]:
%load_ext autoreload
%autoreload 2

import torch
from pathlib import Path
from experiment import data_path

device = torch.device('cpu')
model_name = 'mixture-plane-cef-joint'
gen_path = data_path / 'generated' / model_name

## Data

In [None]:
import data

train_data = data.PaperAffineSubspace(size=64000)
val_data = data.PaperAffineSubspace(size=12800)

## Model

In [None]:
from nflows import cef_models, flows

flow =  cef_models.MixturePlaneCEFlow(
    base_flow_class=flows.SimpleGlow,
).to(device)

## Train

Schedule training

In [None]:
import torch.optim as opt

optim = opt.Adam(flow.parameters(), lr=0.001)
epochs = 10

def schedule():
    '''Yield epoch weights for likelihood and recon loss, respectively'''
    for _ in range(epochs):
        yield 0.01, 100000

Create dataloaders

In [None]:
from torch.utils.data import DataLoader

train_batch_size = 64
test_batch_size = 512

train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers=30)
#test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=True, num_workers=30)

Train the flow

In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm


for epoch, (alpha, beta) in enumerate(schedule()):
    
    # Train for one epoch
    flow.train()
    progress_bar = tqdm(enumerate(train_loader))
    
    for batch, point in progress_bar:
        point = point.to(device)
        optim.zero_grad()

        # Compute reconstruction error
        with torch.set_grad_enabled(beta > 0):
            mid_latent, _ = flow.embedding.forward(point)
            reconstruction, log_conf_det = flow.embedding.inverse(mid_latent)
            reconstruction_error = torch.mean((point - reconstruction)**2)

        # Compute log likelihood
        with torch.set_grad_enabled(alpha > 0):
            log_pu = flow.distribution.log_prob(mid_latent)
            log_likelihood = torch.mean(log_pu + log_conf_det)

        # Training step
        loss = - alpha*log_likelihood + beta*reconstruction_error
        loss.backward()
        optim.step()

        # Display results
        progress_bar.set_description(f'[E{epoch} B{batch}] | loss: {loss: 6.2f} | LL: {log_likelihood:6.2f} '
                                     f'| recon: {reconstruction_error:6.7f} ')

## Plot the Learned Distribution

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm
plt.style.use('default')


rgbs = [(250/255,171/255,54/255),(223/255,220/255,119/255),(217/255,255/255,200/255),
         (129/255,208/255,177/255), (36/255,158/255,160/255)]
custom_cm = LinearSegmentedColormap.from_list("CEF_colors", rgbs, N=21)
lattice_num = 120
extent = 1.5


def model_likelihood_grid(x, y, z):
    points = torch.Tensor(np.stack((x, y, z)).reshape(1, 1, 3, -1).T)
    log_likelihood = torch.exp(flow.log_prob(points))
    return log_likelihood.reshape(lattice_num, lattice_num).detach().numpy()

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(1, 1, 1, projection='3d')
(x, y) = np.meshgrid(np.linspace(-extent, extent, lattice_num), np.linspace(-extent, extent, lattice_num))
learned_mat, _ = flow.embedding.inverse(torch.eye(2).reshape(2, 2))
learned_mat = learned_mat.squeeze().T
plane_coords = torch.linalg.solve(learned_mat[:2,:], torch.Tensor(np.stack((x, y), axis=2).reshape(-1, 2, 1)))
z = (learned_mat[2:,:] @ plane_coords).reshape(lattice_num, lattice_num).detach().numpy()
likelihoods = model_likelihood_grid(x, y, z)

ax.plot_surface(x, y, z, facecolors=custom_cm(likelihoods))
ax.set_xlim(-extent, extent)
ax.set_ylim(-extent, extent)
ax.set_zlim(-extent, extent)