In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

import normflows as nf

from tqdm.notebook import tqdm

from scipy.stats import special_ortho_group

# imports from own modules
from src.flows import SXAffineCouplingBlock, Permute
from src.bases import SXBase, CategoricalBase, GaussianBase
from src.models import SXNormalizingFlow
from src.datasets import SXDataset
from src.utils import check_mem, ConditionalMLP
from src.samplers import Gaussian, Uniform, LocationScatterSampler, LocationScatterBenchmark
from src.metrics import get_L2_UVP, get_BW2_UVP

In [None]:
# Use cuda if available

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
check_mem(device)

In [None]:
# Config

# type of base distribution: either 'gaussian', or 'uniform'
dataset_type = 'gaussian' 

# number of input distributions
n = 4 

# input dimension
d = 2

# distribution weights
alphas = torch.tensor([0.1, 0.2, 0.3, 0.4])

# Data generation

In [None]:
# Generate datasets

if dataset_type == 'gaussian':
    base_sampler = Gaussian(d, device=device)
elif dataset_type == 'uniform':
    base_sampler = Uniform(d, device=device)
else:
    raise ValueError(f'Unknown dataset_type: {dataset_type}')

samplers = []

for i_sampler in range(n):
    rotation = torch.tensor(special_ortho_group.rvs(d), dtype=torch.float)
    b = pow(4, (1/(d-1)))
    L = torch.diag(torch.tensor([0.5*(b**i) for i in range(d)]))
    weight = rotation.T @ L @ rotation
    bias = torch.zeros(d)
    sampler = LocationScatterSampler(base_sampler, weight, bias, device=device)
    samplers.append(sampler)    

benchmark = LocationScatterBenchmark(base_sampler, samplers, alphas, device=device)

# Model implementation

In [None]:
### Model implementation

# number of scales (a.k.a. levels)
L = int(np.log2(d)) 

# flows per scale
if L == 1:
    K = 32 
elif L >= 2 and L <= 4:
    K = 16
else:
    K = 8

split_mode = 'channel'
scale = True

# Latent space dimensions
z_bases = []
latent_dims = [2] + [2**i for i in range(1, L)] # latent dimensions at each scale
assert sum(latent_dims) == d

merges = []
flows = []
for i in range(L):
    flows_ = []
    for j in range(K):
        # Get conditional MLP input/output dimensions at each level
        z1, z2 = torch.rand((10, sum(latent_dims[:i+1]))).chunk(2, dim=1)
        dim1 = z1.shape[1]
        dim2 = 2*z2.shape[1] # mean and scale parameter for each channel in z2

        # Add neural network conditioner
        param_map = ConditionalMLP([dim1, 64, 64, dim2], context_dim=benchmark.num, init_zeros=True)
        # Add flow layer
        flows_.append(SXAffineCouplingBlock(param_map, scale=True, scale_map="exp", split_mode="channel"))
        # Randomly permute dimensions
        flows_.append(Permute(sum(latent_dims[:i+1]), mode='shuffle'))

    flows.append(flows_)

    z_bases.append(nf.distributions.DiagGaussian(latent_dims[i], trainable=False))
    if i > 0:
        merges.append(nf.flows.Merge())

# Instantiate model
s_base = CategoricalBase(alphas, one_hot_encoded=True)
model = SXNormalizingFlow(z_bases, s_base, flows, merges)

# Training

In [None]:
epochs = 3
batch_size = 10000

# Learning rate
if dataset_type == 'gaussian':
    lr = 1e-3
else:
    lr = 1e-4

# print metrics every `print_epochs` epochs
print_epochs = 10 

# decreasing weights schedule
temp = 10**torch.linspace(0, -2, epochs)
    
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                       mode='min', 
                                                       factor=0.8, 
                                                       patience=1000, 
                                                       threshold=0.0001, 
                                                       threshold_mode='abs', 
                                                       min_lr=1e-8, 
                                                       eps=1e-08, 
                                                       verbose='deprecated')

model.to(device)
alphas = alphas.to(device)

kld_losses = []
l2_losses = [] 
losses = []
scores = []

train_loop = tqdm(range(epochs))

for epoch in train_loop:
    
    # Compute KLD loss
    s, x = SXDataset(benchmark.samplers, model.s_base, num_samples=batch_size)[:]
    kld_loss = model.forward_kld(x, s=s)
    
    # Compute L2 loss
    z, _ = model.inverse_and_log_det(x, s)
    l2_loss = torch.mean(torch.sum((x - model.bar(z))**2, dim=1))
        
    loss = temp[epoch]*l2_loss + kld_loss
    
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2., norm_type=2.0, error_if_nonfinite=False, foreach=True)
        optimizer.step()

    l2_losses.append(l2_loss.item())
    kld_losses.append(kld_loss.item())

    if epoch % print_epochs == 0:
        model.eval()
        score = get_L2_UVP(benchmark, model, alphas, batch_size=1000, device=device)
        model.train()
        scores.append(score)
    
    losses.append(loss.item())
    train_loop.set_postfix(l2_loss=l2_loss.item(), 
                           kld_loss=kld_loss.item(), 
                           lr=optimizer.param_groups[0]["lr"], 
                           temp=temp[epoch].item(),
                           score=score)

    scheduler.step(kld_loss)
    #scheduler.step()

# Evaluation

In [None]:
### Plot training losses

plt.plot(kld_losses, 'b.-', label='KL div losses')
plt.plot(l2_losses, 'y.-', label='L2 losses')
plt.legend()
plt.show()

plt.plot(losses, 'r.-', label='Total losses')
plt.legend()
plt.show()

plt.plot(np.log(scores), 'r.-')
plt.title('L2_UVP score (log-%)')
plt.show()

In [None]:
# Compute metrics

# number of trials for computing standard deviation of metrics
n_estimators = 10

L2_UVPs = {'mean': [], 'std': []}
BW2_UVPs = {'mean': [], 'std': []}

# L2_UVP
scores = torch.zeros(n_estimators)
for i in range(scores.numel()):
    scores[i] = get_L2_UVP(benchmark, model, alphas, batch_size=10000, device=device)
L2_UVPs['mean'] = torch.mean(scores).item()
L2_UVPs['std'] = torch.std(scores).item()

# BW2_UVP
scores = torch.zeros(n_estimators)
for i in range(scores.numel()):
    scores[i] = get_BW2_UVP(benchmark, model)
BW2_UVPs['mean'] = torch.mean(scores).item()
BW2_UVPs['std'] = torch.std(scores).item()

# Print metrics
print(f'L2_UVP (mean +- std): {L2_UVPs['mean']} +- {L2_UVPs['std']}\n')
print(f'BW2_UVP (mean +- std): {BW2_UVPs['mean']} +- {BW2_UVPs['std']}\n')