## Config

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from experiment import data_path

model_name = 'sphere-cef-joint'
gen_path = data_path / 'generated' / model_name

## Generate data non-uniformly on sphere

In [None]:
from torch.utils.data import DataLoader, random_split
import data
import numpy as np

num_samples = 1000
batch_size = 100

mu = [-1, -1, 0.0]
sigma = [[1,0,0], [0,1,0], [0,0,1]]

data = data.Sphere(
    manifold_dim=2, 
    ambient_dim=3, 
    size=num_samples, 
    mu=mu, 
    sigma=sigma)

In [None]:
from nflows import cef_models

flow = cef_models.SphereCEFlow()
conf_embedding = flow.embedding
backbone = flow.distribution

## Train

Schedule training

In [None]:
import torch.optim as opt

batch_size = 100
optim = opt.Adam(flow.parameters(), lr=0.005)
scheduler = opt.lr_scheduler.MultiStepLR(optim, milestones=[40], gamma=0.5)

def schedule():
    '''Yield epoch weights for likelihood and recon loss, respectively'''
    for _ in range(45):
        yield 10, 10000
        scheduler.step()
        
loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=6)

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch


points = data.points[:num_samples]

# Initialize model
with torch.no_grad():
    gen_samples = flow.sample(num_samples)
    sample_mid_latent, _ = flow.embedding.forward(points)
    sample_recons, _ =  flow.embedding.inverse(sample_mid_latent)

# Plot data and recons before training
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
point_plot = ax.scatter(points[:,0].cpu(), points[:,1].cpu(), points[:,2].cpu(), 
                        color='#faab36')
recon_plot = ax.scatter(sample_recons[:,0].cpu(), sample_recons[:,1].cpu(),
                        sample_recons[:,2].cpu(), color='#249ea0')
ax.auto_scale_xyz([-1.3, 1.3], [-1.3, 1.3], [-1, 1]) # Correct aspect ratio manually
ax.view_init(elev=20, azim=260)

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

for epoch, (alpha, beta) in enumerate(schedule()):
    
    # Train for one epoch
    flow.train()
    progress_bar = tqdm(enumerate(loader))
    
    for batch, point in progress_bar:
        optim.zero_grad()

        # Compute reconstruction error
        with torch.set_grad_enabled(beta > 0):
            mid_latent, _ = conf_embedding.forward(point)
            reconstruction, log_conf_det = conf_embedding.inverse(mid_latent)
            reconstruction_error = torch.mean((point - reconstruction)**2)

        # Compute log likelihood
        with torch.set_grad_enabled(alpha > 0):
            log_pu = backbone.log_prob(mid_latent)
            log_likelihood = torch.mean(log_pu - log_conf_det)

        # Training step
        loss = - alpha*log_likelihood + beta*reconstruction_error
        loss.backward()
        optim.step()

        # Display results
        progress_bar.set_description(f'[E{epoch} B{batch}] | loss: {loss: 6.2f} | LL: {log_likelihood:6.2f} '
                                     f'| recon: {reconstruction_error:6.5f} ')

In [None]:
# Plot data and recons
with torch.no_grad():
    sample_mid_latent, _ = conf_embedding.forward(points)
    sample_recons, _ =  conf_embedding.inverse(sample_mid_latent)

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
point_plot = ax.scatter(points[:,0], points[:,1], points[:,2], color='#faab36')
recon_plot = ax.scatter(sample_recons[:,0], sample_recons[:,1], sample_recons[:,2], 
                        color='#249ea0')
ax.auto_scale_xyz([-1.3, 1.3], [-1.3, 1.3], [-1, 1]) # Correct aspect ratio manually
ax.view_init(elev=20, azim=260)

In [None]:
# Plot generated samples to gauge density
gen_samples = flow.sample(num_samples).detach()

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
point_plot = ax.scatter(gen_samples[:,0], gen_samples[:,1], gen_samples[:,2], color='#faab36')
ax.auto_scale_xyz([-1.3, 1.3], [-1.3, 1.3], [-1, 1]) # Correct aspect ratio manually
ax.view_init(elev=20, azim=260)

## Plot Densities and Samples

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

rgbs = [(250/255,171/255,54/255),(223/255,220/255,119/255),(217/255,255/255,200/255),
         (129/255,208/255,177/255), (36/255,158/255,160/255)] # Custom color scheme
custom_cm = LinearSegmentedColormap.from_list("CEF_colors", rgbs, N=21)

In [None]:
mkdir figures

In [None]:
# Plot the density of data distribution
from scipy.special import erf

mu_norm = np.linalg.norm(mu)
const = np.exp(-mu_norm**2 / 2) / (2**(5/2) * np.pi**(3/2))

def data_likelihood(x, y, z): # Density for 2d Sphere dataset
    t = x*mu[0] + y*mu[1] + z*mu[2]
    density = (2 * t) + np.sqrt(2*np.pi) * (t**2 + 1) * np.exp(t**2 / 2) * (1 + erf(t / np.sqrt(2)))
    return density * const

