### Walkthrough of Creating a Basic ARM



In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
# Digits dataset
from sklearn.datasets import load_digits
from sklearn import datasets
# Pytorch
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from pytorch_model_summary import summary

#### Dataset

**Scipy Digits:**
- $N \approx 1500$
- Each image is $8 \times 8$
- $\mathcal{X} = \{0,1,2, \dots 16\}$

Defining the dataset class:

In [2]:
class Digits(Dataset):
    """
    SciKit Learn Digits dataset wrapper for PyTorch
    """
    def __init__(self, mode='train', transforms=None):
        
        self.digits = load_digits()
        
        if mode=='train':
            self.data = self.digits.data[:1000].astype(np.float32)
        elif mode=='val':
            self.data = self.digits.data[1000:1350].astype(np.float32)
        else:
            self.data = self.digits.data[1350:].astype(np.float32)
        
        
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

### Building the ARM
#### Causal Convolutional Layers

In [3]:
class CausalConv1d(nn.Module):
    """
    Causal 1D convolution layer
    """

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            dilation, # Dilation factor
            A=False, # NOTE: Determines whether to use "A" mode or not
            **kwargs
    ):
        super(CausalConv1d, self).__init__()
        
        # Basic attributes
        self.A = A
        self.kernel_size = kernel_size 
        self.padding = (kernel_size - 1) * dilation + A*1 # NOTE: 1 added for "A" mode (initial layer)
        self.dilation = dilation


        self.conv1d = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=dilation,
            **kwargs 
        )

    def forward(self, x):
        x = nn.functional.pad(x, (self.padding, 0))
        conv1d_out = self.conv1d(x)
        
        if self.A:
            return conv1d_out[:, :, :-1]
        else:
            return conv1d_out

#### Log Categorical

In [4]:
EPS = 1.e-5

def log_categorical(
    x, 
    p, 
    num_classes=256, # Number of pixel values
    reduction=None, 
    dim=None,
):
    """
    Log categorical distribution
    """
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes) # Basically implements the iverson bracket
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'sum':
        return torch.sum(log_p, dim)
    elif reduction == 'avg':
        return torch.mean(log_p, dim)
    else:
        return log_p 



#### ARM!

In [5]:
class ARM(nn.Module):
    def __init__(self, net, D=2, num_vals=256):
        super(ARM, self).__init__()
        self.net = net
        self.D = D
        self.num_vals = num_vals
        
    def f(self, x):
        
        h = self.net(x.unsqueeze(1))
        h = h.permute(0, 2, 1)
        p = torch.softmax(h, 2)
        
        return p

    def forward(self, x, reduction='avg'):
        if reduction == 'avg':
            return -(self.log_prob(x).mean())
        elif reduction == 'sum':
            return -(self.log_prob(x).sum())
        else:
            raise NotImplementedError
        
    def log_prob(self, x):
        mu_d = self.f(x)
        log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=2).sum(-1)

        return log_p
    
    def sample(self, batch_size):
        x_new = torch.zeros(batch_size, self.D)
        for d in range(self.D):
            p = self.f(x_new)
            x_new_d = torch.multinomial(p[:, d, :], num_samples=1)
            x_new[:, d] = x_new_d.squeeze() # NOTE: Squeeze to remove the last dimension
            
        return x_new


### Auxiliary Functions

#### Training, Evaluation

In [6]:
def evaluation(
        val_loader,
        name = None,
        model_best = None,
        epoch = None,
):
    if model_best is None:
        print(f"Loading model from: {name + '.model'}")
        model_best = torch.load(name + '.model')
    
    loss = 0
    N = 0

    for _, test_batch in enumerate(val_loader):
        
        loss += model_best.forward(test_batch, reduction='sum').item()
        N += test_batch.size(0)

    loss /= N

    if epoch is None:
        print('Test Loss: {:.4f}'.format(loss))
    else:
        print('Epoch: {}, Val Loss: {:.4f}'.format(epoch, loss))
    
    return loss

def samples_real(name, test_loader):

    num_x, num_y = 4, 4
    x = next(iter(test_loader)).detach().numpy()

    fig, ax = plt.subplots(num_x, num_y, figsize=(10, 10))
    for i, ax in enumerate(ax.flatten()):
        
        ax.imshow(x[i].reshape(8, 8), cmap='gray')
        ax.axis('off')
    
    plt.savefig(name + '_real_images.png', bbox_inches='tight')
    plt.close()

def samples_generated(name, data_loader, extra_name=''):
    
    x = next(iter(data_loader)).detach().numpy()

    # Generate
    model_best = torch.load(name + '.model', weights_only=False)
    model_best.eval()

    num_x, num_y = 4, 4
    x = model_best.sample(num_x*num_y).detach().numpy()

    _, ax = plt.subplots(num_x, num_y, figsize=(10, 10))
    for i, ax in enumerate(ax.flatten()):
        
        ax.imshow(x[i].reshape(8, 8), cmap='gray')
        ax.axis('off')

    plt.savefig(name + '_generated_images' + extra_name + '.png', bbox_inches='tight')
    plt.close()

def plot_curve(name, train_loss, val_loss):
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(train_loss, label='Train')
    ax.plot(val_loss, label='Val')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    plt.savefig(name + '_loss_curve.png', bbox_inches='tight')
    plt.close()

### Training Function

