In [None]:
import torch
import torchvision as tv
import normflows as nf

from IPython.display import clear_output

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import numpy as np

from torch.utils.data import Dataset, DataLoader

# own imports
from src.bases import CategoricalBase
from src.datasets import MNISTSampler, SXDataset
from src.utils import check_mem
from src.flows import SXGlowBlock, Squeeze, Sigmoid
from src.models import SXNormalizingFlow

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
check_mem(device)

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# MNIST data

In [None]:
img_size = 32
dataset_path = 'data'

transform = tv.transforms.Compose([
    tv.transforms.Resize((img_size, img_size)),
    tv.transforms.ToTensor(),
    tv.transforms.Lambda(lambda x: torch.clip(x,0,1))
])

dataset = tv.datasets.MNIST(root=dataset_path,
                                     download=True, 
                                     transform=transform)

In [None]:
# Indices of digit distributions, must be a subset of {0, ..., 9}
# Choose one of the following:
# 1. For the all digits experiment.
digit_idxs = list(range(10))
# 2. For the {0, 1} experiment:
# digit_idxs = [0, 1]

# input distributions weights
alphas = torch.ones(len(digit_idxs))/len(digit_idxs) 

x_samplers = []
for digit_idx in digit_idxs:
    x_samplers.append(MNISTSampler(dataset, digit_idx, device=device))

s_base = CategoricalBase(alphas, one_hot_encoded=True)
dataset = SXDataset(x_samplers, s_base, num_samples=64)

# Model implementation

In [None]:
# Number of scales (a.k.a. levels)
L = 4

# Number of flows per scale
K = 16

input_shape = (1, 32, 32)
n_dims = np.prod(input_shape)
channels = 1
hidden_channels = 256
split_mode = 'channel'
scale = True

z_bases = []
merges = []
flows = []
for i in range(L):
    print(f'Scale {i}')
    
    flows_ = []
    for j in range(K):
        flows_ += [SXGlowBlock(channels * 2 ** (L + 1 - i), 
                               hidden_channels, 
                               context_dim=s_base.dim,
                               split_mode=split_mode, 
                               scale=scale)]
    flows_ += [Squeeze()]

    # Add (stretched) sigmoid transformation as last flow
    if i == L-1:
        flows_ += [Sigmoid(eps=1e-5)]
    
    flows += [flows_]
    
    if i > 0:
        merges += [nf.flows.Merge()]
        latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i), 
                        input_shape[2] // 2 ** (L - i))
    else:
        latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L, 
                        input_shape[2] // 2 ** L)

    z_bases += [nf.distributions.DiagGaussian(latent_shape, trainable=False)]


# Construct flow model with the multiscale architecture
model = SXNormalizingFlow(z_bases, s_base, flows, merges)

# Training

In [None]:
epochs = 50000

batch_size = 32
lr = 1e-4

# print metrics every `print_epochs` epochs
print_epochs = 1000 

save_epochs = 5000
milestones = []


optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

model.to(device)
alphas = alphas.to(device)

kld_losses = []
l2_losses = [] 
losses = []

# annealing schedule
temp = 10**torch.linspace(1, -2, epochs)

train_loop = tqdm(range(epochs))
for epoch in train_loop:

    with torch.no_grad():
        dataset = SXDataset(x_samplers, s_base, num_samples=batch_size)
        s, x = dataset[:]
        
    # Compute KLD loss
    kld_loss = model.forward_kld(x, s=s)
        
    # Compute L2 loss
    z, _ = model.inverse_and_log_det(x, s)
    l2_loss = torch.mean(torch.sum((x - model.bar(z))**2, dim=1))
    
    if torch.isnan(l2_loss):
        loss = kld_loss
    else:
        loss = temp[epoch]*l2_loss + kld_loss
    
    
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1., norm_type=2.0, error_if_nonfinite=False, foreach=True)
        
        optimizer.step()

    l2_losses.append(l2_loss.item())
    kld_losses.append(kld_loss.item())

    losses.append(loss.item())
    train_loop.set_postfix(l2_loss=l2_loss.item(), 
                           kld_loss=kld_loss.item(), 
                           loss=loss.item(),
                           lr=optimizer.param_groups[0]["lr"], 
                           temp=temp[epoch].item())

    # Check learning at training time
    if (epoch + 1) % print_epochs == 0:
        clear_output(wait=True)
        print(f'Epoch: {epoch}')
        print(f'KLD loss: {kld_loss.item()}')
        print(f'lr: {optimizer.param_groups[0]["lr"]}')
        
        # Plot learned marginals
        temperature = 1.
        num_samples = 32
        for i in range(len(digit_idxs)):
            with torch.no_grad():
                s = model.s_base.encode(i*torch.ones((num_samples, 1), dtype=torch.int64, device=device))
                images, _ = model.sample(num_samples=num_samples, s=s, temperature=temperature)
                imshow(tv.utils.make_grid(images.detach().cpu()))

        # Plot sample from learned barycenter
        print('Sample from barycenter.')
        with torch.no_grad():
            images = model.bar_sample(num_samples=num_samples, temperature=temperature)
            imshow(tv.utils.make_grid(images.detach().cpu()))

    torch.cuda.empty_cache()
    scheduler.step()

In [None]:
# Plot losses
plt.plot(-np.log(np.abs(kld_losses)), 'b.-', label='KLD loss')
plt.legend()
plt.show()

plt.plot(np.log(l2_losses), 'y.-', label='L2 loss')
plt.legend()
plt.show()

plt.plot(losses, 'r.-', label='Total loss')
plt.legend()
plt.show()