This code is adopted from one of the [PyTorch Lightning examples](https://colab.research.google.com/drive/1Tr9dYlwBKk6-LgLKGO8KYZULnguVA992?usp=sharing#scrollTo=CxXtBfFrKYgA) and [this PonderNet implementation](https://nn.labml.ai/adaptive_computation/ponder_net/index.html). 

# PonderNet: complexity in MNIST
[PonderNet](https://arxiv.org/pdf/2107.05407.pdf) is a new architecture that promises to be able to adapt its computational budget according to the complexity of the task at hand. In the original paper, the authors deal with either problems that are too simplistic (guessing the parity of the number of "1"s in a vector), or too obscure to be able to draw meaningful conclusions. Although their results show that harder problems are given more computational resources than easier ones by PonderNet, it is hard to say if this is only the case in the selected toy examples.

In order to test the validity of their argument, we train a version of PonderNet on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). Our solution uses an augmented CNN network that embeds the image so it can be adapted to the PonderNet framework, which requires the incorporation of a hidden state.

We propose two tasks, an _interpolation_ and an _extrapolation_ one, which are akin to the ones found in the parity experiments from the original paper. Our interpolation task consists on simply learning to correctly classify the unaltered MNIST dataset. On the other hand, our extrapolation task consists on training PonderNet on slightly rotated images and testing it on considerably rotated images; this mirrors how the parity extrapolation task is trained on vectors of size 1-48 and tested on sizes between 48-96.

Our results show that PonderNet is able to solve the interpolation task but struggles with the extrapolation task. This is expected since pronounced rotations can render an image impossible to classify. In terms of the expected number of steps, the extrapolation task uses less steps the more complex it is; this is a counter-intuitive result for which it is hard to draw meaningful conclusions.

# Setup and imports

We use `PyTorch Lightning` (wrapping `PyTorch`) as our main framework and `wandb` to track and log the experiments. We set all seeds through `PyTorch Lightning`'s dedicated function.

In [None]:
!pip install wandb
!pip install pytorch-lightning

In [None]:
# import Libraries

# torch imports
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import torch.nn.functional as F
import torchmetrics

# pl imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

# remaining imports
import wandb
from math import floor

In [None]:
# set seeds
seed_everything(1234)

# log in to wandb
wandb.login()

# Constants and hyeperparameters

We define the hyperparameters for our experiments. The choices for the underlying CNN are taken from the linked MNIST tutorial, and similarly with the PonderNet hyperparameters. 

In [None]:
# TRAINER SETTINGS
BATCH_SIZE = 128
EPOCHS = 20

# OPTIMIZER SETTINGS
LR = 0.001
GRAD_NORM_CLIP = 0.5

# MODEL HPARAMS
N_HIDDEN = 64
N_HIDDEN_CNN = 64
N_HIDDEN_LIN = 64
KERNEL_SIZE = 5

MAX_STEPS = 20
LAMBDA_P = 0.2
BETA = 0.01

# MNIST

We wrap the MNIST dataset with `PyTorch Lightning`'s Data Module classs, which allows for easier integration. 


In [None]:
class MNIST_DataModule(pl.LightningDataModule):
    '''
        DataModule to hold the MNIST dataset. Accepts different transforms for train and test to
        allow for extrapolation experiments.

        Parameters
        ----------
        data_dir : str
            Directory where MNIST will be downloaded or taken from.

        train_transform : [transform] 
            List of transformations for the training dataset. The same
            transformations are also applied to the validation dataset.

        test_transform : [transform] or [[transform]]
            List of transformations for the test dataset. Also accepts a list of
            lists to validate on multiple datasets with different transforms.

        batch_size : int
            Batch size for both all dataloaders.
    '''

    def __init__(self, data_dir='./', train_transform=None, test_transform=None, batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.test_transform = test_transform

        self.default_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        '''called only once and on 1 GPU'''
        # download data (train/val and test sets)
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        '''
            Called on each GPU separately - stage defines if we are
            at fit, validate, test or predict step.
        '''
        # we set up only relevant datasets when stage is specified
        if stage in [None, 'fit', 'validate']:
            mnist_train = MNIST(self.data_dir, train=True,
                                transform=(self.train_transform or self.default_transform))
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            if self.test_transform is None or isinstance(self.test_transform, transforms.Compose):
                self.mnist_test = MNIST(self.data_dir,
                                        train=False,
                                        transform=(self.test_transform or self.default_transform))
            else:
                self.mnist_test = [MNIST(self.data_dir,
                                         train=False,
                                         transform=test_transform)
                                   for test_transform in self.test_transform]

    def train_dataloader(self):
        '''returns training dataloader'''
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
        return mnist_train

    def val_dataloader(self):
        '''returns validation dataloader'''
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        '''returns test dataloader(s)'''
        if isinstance(self.mnist_test, MNIST):
            return DataLoader(self.mnist_test, batch_size=self.batch_size)

        mnist_test = [DataLoader(test_dataset,
                                 batch_size=self.batch_size)
                      for test_dataset in self.mnist_test]
        return mnist_test

# PonderNet implementation

## Auxiliary networks

The following section contains code that implements our particular version of PonderNet. For convenience, we define two small networks that will be used within PonderNet, a CNN and a multi-layer perceptron.

In [None]:
class MLP(nn.Module):
    '''
        Simple 3-layer multi layer perceptron.

        Parameters
        ----------
        n_input : int
            Size of the input.

        n_hidden : int
            Number of units of the hidden layer.

        n_ouptut : int
            Size of the output.
    '''

    def __init__(self, n_input, n_hidden, n_output):
        super(MLP, self).__init__()
        self.i2h = nn.Linear(n_input, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        self.droput = nn.Dropout(0.2)

    def forward(self, x):
        '''forward pass'''
        x = F.relu(self.i2h(x))
        x = self.droput(x)
        x = F.relu(self.h2o(x))
        return x


class CNN(nn.Module):
    '''
        Simple convolutional neural network.

        Parameters
        ----------
        n_input : int
            Size of the input image. We assume the image is a square,
            and `n_input` is the size of one side.

        n_ouptut : int
            Size of the output.

        kernel_size : int
            Size of the kernel.
    '''

    def __init__(self, n_input=28, n_output=50, kernel_size=5):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=kernel_size)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=kernel_size)
        self.conv2_drop = nn.Dropout2d()

        # calculate size of convolution output
        self.lin_size = floor((floor((n_input - (kernel_size - 1)) / 2) - (kernel_size - 1)) / 2)
        self.fc1 = nn.Linear(self.lin_size ** 2 * 20, n_output)

    def forward(self, x):
        '''forward pass'''
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return x

## Loss

Here we define the two terms in the loss, namely the reconstruction term and the regularization term. We create a class to wrap them.

In [None]:
class ReconstructionLoss(nn.Module):
    '''
        Computes the weighted average of the given loss across steps according to
        the probability of stopping at each step.

        Parameters
        ----------
        loss_func : callable
            Loss function accepting true and predicted labels. It should output
            a loss item for each element in the input batch.
    '''

    def __init__(self, loss_func: nn.Module):
        super().__init__()
        self.loss_func = loss_func

    def forward(self, p: torch.Tensor, y_pred: torch.Tensor, y: torch.Tensor):
        '''
            Compute the loss.

            Parameters
            ----------
            p : torch.Tensor
                Probability of halting at each step, of shape `(max_steps, batch_size)`.

            y_pred : torch.Tensor
                Predicted outputs, of shape `(max_steps, batch_size)`.

            y : torch.Tensor
                True targets, of shape `(batch_size)`.

            Returns
            -------
            total_loss : torch.Tensor
                Scalar representing the reconstruction loss.
        '''
        total_loss = p.new_tensor(0.)

        for n in range(p.shape[0]):
            loss = (p[n] * self.loss_func(y_pred[n], y)).mean()
            total_loss = total_loss + loss

        return total_loss


class RegularizationLoss(nn.Module):
    '''
        Computes the KL-divergence between the halting distribution generated
        by the network and a geometric distribution with parameter `lambda_p`.

        Parameters
        ----------
        lambda_p : float
            Parameter determining our prior geometric distribution.

        max_steps : int
            Maximum number of allowed pondering steps.
    '''

    def __init__(self, lambda_p: float, max_steps: int = 1_000, device=None):
        super().__init__()

        p_g = torch.zeros((max_steps,), device=device)
        not_halted = 1.

        for k in range(max_steps):
            p_g[k] = not_halted * lambda_p
            not_halted = not_halted * (1 - lambda_p)

        self.p_g = nn.Parameter(p_g, requires_grad=False)
        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    def forward(self, p: torch.Tensor):
        '''
            Compute the loss.

            Parameters
            ----------
            p : torch.Tensor
                Probability of halting at each step, representing our
                halting distribution.

            Returns
            -------
            loss : torch.Tensor
                Scalar representing the regularization loss.
        '''
        p = p.transpose(0, 1)
        p_g = self.p_g[None, :p.shape[1]].expand_as(p)
        return self.kl_div(p.log(), p_g)


class Loss:
    '''
        Class to group the losses together and calculate the total loss.

        Parameters
        ----------
        rec_loss : torch.Tensor
            Reconstruction loss obtained from running the network.

        reg_loss : torch.Tensor
            Regularization loss obtained from running the network.

        beta : float
            Hyperparameter to calculate the total loss.
    '''

    def __init__(self, rec_loss, reg_loss, beta):
        self.rec_loss = rec_loss
        self.reg_loss = reg_loss
        self.beta = beta

    def get_rec_loss(self):
        '''returns the reconstruciton loss'''
        return self.rec_loss

    def get_reg_loss(self):
        '''returns the regularization loss'''
        return self.reg_loss

    def get_total_loss(self):
        '''returns the total loss'''
        return self.rec_loss + self.beta * self.reg_loss

## PonderNet

Finally, we have PonderNet. We use a `PyTorch Lichtning` module, which allows us to control all the aspects of training, validation and testing in the same class. Of special importance is the forward pass; for the sake of simplicity, we decided to implement a hardcoded maximum number of steps approach instead of a threshold on the cumulative probability of halting.

In [None]:
class PonderMNIST(pl.LightningModule):
    '''
        PonderNet variant to perform image classification on MNIST. It is capable of
        adaptively choosing the number of steps for which to process an input.

        Parameters
        ----------
        n_hidden : int
            Hidden layer size of the propagated hidden state.

        n_hidden_lin :
            Hidden layer size of the underlying MLP.

        n_hidden_cnn : int
            Hidden layer size of the output of the underlying CNN.

        kernel_size : int
            Size of the kernel of the underlying CNN.

        max_steps : int
            Maximum number of steps the network is allowed to "ponder" for.

        lambda_p : float 
            Parameter of the geometric prior. Must be between 0 and 1.

        beta : float
            Hyperparameter to calculate the total loss.

        lr : float
            Learning rate.

        Modules
        -------
        cnn : CNN
            Learnable convolutional neural network to emgbed the image into a vector.

        mlp : MLP
            Learnable 3-layer machine learning perceptron to combine the hidden state with
            the image embedding.

        ouptut_layer : nn.Linear
            Linear module that serves as a multi-class classifier.

        lambda_layer : nn.Linear
            Linear module that generates the halting probability at each step.
    '''

    def __init__(self, n_hidden, n_hidden_lin, n_hidden_cnn, kernel_size, max_steps, lambda_p, beta, lr):
        super().__init__()

        # attributes
        self.n_classes = 10
        self.max_steps = max_steps
        self.lambda_p = lambda_p
        self.beta = beta
        self.n_hidden = n_hidden
        self.lr = lr

        # modules
        self.cnn = CNN(n_input=28, kernel_size=kernel_size, n_output=n_hidden_cnn)
        self.mlp = MLP(n_input=n_hidden_cnn + n_hidden, n_hidden=n_hidden_lin, n_output=n_hidden)
        self.outpt_layer = nn.Linear(n_hidden, self.n_classes)
        self.lambda_layer = nn.Linear(n_hidden, 1)

        # losses
        self.loss_rec = ReconstructionLoss(nn.CrossEntropyLoss())
        self.loss_reg = RegularizationLoss(self.lambda_p, max_steps=self.max_steps, device=self.device)

        # metrics
        self.accuracy = torchmetrics.Accuracy()

        # save hparams on W&B
        self.save_hyperparameters()

    def forward(self, x):
        '''
            Run the forward pass.

            Parameters
            ----------
            x : torch.Tensor
                Batch of input features of shape `(batch_size, n_elems)`.

            Returns
            -------
            y : torch.Tensor
                Tensor of shape `(max_steps, batch_size)` representing
                the predictions for each step and each sample. In case
                `allow_halting=True` then the shape is
                `(steps, batch_size)` where `1 <= steps <= max_steps`.

            p : torch.Tensor
                Tensor of shape `(max_steps, batch_size)` representing
                the halting probabilities. Sums over rows (fixing a sample)
                are 1. In case `allow_halting=True` then the shape is
                `(steps, batch_size)` where `1 <= steps <= max_steps`.

            halting_step : torch.Tensor
                An integer for each sample in the batch that corresponds to
                the step when it was halted. The shape is `(batch_size,)`. The
                minimal value is 1 because we always run at least one step.
        '''
        # extract batch size for QoL
        batch_size = x.shape[0]

        # propagate to get h_1
        h = x.new_zeros((batch_size, self.n_hidden))
        embedding = self.cnn(x)
        concat = torch.cat([embedding, h], 1)
        h = self.mlp(concat)

        # lists to save p_n, y_n
        p = []
        y = []

        # vectors to save intermediate values
        un_halted_prob = h.new_ones((batch_size,))  # unhalted probability till step n
        halting_step = h.new_zeros((batch_size,), dtype=torch.long)  # stopping step

        # main loop
        for n in range(1, self.max_steps + 1):
            # obtain lambda_n
            if n == self.max_steps:
                lambda_n = h.new_ones(batch_size)
            else:
                lambda_n = torch.sigmoid(self.lambda_layer(h)).squeeze()

            # obtain output and p_n
            y_n = self.outpt_layer(h)
            p_n = un_halted_prob * lambda_n

            # append p_n, y_n
            p.append(p_n)
            y.append(y_n)

            # calculate halting step
            halting_step = torch.maximum(
                n
                * (halting_step == 0)
                * torch.bernoulli(lambda_n).to(torch.long),
                halting_step)

            # track unhalted probability and flip coin to halt
            un_halted_prob = un_halted_prob * (1 - lambda_n)

            # propagate to obtain h_n
            embedding = self.cnn(x)
            concat = torch.cat([embedding, h], 1)
            h = self.mlp(concat)

            # break if we are in inference and all elements have halting_step
            if not self.training and (halting_step > 0).sum() == batch_size:
                break

        return torch.stack(y), torch.stack(p), halting_step

    def training_step(self, batch, batch_idx):
        '''
            Perform the training step.

            Parameters
            ----------
            batch : (torch.Tensor, torch.Tensor)
                Current training batch to train on.

            Returns
            -------
            loss : torch.Tensor
                Loss value of the current batch.
        '''
        loss, _, acc, steps = self._get_loss_and_metrics(batch)

        # logging
        self.log('train/steps', steps)
        self.log('train/accuracy', acc)
        self.log('train/total_loss', loss.get_total_loss())
        self.log('train/reconstruction_loss', loss.get_rec_loss())
        self.log('train/regularization_loss', loss.get_reg_loss())

        return loss.get_total_loss()

    def validation_step(self, batch, batch_idx):
        '''
            Perform the validation step. Logs relevant metrics and returns
            the predictions to be used in a custom callback.

            Parameters
            ----------
            batch : (torch.Tensor, torch.Tensor)
                Current validation batch to evaluate.

            Returns
            -------
            preds : torch.Tensor
                Predictions for the current batch.
        '''
        loss, preds, acc, steps = self._get_loss_and_metrics(batch)

        # logging
        self.log('val/steps', steps)
        self.log('val/accuracy', acc)
        self.log('val/total_loss', loss.get_total_loss())
        self.log('val/reconstruction_loss', loss.get_rec_loss())
        self.log('val/regularization_loss', loss.get_reg_loss())

        # for custom callback
        return preds

    def test_step(self, batch, batch_idx, dataset_idx=0):
        '''
            Perform the test step. Returns relevant metrics.

            Parameters
            ----------
            batch : (torch.Tensor, torch.Tensor)
                Current teest batch to evaluate.

            Returns
            -------
            acc : torch.Tensor
                Accuracy for the current batch.

            steps : torch.Tensor
                Average number of steps for the current batch.
        '''
        _, _, acc, steps = self._get_loss_and_metrics(batch)

        # logging
        self.log(f'test_{dataset_idx}/steps', steps)
        self.log(f'test_{dataset_idx}/accuracy', acc)

    def configure_optimizers(self):
        '''
            Configure the optimizers and learning rate schedulers.

            Returns
            -------
            config : dict
                Dictionary with `optimizer` and `lr_scheduler` keys, with an
                optimizer and a learning scheduler respectively.
        '''
        optimizer = Adam(self.parameters(), lr=self.lr)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(optimizer, mode='max', verbose=True),
                "monitor": 'val/accuracy',
                "interval": 'epoch',
                "frequency": 1
            }
        }

    def configure_callbacks(self):
        '''returns a list of callbacks'''
        # we choose high patience sine we validate 4 times per epoch to have nice graphs
        early_stopping = EarlyStopping(monitor='val/accuracy', mode='max', patience=6)
        model_checkpoint = ModelCheckpoint(monitor="val/accuracy", mode='max')
        return [early_stopping, model_checkpoint]

    def _get_loss_and_metrics(self, batch):
        '''
            Returns the losses, the predictions, the accuracy and the number of steps.

            Parameters
            ----------
            batch : (torch.Tensor, torch.Tensor)
                Batch to process.

            Returns
            -------
            loss : Loss
                Loss object from which all three losses can be retrieved.

            preds : torch.Tensor
                Predictions for the current batch.

            acc : torch.Tensor
                Accuracy obtained with the current batch.

            steps : torch.Tensor
                Average number of steps in the current batch.
        '''
        # extract the batch
        data, target = batch

        # forward pass
        y, p, halted_step = self(data)

        # remove elements with infinities (after taking the log)
        if torch.any(p == 0) and self.training:
            valid_indices = torch.all(p != 0, dim=0)
            p = p[:, valid_indices]
            y = y[:, valid_indices]
            halted_step = halted_step[valid_indices]
            target = target[valid_indices]

        # calculate the loss
        loss_rec_ = self.loss_rec(p, y, target)
        loss_reg_ = self.loss_reg(p)
        loss = Loss(loss_rec_, loss_reg_, self.beta)

        halted_index = (halted_step - 1).unsqueeze(0).unsqueeze(2).repeat(1, 1, self.n_classes)

        # calculate the accuracy
        logits = y.gather(dim=0, index=halted_index).squeeze()
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, target)

        # calculate the average number of steps
        steps = (halted_step * 1.0).mean()

        return loss, preds, acc, steps

