In [None]:
import math

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy

from torch.distributions.multivariate_normal import MultivariateNormal

import numpy as np 
import matplotlib.pyplot as plt  
%matplotlib inline

torch.manual_seed(1337)

## Hyperparameters & Data Loading

In [None]:
# HYPERPARAMETERS
lr = 1e-4
batch_size = 128
num_epochs = 1000

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.autograd.set_detect_anomaly(True)

In [None]:
transform = transforms.Compose([
    transforms.Resize((8, 8)),  # resizing the images to 8x8 for computational complexity
    transforms.ToTensor(),  # Renormalizes to [0,1] range
    transforms.Lambda(lambda x: torch.flatten(x))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## Helper Classes

In [None]:
class MaskedLinear(nn.Linear):
    '''
    - purpose is a linear layer with masking between weights to enforce autoregressive property
        - the mask is a lower triangular matrix with 0's on the diagonal
        - matrix = W_{ij}, where [i][j] represents the path from input unit j -> output unit i
        - so to enforce autoregressive property, W_{ij} = {1 if j < i, 0 otherwise}, i.e., lower tril
    - inherits from nn.Linear to achieve this
    '''
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        # diagonal = -1 makes it a lower tril with 0's on the diagonal
        self.mask = torch.tril(torch.ones(out_features, in_features), diagonal=-1).to(device)
    
    def forward(self, x):
        masked_weights = self.weight * self.mask
        return F.linear(x, masked_weights, self.bias)

class MADE(nn.Module):
    '''
    - for MAF, an autoregressive layer (AR) = MADE, and they also use MADE to compute f_\alpha, f_\mu
    - the two key points for this implementation are handling
        - order-agnostic training
            - MADE samples an ordering before each minibatch update for each layer
            - MAF has 5 AR layers implementing MADE, and
                - keeps the natural ordering of the input for the first AR layer
                - reverses the ordering of the output after each AR layer as done in IAF
        - connectivity-agnostic training
            - MADE samples the degree of each unit for each layer before each minibatch update
            - As far as I can tell, MAF just has it as a lower tril with the diag = 0
    '''
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.net = nn.Sequential(
            MaskedLinear(in_features=input_dim, out_features=hidden_dim),
            nn.ReLU(),
            MaskedLinear(in_features=hidden_dim, out_features=output_dim)
        )
    
    def forward(self, x):
        return self.net(x) # [BS, output_dim]

In [None]:
class DiagonalScalingLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.s = nn.Parameter(torch.zeros(dim, device=device)) # [D,]
        
    def forward(self, z):
        ''' f: Z -> X
        x = exp(S) * z
        det(S) = \Pi_i exp(S_{ii})
        log|det(S)| = \Sigma_i S_{ii}
        '''
        z = z * torch.exp(self.s) # [BS, D] * [D], use element-wise mult w diag vector
        log_det = self.s.sum() 
        return z, log_det
    
    def inverse(self, x):
        ''' f: X -> Z
        z = x * 1/exp(S) = x * exp(S)^-1
        det(S^-1) = \Pi_i exp(S)^-1
        log|det(S)| = -\Sigma_i S_{ii}
        '''
        # For the inverse transformation, divide by the scales
        x = x / torch.exp(self.s)
        log_det = -self.s.sum()
        return x, log_det

## Autoregressive + Coupling Norm Flow Architectures

In [None]:
class MAF(nn.Module):
    '''
    - THIS IS A MODULE DEFINING ONE TRANSFORMATION
    - For MNIST, MAF stacks 5 transformations with a BatchNorm and order reversal after each one
    - forward(self, z): SAMPLING, f: Z -> X
        - I'm defining forward to be the pass through the Norm Flow model that maps f: Z -> X
        - for MAF, this is transforming z -> x done sequentially as alpha and mu depend on x
        - returns log_det here in order to compute the density of the generated sample x under p_X
    - inverse(self, x): TRAINING, f: X -> Z
        - I'm defining inverse to be the pass through the Norm Flow model that maps f: X -> Z
        - for MAF, this means transforming external or internal x -> z
        - z = f^-1(x) represents the inverted bijection function mapping the real data, x, to the
        z that generated it. Use change of vars formula to express learned dist. p_X in terms of
        this transformation, and get p_X ~ p_data by MLE under the real dataset
    '''
    def __init__(self, dim, hidden_dim, reverse_order=True):
        super().__init__()
        
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.reverse_order = reverse_order
        
        # instead of having separate nets for alpha and mu,
        # Kaparthy has one net that outputs a 2*input_dim tensor and splits it in half for them
        self.f_params = MADE(dim, hidden_dim, dim*2) 
    
    def forward(self, z):
        ''' f: Z -> X
        - for MAF, sampling is slow since each alpha_i and mu_i depend on x_i
        - thus, done sequentially and need to go pixel by pixel
        - FORWARD EQUATION: x_i = z_i * exp(alpha_i) + mu_i
        - LOG DET OF f: \Sigma_{i} \alpha_{i}
        - RETURNS: x (where x = f(z)) and log_det of f
        '''
        x = torch.zeros_like(z, device=device)
        log_det = torch.zeros(z.size(0), device=device) # [BS,]
        
        for i in range(self.dim):
            params = self.f_params(x)
            mu, alpha = torch.chunk(params, 2, dim=1)
            
            x[:, i] = z[:, i] * torch.exp(alpha[:, i]) + mu[:, i] # sequential buildup
            log_det += alpha[:, i] # [BS,]
        
        x = x.flip(dims=(1,)) if self.reverse_order else x # flip order after every AR layer
        
        return x, log_det # [BS, D], [BS,]
    
    def inverse(self, x):
        ''' f^-1: X -> Z
        - for MAF, alpha and mu depend on x, and the autoregressive property is enforced by MADE
        - thus, we can vectorize our computation of p_X for fast training done in parallel
        - this is done since we have all x, parallel-y computed mu and alpha, so can compute all of z
        - INVERSE EQUATION: z_i = (x_i - mu_i) * exp(-alpha_i) -> z = (x - mu) * exp(-alpha)
        - LOG DET OF f^-1: -(\Sigma_{i} \alpha_{i})
        - RETURNS: z (where z = f^-1(x)) and log_det of f^-1
        '''  
        x = x.flip(dims=(1,)) if self.reverse_order else x
        
        params = self.f_params(x)
        mu, alpha = torch.chunk(params, 2, dim=1)
        
        z = (x - mu) * torch.exp(-alpha)
        log_det = -torch.sum(alpha, dim=1) # [BS,]
        
        return z, log_det # [BS, D], [BS,]
        

In [None]:
class IAF(nn.Module):
    '''
    - THIS IS A MODULE DEFINING ONE TRANSFORMATION OF IAF
    - Core difference between IAF and MAF is that alpha and mu depend on z instead of x now
    '''
    def __init__(self, dim, hidden_dim, reverse_order=True):
        super().__init__()
        
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.reverse_order = reverse_order
        
        self.f_params = MADE(dim, hidden_dim, dim*2) 
    
    def forward(self, z):
        ''' f: Z -> X
        - for IAF, sampling is fast since alpha and mu depend on z now, which we have
        - thus, we can compute samples quickly in parallel
        - FORWARD EQUATION: x = z * exp(alpha) + mu
        - LOG DET OF f: \Sigma_{i} \alpha_{i}
        - RETURNS: x (where x = f(z)) and log_det of f
        '''
        params = self.f_params(z)
        mu, alpha = torch.chunk(params, 2, dim=1)
        
        x = z * torch.exp(alpha) + mu
        log_det = torch.sum(alpha, dim=1)
        
        x = x.flip(dims=(1,)) if self.reverse_order else x # flip after transforming in fpass
        
        return x, log_det # [BS, D], [BS, 1]
    
    def inverse(self, x):
        ''' f: X -> Z
        - for IAF, alpha and mu depend on z, and the autoregressive property is enforced by MADE
        - thus, computing p_X is slow now since we need to recover z sequentially
        - INVERSE EQUATION: z_i = (x_i - mu_i) * exp(-alpha_i)
        - LOG DET OF f^-1: -(\Sigma_{i} \alpha_{i})
        - RETURNS: z (where z = f^-1(x)) and log_det of f^-1
        '''  
        x = x.flip(dims=(1,)) if self.reverse_order else x
        
        z = torch.zeros_like(x, device=device)
        log_det = torch.zeros(z.size(0), device=device) # [BS,]
        
        for i in range(self.dim):
            params = self.f_params(z)
            mu, alpha = torch.chunk(params, 2, dim=1)  
            
            z[:, i] = (x[:, i] - mu[:, i]) * torch.exp(-alpha[:, i])
            log_det += -alpha[:, i] # [BS,]
        
        return z, log_det # [BS, D], [BS,]

In [None]:
class NICE(nn.Module):
    '''
    - this is just for the coupling layer implementation, diagonal scaling matrix implemented later
    - alternate partitioning handled in upper vs lower halves via reverse_order, not evens vs odds
    '''
    def __init__(self, dim, hidden_dim, reverse_order=True):
        super().__init__()
        
        self.d = dim // 2
        self.reverse_order = reverse_order
        
        self.net = nn.Sequential(
            nn.Linear(self.d, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.d)
        )
        
    def forward(self, z):
        ''' f: Z -> X
        - log det of forward pass is actually equal to 0 since 
            - J = [[I, 0], [df/dz1, I]], identities on the diagonal and 0 on the opposite diagonal
            - det(J) = 1
            - log(det(J)) = log(1) = 0
        '''
        z1, z2 = z[:, :self.d], z[:, self.d:] # z1 = z1:d, z2 = zd+1:D
        
        x1, x2 = z1, z2 + self.net(z1) # x1 = unchanged z1, x2 = additive coupling w transformed z1
        x = torch.cat((x1, x2), dim=1) # [BS, D]
        
        x = x.flip(dims=(1,)) if self.reverse_order else x
        
        log_det = torch.zeros(x.size(0), device=device) # [BS,]
        return x, log_det # [BS, D], [BS,]
        
    def inverse(self, x):
        ''' f: X -> Z
        - inv log det is also 0
        '''
        x = x.flip(dims=(1,)) if self.reverse_order else x
        
        x1, x2 = x[:, :self.d], x[:, self.d:] # x1 = x1:d, x2 = xd+1:D
        
        z1 = x1 # z1 = unchanged x1
        z2 = x2 - self.net(z1) # z2 = depends on z1, additive coupling w x2 and transformed z1
        z = torch.cat((z1, z2), dim=1) # [BS, D]
        
        log_det = torch.zeros(x.size(0), device=device) # [BS,]
        return z, log_det # [BS, D], [BS,]

## Normalizing Flow Model

In [None]:
class NormFlow(nn.Module):
    '''
    - For MNIST
        - MAF stacks 5 transformations with a BatchNorm and order reversal after each one
        - NICE uses 4 coupling layers with no BatchNorm and evens/odds, but I'm just keeping 
        the MAF params (5 layers, BatchNorm, order reversal) and adding the diag scaling layer
    - To get final log_det, just add up the log_dets from each transformation layer
    - To get final inverse, keep passing x through inverse methods of layers
    - I didn't implement batch norm as I didn't want to implement an inverse batch norm
    '''
    def __init__(self, dim, hidden_dim, model_name, num_layers=5):
        super().__init__()
        
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.layers = nn.ModuleList()
        
        for _ in range(num_layers):
            if model_name == 'maf':
                model_instance = MAF(dim, hidden_dim, reverse_order=True)
            elif model_name == 'iaf':
                model_instance = IAF(dim, hidden_dim, reverse_order=True)
            elif model_name == 'nice':
                model_instance = NICE(dim, hidden_dim, reverse_order=True)
                
            self.layers.append(model_instance)

        
        if model_name == 'nice':
            self.layers.append(DiagonalScalingLayer(dim))
        
    def forward(self, z):
        ''' f: Z -> X = SAMPLING
        '''
        log_det_total = torch.zeros(z.size(0), device=device) # [BS,]
        
        for i in range(self.num_layers):
            z, log_det = self.layers[i](z) # output of one layer passed into next
            # flipping handled internally after each layer
            log_det_total += log_det # final log_det of the entire transformation is sum of log_dets
        
        # this final z after being transformed through all the layers = x
        return z, log_det_total # [BS, D], [BS,]
        
    def inverse(self, x):
        ''' f: X -> Z = TRAINING
        '''
        log_det_total = torch.zeros(x.size(0), device=device) # [BS,]
        
        for i in reversed(range(self.num_layers)):
            x, log_det = self.layers[i].inverse(x)
            log_det_total += log_det
        
        # this final x after being transformed through all the layers = z
        return x, log_det_total # [BS, D], [BS,]

## Training & Sampling

In [None]:
def view_samples(model, input_dim):
    model.eval()
    
    z = torch.randn((batch_size, input_dim), device=device) # [BS, D]
    
    with torch.no_grad():
        x, _ = model(z) # [BS, D]
    
    x_images = x.view(-1, int(np.sqrt(input_dim)), int(np.sqrt(input_dim))).cpu().numpy()
    
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))  # Adjust figsize as needed
    for i, ax in enumerate(axes.flatten()):
        if i < 10:  # Make sure to only access the first 10 images
            ax.imshow(x_images[i], cmap='gray')
        ax.axis('off')
    plt.show()