def plot_data_density():
    # create grid of points on spherical surface
    u = np.linspace(0, 2 * np.pi, 240) # azimuthal angle
    v = np.linspace(0, np.pi, 120) # polar angle

    # create the sphere surface in xyz coordinates
    XX = np.outer(np.cos(u), np.sin(v))
    YY = np.outer(np.sin(u), np.sin(v))
    ZZ = np.outer(np.ones(np.size(u)), np.cos(v))

    density_grid_2 = np.zeros_like(XX)
    grid_points = np.zeros([len(u), 3], dtype=np.float32)
    for i in range(len(v)):
        z = np.cos(v[i])
        s = np.sin(v[i])
        for j in range(len(u)):
            x = np.cos(u[j])*s
            y = np.sin(u[j])*s
            density_grid_2[j, i] = data_likelihood(x, y, z)
    
    # plot density as heatmap. for coloration values should fill (0,1)
    heatmap = density_grid_2 / np.max(density_grid_2)
    
    return XX, YY, ZZ, density_grid_2, heatmap
        
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')

XX, YY, ZZ, density_grid_data, heatmap = plot_data_density()

colorbar = cm.ScalarMappable(cmap=custom_cm)
colorbar.set_array(density_grid_data)
plt.colorbar(colorbar, pad=-0.02, fraction=0.026, format='%.2f')
ax.view_init(elev=20, azim=260)
ax.plot_surface(XX, YY, ZZ, cstride=1, rstride=1, facecolors=custom_cm(heatmap))
ax.auto_scale_xyz([-1.15, 1.15], [-1.15, 1.15], [-1, 1]) # Correct aspect ratio manually
ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_zticks([-1.0, -0.5, 0.0, 0.5, 1.0])
plt.tight_layout(pad=0, w_pad=0)
plt.savefig("figures/sphere-data-density.png", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Above should have similar distribution to original data distribution here
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
point_plot = ax.scatter(points[:,0], points[:,1], points[:,2], color='#faab36')
ax.view_init(elev=20, azim=260)
ax.set_xlim(-1.3, 1.3)  
ax.set_ylim(-1.3, 1.3) 
ax.set_zlim(-1.0, 1.0) 
ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_zticks([-1.0, -0.5, 0.0, 0.5, 1.0])
plt.savefig("figures/sphere-data-samples.png", bbox_inches='tight', dpi=300)

In [None]:
def likelihood_of_point(arr, manifold_model, density_model):
    with torch.no_grad():        
        grid_points = torch.from_numpy(arr)
        mid_latent, _ = manifold_model.forward(grid_points)
        _, log_conf_det = manifold_model.inverse(mid_latent)
        log_pu = density_model.log_prob(mid_latent)
        log_likelihood = log_pu - log_conf_det
        
    return torch.exp(log_likelihood).numpy()

def plot_model_density(manifold_model, density_model):
    # create grid of points on spherical surface
    u = np.linspace(0, 2 * np.pi, 240) # azimuthal angle
    v = np.linspace(0, np.pi, 120) # polar angle

    # create the sphere surface in xyz coordinates
    XX = np.outer(np.cos(u), np.sin(v))
    YY = np.outer(np.sin(u), np.sin(v))
    ZZ = np.outer(np.ones(np.size(u)), np.cos(v))

    density_grid = np.zeros_like(XX)
    grid_points = np.zeros([len(u), 3], dtype=np.float32)
    for i in range(len(v)):
        z = np.cos(v[i])
        s = np.sin(v[i])
        for j in range(len(u)):
            grid_points[j, 0] = np.cos(u[j])*s
            grid_points[j, 1] = np.sin(u[j])*s
            grid_points[j, 2] = z
    
        # Treat every point in grid as (x, y, z) data_point
        # Calculate likelihood from model in batches
        density_grid[:, i] = likelihood_of_point(grid_points, manifold_model, density_model)
    
    # plot density as heatmap. for coloration values should fill (0,1)
    heatmap = density_grid / np.max(density_grid_data)
    
    return XX, YY, ZZ, density_grid, heatmap

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')

XX, YY, ZZ, density_grid, heatmap = plot_model_density(conf_embedding, backbone)

colorbar = cm.ScalarMappable(cmap=custom_cm)
colorbar.set_array(density_grid_data) # Setting to density_grid_data for matching scales
plt.colorbar(colorbar, pad=-0.02, fraction=0.026, format='%.2f')
ax.view_init(elev=20, azim=260)
ax.plot_surface(XX, YY,  ZZ, cstride=1, rstride=1, facecolors=custom_cm(heatmap))
ax.auto_scale_xyz([-1.15, 1.15], [-1.15, 1.15], [-1, 1]) # Correct aspect ratio manually
ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_zticks([-1.0, -0.5, 0.0, 0.5, 1.0])
plt.tight_layout(pad=0, w_pad=0)
plt.savefig("figures/sphere-model-density.png", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Replot using trained density model
gen_samples = flow.sample(num_samples).detach()

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')
gen_plot = ax.scatter(gen_samples[:,0], gen_samples[:,1], gen_samples[:,2], color='#faab36')
ax.view_init(elev=20, azim=260)
ax.set_xlim(-1.3, 1.3)  
ax.set_ylim(-1.3, 1.3) 
ax.set_zlim(-1.0, 1.0) 
ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
ax.set_zticks([-1.0, -0.5, 0.0, 0.5, 1.0])
plt.savefig("figures/sphere-generated-samples.png", dpi=300)