# Run interpolation

Load the MNIST dataset with no rotations and train PonderNet on it. Make sure to edit the `WandbLogger` call so that you log the experiment on your account's desired project.

In [None]:
# initialize datamodule and model
mnist = MNIST_DataModule(batch_size=BATCH_SIZE)
model = PonderMNIST(n_hidden=N_HIDDEN,
                    n_hidden_cnn=N_HIDDEN_CNN,
                    n_hidden_lin=N_HIDDEN_LIN,
                    kernel_size=KERNEL_SIZE,
                    max_steps=MAX_STEPS,
                    lambda_p=LAMBDA_P,
                    beta=BETA,
                    lr=LR)

# setup logger
logger = WandbLogger(project='PonderNet', name='interpolation', offline=False)
logger.watch(model)

trainer = Trainer(
    logger=logger,                      # W&B integration
    gpus=-1,                            # use all available GPU's
    max_epochs=EPOCHS,                  # maximum number of epochs
    gradient_clip_val=GRAD_NORM_CLIP,   # gradient clipping
    val_check_interval=0.25,            # validate 4 times per epoch
    precision=16,                       # train in half precision
    deterministic=True)                 # for reproducibility

# fit the model
trainer.fit(model, datamodule=mnist)

# evaluate on the test set
trainer.test(model, datamodule=mnist)

