## Config

Configure notebook, torch, and file paths

In [None]:
%load_ext autoreload
%autoreload 2
from nb_util import data_path, device

model_name = 'celeba-small-glow'

## Load data

In [None]:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import flow.data as data

batch_size_manifold = 32
batch_size_density = 32
channels = 3
height = 64
width = 64

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

train_data = data.CelebA(root=data_path, split='train', transform=transform)
test_data = data.CelebA(root=data_path, split='test', transform=transform)

manifold_loader = DataLoader(train_data, batch_size=batch_size_manifold, shuffle=True, num_workers=30)
density_loader = DataLoader(train_data, batch_size=batch_size_density, shuffle=True, num_workers=30)
test_loader = DataLoader(test_data, batch_size=batch_size_manifold, shuffle=True, num_workers=30)

## Manifold model

Construct manifold-learning component

In [None]:
import torch
import flow.components as comp

manifold_model = comp.Sequential(
    comp.GlowNet(channels, k=3, l=2, additive_coupling=True, hidden_size=8, out_shape=(8, 8)),
    comp.Invertible1x1Conv(channels*64),
    comp.Pad(channels*64, channels*8),
)
manifold_model.to(device)

# Check for runtime errors and initialize weights with first batch
num_recons = 8
sample_x = next(iter(manifold_loader))[0].to(device)[:num_recons]
with torch.no_grad():
    sample_mid_latent = manifold_model.initialize(sample_x)

m = sample_mid_latent.numel() // num_recons # Dimension of latent space
with torch.no_grad():
    manifold_model.data_to_latent(sample_x, m)
    manifold_model.latent_to_data(sample_mid_latent, m)
    
f'Parameters: {sum(w.numel() for w in manifold_model.parameters() if w.requires_grad)}'

Train the learned manifold using reconstruction loss

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as opt
from nb_util import compare_batches

epochs = 10
manifold_opt = opt.Adam(manifold_model.parameters(), lr=0.0001)
manifold_model.train()

sample_mid_latent = manifold_model(sample_x, inverse=True)
sample_recons =  manifold_model(sample_mid_latent)
fig, display_id = compare_batches(sample_x, sample_recons)


for epoch in range(epochs):
    for batch, (image, _) in enumerate(manifold_loader):
        image = image.to(device)

        # Compute reconstruction error
        manifold_opt.zero_grad()
        
        mid_latent = manifold_model(image, inverse=True)
        reconstruction = manifold_model(mid_latent)
        reconstruction_error = torch.mean((image - reconstruction)**2)

        # Training step
        loss = reconstruction_error
        loss.backward()
        manifold_opt.step()

        # Display results
        print(f'[E{epoch} B{batch}] | Reconstruction: {reconstruction_error:6.5f}', end='\r')
        if batch % 20 == 0:
            with torch.no_grad():
                sample_mid_latent = manifold_model(sample_x, inverse=True)
                sample_recons =  manifold_model(sample_mid_latent)
            compare_batches(sample_x, sample_recons, fig, display_id)

Compute reconstruction error

In [None]:
rec_errors = []

for image, _ in test_loader:
    image = image.to(device)

    with torch.no_grad():
        mid_latent = manifold_model(image, inverse=True)
        reconstruction = manifold_model(mid_latent)
        reconstruction_error = torch.mean((image - reconstruction)**2).detach()

    rec_errors.append(reconstruction_error)

f'Reconstruction error: {np.mean([float(err) for err in rec_errors])}'

Save manifold model

In [None]:
torch.save(manifold_model.state_dict(), f'models/{model_name}-manifold.pt')

## Density Model

Construct density-learning component and concatenate the two models

In [None]:
density_model = comp.Sequential(
    comp.GlowNet(channels*8, k=3, l=3),
)
density_model.to(device)

# Initialize the weights of the density model and check for errors
with torch.no_grad():
    sample_mid_latent = manifold_model(sample_x, inverse=True)
    sample_z = density_model.initialize(sample_mid_latent)
    density_model.data_to_latent(sample_mid_latent, m)
    density_model.latent_to_data(sample_z, m)

f'Parameters: {sum(w.numel() for w in density_model.parameters() if w.requires_grad)}'

Sample some latents to show during training

In [None]:
# Generate 16 samples with reduced temperature for visualization
temp = 0.75
num_samples = 16

with torch.no_grad():
    latent_shape = density_model(sample_mid_latent, inverse=True).shape[1:]