In [7]:
def training(
        name, 
        max_patience,
        num_epochs,
        model,
        optimizer,
        train_loader,
        val_loader,
):
    val_loss = []
    train_loss = []
    best_val_loss = 1000
    patience = 0

    for epoch in range(num_epochs):

        # Training
        
        model.train()
        train_loss_ep = 0
        N = 0

        for _, train_batch in enumerate(train_loader):

            if hasattr(model, 'dequantization'):
                if model.dequantization:
                    batch = batch + torch.rand_like(batch)
                

            loss = model.forward(train_batch)            
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            
            train_loss_ep += loss.item()
            N += train_batch.size(0)


        train_loss_ep /= N
        train_loss.append(train_loss_ep)

        # Validation
        val_loss_ep = evaluation(
            val_loader,
            name=name,
            model_best=model,
            epoch=epoch
        )
        val_loss.append(val_loss_ep)
        print('Epoch: {}, Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch, train_loss_ep, val_loss_ep))

        if epoch == 0:
            print('saved')
            torch.save(model, name + '.model')
            best_val_loss = val_loss_ep
        else:  
            if val_loss_ep < best_val_loss:
                print('saved')
                best_val_loss = val_loss_ep
                patience = 0
                torch.save(model, name + '.model')
                samples_generated(name, val_loader, extra_name='_epoch_' + str(epoch))
            else:
                patience += 1

        if patience > max_patience:
            break

    val_loss = np.array(val_loss)
    train_loss = np.array(train_loss)
    
    return train_loss, val_loss


### Initialize Dataloaders

In [8]:
train_data = Digits(mode='train')
val_data = Digits(mode='val')
test_data = Digits(mode='test')

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

results_dir = 'results/'
os.makedirs(results_dir, exist_ok=True)
name = 'arm'

In [9]:
train_data

<__main__.Digits at 0x1203e7cb0>

#### Hyperparameters

In [10]:
D = 64
M = 256

lr = 1e-3
num_epochs = 100
max_patience = 20


### Initialize ARM

In [11]:
likelihood_type = 'categorical'
num_classes = 17
kernel = 7

net = nn.Sequential(
    CausalConv1d(in_channels=1, out_channels=M, dilation=1, kernel_size=kernel, A=True, bias=True), # <- NOTE: A=True for input layer
    nn.LeakyReLU(),
    CausalConv1d(in_channels=M, out_channels=M, dilation=1, kernel_size=kernel, A=False, bias=True),
    nn.LeakyReLU(),
    CausalConv1d(in_channels=M, out_channels=M, dilation=1, kernel_size=kernel, A=False, bias=True),
    nn.LeakyReLU(),
    CausalConv1d(in_channels=M, out_channels=num_classes, dilation=1, kernel_size=kernel, A=False, bias=True)
)

model = ARM(net, D=D, num_vals=num_classes)
print(summary(model, torch.zeros((1, 64)), show_input=False, show_hierarchical=False))


-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
    CausalConv1d-1        [1, 256, 64]           2,048           2,048
       LeakyReLU-2        [1, 256, 64]               0               0
    CausalConv1d-3        [1, 256, 64]         459,008         459,008
       LeakyReLU-4        [1, 256, 64]               0               0
    CausalConv1d-5        [1, 256, 64]         459,008         459,008
       LeakyReLU-6        [1, 256, 64]               0               0
    CausalConv1d-7         [1, 17, 64]          30,481          30,481
Total params: 950,545
Trainable params: 950,545
Non-trainable params: 0
-----------------------------------------------------------------------


### Training

In [12]:
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

In [13]:
result_dir = 'results/'
os.makedirs(result_dir, exist_ok=True)
train_loss, val_loss = training(
    name=result_dir + name, 
    max_patience=max_patience, 
    num_epochs=num_epochs, 
    model=model, 
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    )

Epoch: 0, Val Loss: 117.9166
Epoch: 0, Train Loss: 2.2520, Val Loss: 117.9166
saved
Epoch: 1, Val Loss: 112.4990
Epoch: 1, Train Loss: 1.8165, Val Loss: 112.4990
saved
Epoch: 2, Val Loss: 110.3184
Epoch: 2, Train Loss: 1.7582, Val Loss: 110.3184
saved
Epoch: 3, Val Loss: 108.7038
Epoch: 3, Train Loss: 1.7247, Val Loss: 108.7038
saved
Epoch: 4, Val Loss: 107.3021
Epoch: 4, Train Loss: 1.7001, Val Loss: 107.3021
saved
Epoch: 5, Val Loss: 105.7430
Epoch: 5, Train Loss: 1.6746, Val Loss: 105.7430
saved
Epoch: 6, Val Loss: 103.9772
Epoch: 6, Train Loss: 1.6456, Val Loss: 103.9772
saved
Epoch: 7, Val Loss: 102.0827
Epoch: 7, Train Loss: 1.6142, Val Loss: 102.0827
saved
Epoch: 8, Val Loss: 100.2063
Epoch: 8, Train Loss: 1.5767, Val Loss: 100.2063
saved
Epoch: 9, Val Loss: 98.2842
Epoch: 9, Train Loss: 1.5448, Val Loss: 98.2842
saved
Epoch: 10, Val Loss: 97.1546
Epoch: 10, Train Loss: 1.5166, Val Loss: 97.1546
saved
Epoch: 11, Val Loss: 96.0614
Epoch: 11, Train Loss: 1.4926, Val Loss: 96.0614
