## Config

Configure notebook, torch, and file paths

In [None]:
%load_ext autoreload
%autoreload 2
from nb_util import data_path, device

model_name = 'celeba-conformal-glow-joint'

## Load data

In [None]:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import flow.data as data

batch_size = 32
channels = 3
height = 64
width = 64

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

train_data = data.CelebA(root=data_path, split='train', transform=transform)
test_data = data.CelebA(root=data_path, split='test', transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=30)
cluster_loader = DataLoader(train_data, batch_size=1024, shuffle=True, num_workers=30)

## Cluster the Data

Use K-means

## Model

Construct manifold-learning component

In [None]:
import torch
import flow.components as comp

manifold_model = comp.Sequential(
    comp.Shift(shape=(channels, height, width)),
    comp.Scale(),
    
    comp.HouseholderConv(channels, kernel_size=2),
    comp.Shift(shape=(channels*4, height//2, width//2)),
    comp.Scale(),
    comp.HouseholderConv(channels*4, kernel_size=1),
    comp.Shift(shape=(channels*4, height//2, width//2)),
    comp.Scale(),
    
    comp.ConditionalConv(channels*4),
    comp.Shift(shape=(channels*4, height//2, width//2)),
    comp.Scale(),
    
    comp.Pad(channels*4, channels*2),
        
    comp.HouseholderConv(channels*2, kernel_size=2),
    comp.Shift(shape=(channels*8, height//4, width//4)),
    comp.Scale(),
    comp.HouseholderConv(channels*8, kernel_size=1),
    comp.Shift(shape=(channels*8, height//4, width//4)),
    comp.Scale(),
    
    comp.ConditionalConv(channels*8),
    comp.Shift(shape=(channels*8, height//4, width//4)),
    comp.Scale(),
    
    comp.Pad(channels*8, channels*4),
       
    comp.HouseholderConv(channels*4, kernel_size=2),
    comp.Shift(shape=(channels*16, height//8, width//8)),
    comp.Scale(),
    comp.HouseholderConv(channels*16, kernel_size=1),
    comp.Shift(shape=(channels*16, height//8, width//8)),
    comp.Scale(),
    
    comp.ConditionalConv(channels*16),
    comp.Shift(shape=(channels*16, height//8, width//8)),
    comp.Scale(),
    
    comp.Pad(channels*16, channels*8),
)
manifold_model.to(device)


# Check for runtime errors and initialize weights with first batch
num_recons = 4
init_x = next(iter(train_loader))[0].to(device)
with torch.no_grad():
    init_mid_latent = manifold_model.initialize(init_x)
    
sample_x = init_x[:num_recons] # Store some samples to visualize reconstructions

m = init_mid_latent[0].numel() // num_recons # Dimension of latent space
with torch.no_grad():
    manifold_model.data_to_latent(init_x, m)
    manifold_model.latent_to_data(init_mid_latent, m)

Construct density-learning component and concatenate the two models

In [None]:
density_model = comp.Sequential(
    comp.GlowNet(channels*8, k=3, l=3),
)
density_model.to(device)

full_model = manifold_model + density_model

# Initialize the weights of the density model and check for errors
with torch.no_grad():
    init_z = density_model.initialize(init_mid_latent)
    density_model.data_to_latent(init_mid_latent, m)
    density_model.latent_to_data(init_z, m)
    
f'Parameters: {sum(w.numel() for w in full_model.parameters() if w.requires_grad)}'

## Training

Sample some latents to show during training

In [None]:
# Generate 8 samples with reduced temperature for visualization
temp = 0.75
num_samples = 8

with torch.no_grad():
    latent_shape = density_model(init_mid_latent, inverse=True).shape[1:]

latent_samples = torch.normal(mean=torch.zeros(num_samples, *latent_shape), 
                              std=torch.ones(num_samples, *latent_shape)*temp)
latent_samples = latent_samples.to(device)

Schedule training

In [None]:
def schedule():
    '''Yields weights for density and reconstruction respectively'''
    for _ in range(10):
        yield 0, 10000
        
    # After manifold warmup, re-initialize density model    
    with torch.no_grad():
        sample_mid_latent = manifold_model(sample_x, inverse=True)
        sample_z = density_model.initialize(sample_mid_latent)
        
    while True:
        yield 0.001, 10000

Train the density with log-likelihood

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as opt
from nb_util import compare_batches, display_batch, update_displayed_batch

full_opt = opt.Adam(full_model.parameters(), lr=0.00001)             
const = -(m/2) * np.log(2*np.pi) # Constant for log likelihood

full_model.train()

gen_samples = full_model(latent_samples.to(device))
sample_recons =  manifold_model(manifold_model(sample_x, inverse=True))
fig1, display_id1 = display_batch(gen_samples)
fig2, display_id2 = compare_batches(sample_x, sample_recons)

for epoch, (alpha, beta) in enumerate(schedule()):
    for batch, (image, _) in enumerate(train_loader):
        image = image.to(device)
        full_opt.zero_grad()

        # Compute reconstruction error
        with torch.set_grad_enabled(beta > 0):
            mid_latent, _ = manifold_model.data_to_latent(image, m)
            _, manifold_log_det = manifold_model.latent_to_data(mid_latent, m)
            reconstruction = manifold_model(mid_latent)
            reconstruction_error = torch.mean((image - reconstruction)**2)

        # Compute log likelihood
        with torch.set_grad_enabled(alpha > 0):
            z, density_log_det = density_model.data_to_latent(mid_latent, m)
            log_pz = const - torch.sum(z**2, axis=1)/2
            half_log_det = manifold_log_det + density_log_det
            log_likelihood = torch.mean(log_pz + half_log_det)

        # Training step
        loss = - alpha*log_likelihood + beta*reconstruction_error
        loss.backward()
        full_opt.step()

        # Display results
        print(f'[E{epoch} B{batch}] | '
              f'loss: {loss: 6.2f} '
              f'| LL: {log_likelihood:6.2f} '
              f'| logp(z): {torch.mean(log_pz):6.2f} '
              f'| logdet: {torch.mean(half_log_det):6.2f}'
              f'| manifold logdet: {torch.mean(manifold_log_det):6.2f}'
              f'| density logdet: {torch.mean(density_log_det):6.2f}'
              f'| recon: {reconstruction_error:6.5f}', end='\r')
        if batch % 10 == 0:
            with torch.no_grad():
                gen_samples = full_model(latent_samples)
                sample_recons =  manifold_model(manifold_model(sample_x, inverse=True))
            update_displayed_batch(gen_samples, fig1, display_id1)
            compare_batches(sample_x, sample_recons, fig2, display_id2)
            
    torch.save(full_model.state_dict(), f'models/{model_name}-e{epoch}.pt')

Compute reconstruction error

In [None]:
rec_errors = []

for image, _ in test_loader:
    image = image.to(device)

    with torch.no_grad():
        mid_latent = manifold_model(image, inverse=True)
        reconstruction = manifold_model(mid_latent)
        reconstruction_error = torch.mean((image - reconstruction)**2).detach()

    rec_errors.append(reconstruction_error)

f'Reconstruction error: {np.mean([float(err) for err in rec_errors])}'

## Generate Images

In [None]:
from nb_util import generate_image_samples

eval_epoch = 20
full_model.load_state_dict(torch.load(f'models/{model_name}-e{eval_epoch}.pt'))

generate_image_samples(
    100, full_model, model_name, latent_shape=latent_shape, batch_size=16, temp=temp)

## Debug model

Check stats for all the parameters. Check for invertibility (the model should be left invertible but not necessarily right invertible).

In [None]:
model = full_model

with torch.no_grad():
    for component in model.components:
        print(component.__class__.__name__)

        for parameter in component.parameters():
            if parameter.requires_grad:
                print(f'\tParam shape: {parameter.shape}')
                print(f'\t\tmin:  {torch.min(parameter):6.3f}')
                print(f'\t\tmax:  {torch.max(parameter):6.3f}')
                print(f'\t\tmean:  {torch.mean(parameter):6.3f}')
                print(f'\t\tnorm: {torch.linalg.norm(parameter):6.3f}')


    print('Invertibility check')
    right_invertibility = torch.max(model(model(sample_x, inverse=True)) - sample_x)
    print(f'\tRight invertibility: {right_invertibility:6.5f}')

    left_invertibility = torch.max(model(model(model(sample_x, inverse=True)), inverse=True) 
                                   - model(sample_x, inverse=True))
    print(f'\tLeft invertibility: {left_invertibility:6.5f}')