wandb.finish()

# Run extrapolation
Train PonderNet on slightly rotated MNIST pictures, while testing on more pronounced rotations. As before, make sure to edit the `WandbLogger` accordingly.

In [None]:
def get_transforms():
    # define transformations
    transform_22 = transforms.Compose([
        transforms.RandomRotation(degrees=22.5),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    transform_45 = transforms.Compose([
        transforms.RandomRotation(degrees=45),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    transform_67 = transforms.Compose([
        transforms.RandomRotation(degrees=67.5),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    transform_90 = transforms.Compose([
        transforms.RandomRotation(degrees=90),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_transform = transform_22
    test_transform = [transform_22, transform_45, transform_67, transform_90]

    return train_transform, test_transform

In [None]:
train_transform, test_transform = get_transforms()

# initialize datamodule and model
mnist = MNIST_DataModule(batch_size=BATCH_SIZE,
                         train_transform=train_transform,
                         test_transform=test_transform)
model = PonderMNIST(n_hidden=N_HIDDEN,
                    n_hidden_cnn=N_HIDDEN_CNN,
                    n_hidden_lin=N_HIDDEN_LIN,
                    kernel_size=KERNEL_SIZE,
                    max_steps=MAX_STEPS,
                    lambda_p=LAMBDA_P,
                    beta=BETA,
                    lr=LR)

# setup logger
logger = WandbLogger(project='PonderNet', name='extrapolation', offline=False)
logger.watch(model)

trainer = Trainer(
    logger=logger,                      # W&B integration
    gpus=-1,                            # use all available GPU's
    max_epochs=EPOCHS,                  # maximum number of epochs
    gradient_clip_val=GRAD_NORM_CLIP,   # gradient clipping
    val_check_interval=0.25,            # validate 4 times per epoch
    precision=16,                       # train in half precision
    deterministic=True)                 # for reproducibility

# fit the model
trainer.fit(model, datamodule=mnist)

# evaluate on the test set
trainer.test(model, datamodule=mnist)