In [None]:
def trainNormFlow(model, train_loader, optimizer, num_epochs, input_dim):
    pz = MultivariateNormal(
        torch.zeros(input_dim, device=device), 
        torch.eye(input_dim, device=device)
    )
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for i, batch in enumerate(train_loader):
            x, _ = batch # [BS, D]
            x = x.to(device)
            
            optimizer.zero_grad()
            
            # MAF inverse: X -> Z
            # p_X(x) = pZ(f^-1(x)) * |det(df^-1/dx)| = pZ(f^-1(x)) + exp(-alpha summed over i)
            # log(p_X(x)) = log(pZ(f^-1(x))) + log(exp(-alpha)) = log(pZ(z)) + (-alpha summed over i)
            z, log_det = model.inverse(x) # [BS, D], [BS,]
            log_pz = pz.log_prob(z) # log_pz = [BS,]
            # log_pz = -0.5 * input_dim * torch.log(torch.tensor(2 * math.pi)) - 0.5 * (z ** 2).sum(dim=1)  # [BS,]
            
            # minimizing KL divergence between p_X and p_data = argmin_pX KL(p_data||p_X)
            # argmin_pX KL(p_data||p_X) = argmin_pX -\Sigma_{x} p_data(x) * log(p_X(x))
            # = argmin_pX -log(p_X(x)) over the real data, x
            log_px = log_pz + log_det # [BS,] + [BS,]
            loss = -log_px.mean() # averaged over the batch
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {train_loss / len(train_loader)}")
        if epoch != 0 and epoch % 10 == 0:
            view_samples(model, input_dim)

In [None]:
input_dim = 64
hidden_dim = 128
num_layers = 5

models_q1 = {'maf':None, 'nice':None}

for model_name in models_q1.keys():
    model = NormFlow(input_dim, hidden_dim, model_name, num_layers)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    trainNormFlow(model, train_loader, optimizer, num_epochs, input_dim)
    
    models_q1[model_name] = model

In [None]:
for model_name, model in models_q1.items():
    print(f'Results after {num_epochs} epochs of training for {model_name}')
    view_samples(model, input_dim)