latent_samples = torch.normal(mean=torch.zeros(num_samples, *latent_shape), 
                              std=torch.ones(num_samples, *latent_shape)*temp)
latent_samples = latent_samples.to(device)

Train the density with log-likelihood

In [None]:
import torch.nn.functional as f
from nb_util import display_batch, update_displayed_batch

epochs = 300
density_opt = opt.Adam(density_model.parameters(), lr=0.00001)             

const = -(m/2) * np.log(2*np.pi) # Constant for log likelihood

manifold_model.eval()
density_model.train()
mid_latent = density_model(latent_samples)
gen_samples = manifold_model(mid_latent)
fig, display_id = display_batch(gen_samples)


for epoch in range(epochs):
    for batch, (image, _) in enumerate(density_loader):
        image = image.to(device)
        density_opt.zero_grad()

        # Compute log likelihood
        with torch.no_grad():
            mid_latent, _ = manifold_model.data_to_latent(image, m)
        z, density_log_det = density_model.data_to_latent(mid_latent, m)
        log_pz = const - torch.sum(z**2, axis=1)/2
        half_log_det = density_log_det
        log_likelihood = torch.mean(log_pz + half_log_det)

        # Training step
        loss = -log_likelihood
        loss.backward()
        density_opt.step()

        # Display results
        print(f'[E{epoch} B{batch}] | Log-likelihood: {log_likelihood:6.2f} '
              f'| Logp(z): {torch.mean(log_pz):6.2f} '
              f'| Log det: {torch.mean(half_log_det):6.2f}', end='\r')
        if batch % 10 == 0:
            with torch.no_grad():
                mid_latent = density_model(latent_samples)
                gen_samples = manifold_model(mid_latent)
            update_displayed_batch(gen_samples, fig, display_id)
            
    torch.save(density_model.state_dict(), f'models/{model_name}-density-e{epoch}.pt')

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as opt

test_data = data.CelebA(root=data_path, split='test', transform=transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=30)


likelihoods = []

for image, _ in test_loader:
    image = image.to(device)

    # Compute log likelihood
    with torch.no_grad():
        mid_latent, _ = manifold_model.data_to_latent(image, m)
        _, manifold_log_det = manifold_model.latent_to_data(mid_latent, m)
        z, density_log_det = density_model.data_to_latent(mid_latent, m)
        log_pz = const - torch.sum(z**2, axis=1)/2
        half_log_det = manifold_log_det + density_log_det
        log_likelihood = torch.mean(log_pz + half_log_det)
        
    likelihoods.append(log_likelihood)
    
np.mean([float(lik) for lik in likelihoods])

In [None]:
with torch.no_grad():
    mid_latent = density_model(latent_samples)
    gen_samples = manifold_model(mid_latent)
fig, display_id = display_batch(gen_samples)

## Generate Images

In [None]:
from nb_util import generate_image_samples

eval_epoch = 100
manifold_model.load_state_dict(torch.load(f'models/{model_name}-manifold.pt'))
density_model.load_state_dict(torch.load(f'models/{model_name}-density-e{eval_epoch}.pt'))
full_model = manifold_model + density_model

generate_image_samples(
    30000, full_model, model_name + '2', latent_shape=latent_shape, batch_size=16, temp=temp)

In [None]:
ls data/celeba/list_eval_partition.csv

In [None]:
import pandas as pd
df = pd.read_csv('data/celeba/list_eval_partition.csv')

In [None]:
df[df['partition'] == 2]

## Debug model

Check stats for all the parameters. Check for invertibility (the model should be left invertible but not necessarily right invertible).

In [None]:
model = full_model

with torch.no_grad():
    for component in model.components:
        print(component.__class__.__name__)

        for parameter in component.parameters():
            if parameter.requires_grad:
                print(f'\tParam shape: {parameter.shape}')
                print(f'\t\tmin:  {torch.min(parameter):6.3f}')
                print(f'\t\tmax:  {torch.max(parameter):6.3f}')
                print(f'\t\tmean:  {torch.mean(parameter):6.3f}')
                print(f'\t\tnorm: {torch.linalg.norm(parameter):6.3f}')


    print('Invertibility check')
    right_invertibility = torch.max(model(model(sample_x, inverse=True)) - sample_x)
    print(f'\tRight invertibility: {right_invertibility:6.5f}')

    left_invertibility = torch.max(model(model(model(sample_x, inverse=True)), inverse=True) 
                                   - model(sample_x, inverse=True))
    print(f'\tLeft invertibility: {left_invertibility:6.5f}')