In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

from positional_encodings.torch_encodings import PositionalEncodingPermute2D

from data.cifar100 import CIFAR100DataModule
from data.cifar10 import CIFAR10DataModule

In [2]:
# Set this to whatever folder you wish CIFAR-100 to be downloaded into
CIFAR = "/media/curttigges/project-files/datasets/cifar-100/"

In [3]:
#class FourierPositionEncoding(nn.Module):

In [4]:
class PositionalImageEmbedding(nn.Module):
    """Reshapes images and concatenates position encoding
    
    Initializes position encoding,
    
    Args:

    Returns:
    
    """
    def __init__(self, input_channels, embed_dim):
        super().__init__()
        self.p_enc = PositionalEncodingPermute2D(input_channels)
        self.conv = nn.Conv1d(input_channels*2, embed_dim, 1)


    def forward(self, x):
        # initial x of shape [BATCH_SIZE x CHANNELS x HEIGHT x WIDTH]
        
        enc = self.p_enc(x)
        # create position encoding of the same shape as x

        x = torch.cat([x, enc], dim=1)
        # concatenate position encoding along the channel dimension
        # shape is now [BATCH_SIZE x COLOR_CHANNELS + POS_ENC_CHANNELS x HEIGHT x WIDTH]

        x = x.flatten(2)
        # reshape to [BATCH_SIZE x CHANNELS x HEIGHT*WIDTH]

        x = self.conv(x)
        # shape is now [BATCH_SIZE x EMBED_DIM x HEIGHT*WIDTH]

        x = x.permute(2, 0, 1)
        # shape is now [HEIGHT*WIDTH x BATCH_SIZE x EMBED_DIM]

        return x

In [5]:
class PerceiverAttention(nn.Module):
    """Basic decoder block used both for cross-attention and the latent transformer
    """
    def __init__(self, embed_dim, mlp_dim, n_heads, dropout=0.0):
        super().__init__()

        self.lnorm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=n_heads)

        self.lnorm2 = nn.LayerNorm(embed_dim)
        self.linear1 = nn.Linear(embed_dim, mlp_dim)
        self.act = nn.GELU()
        self.linear2 = nn.Linear(mlp_dim, embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, q):
        # x will be of shape [PIXELS x BATCH_SIZE x EMBED_DIM]
        # q will be of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM] when this is
        # used for cross-attention; otherwise same as x

        # attention block
        out = self.lnorm1(x)
        out, _ = self.attn(query=q, key=x, value=x)
        # out will be of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM] after matmul
        # when used for cross-attention; otherwise same as x
        
        # first residual connection
        resid = out + q

        # dense block
        out = self.lnorm2(resid)
        out = self.linear1(out)
        out = self.act(out)
        out = self.linear2(out)
        out = self.drop(out)

        # second residual connection
        out = out + resid

        return out


In [6]:
class LatentTransformer(nn.Module):
    def __init__(self, embed_dim, mlp_dim, n_heads, dropout, n_layers):
        super().__init__()
        self.transformer = nn.ModuleList([
            PerceiverAttention(
                embed_dim=embed_dim, 
                mlp_dim=mlp_dim, 
                n_heads=n_heads, 
                dropout=dropout) 
            for l in range(n_layers)])

    def forward(self, l):
        
        for trnfr in self.transformer:
            l = trnfr(l, l)
        
        return l

In [7]:
class PerceiverBlock(nn.Module):
    def __init__(self, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers):
        super().__init__()
        
        self.cross_attention = PerceiverAttention(
            embed_dim, attn_mlp_dim, n_heads=1, dropout=dropout)

        self.latent_transformer = LatentTransformer(
            embed_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers)

    def forward(self, x, l):
        l = self.cross_attention(x, l)

        l = self.latent_transformer(l)

        return l

In [20]:
class Classifier(nn.Module):
    def __init__(self, embed_dim, latent_dim, batch_size, n_classes):
        super().__init__()
        self.batch_size = batch_size
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.fc = nn.Linear(self.embed_dim*self.latent_dim, n_classes)

    def forward(self, x):
        # latent, batch, embed
        L, B, E = x.shape

        #x = x.mean(dim=0)
        x = x.transpose(0, 1)
        #print(x.shape)
        x = torch.reshape(x, (B, E*L))
        #print(x.shape)
        x = self.fc(x)

        return x

In [22]:
class Perceiver(nn.Module):
    def __init__(
        self, latent_dim, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, 
        dropout, trnfr_layers, n_blocks, n_classes, batch_size):
        super().__init__()

        # Initialize latent array
        self.latent = nn.Parameter(
            torch.nn.init.trunc_normal_(
                torch.zeros((latent_dim, 1, embed_dim)), 
                mean=0, 
                std=0.02, 
                a=-2, 
                b=2))
        # In the paper, a truncated normal distribution was used for initialization, 
        # so I used this hidden torch function to create it.

        # Initialize embedding with position encoding
        self.embed = PositionalImageEmbedding(3, embed_dim)

        # Initialize arbitrary number of Perceiver blocks
        self.perceiver_blocks = nn.ModuleList([
            PerceiverBlock(
                embed_dim=embed_dim, 
                attn_mlp_dim=attn_mlp_dim, 
                trnfr_mlp_dim=trnfr_mlp_dim, 
                trnfr_heads=trnfr_heads, 
                dropout = dropout, 
                trnfr_layers = trnfr_layers)
            for b in range(n_blocks)])

        # Initialize classification layer
        self.classifier = Classifier(embed_dim=embed_dim, latent_dim=latent_dim, batch_size=batch_size, n_classes=n_classes)

    def forward(self, x):
        # First we expand our latent query matrix to size of batch
        batch_size = x.shape[0]
        latent = self.latent.expand(-1, batch_size, -1)

        # Next, we pass the image through the embedding module to get flattened input
        x = self.embed(x) 

        # Next, we iteratively pass the latent matrix and image embedding through
        # perceiver blocks
        for pb in self.perceiver_blocks:
            latent = pb(x, latent)
        #print(latent.shape)
    
        # Finally, we project the output to the number of target classes
        latent = self.classifier(latent)

        

        return latent

### Training Setup

In [23]:
class PerceiverTrainingModule(pl.LightningModule):
    '''Classic Perceiver

    Args:
        embed_dim (int): Size of embedding output from linear projection layer
        hidden_dim (int): Size of MLP head
        class_head_dim (int): Size of classification head
        num_encoders (int): Number of encoder layers
        num_heads (int): Number of self-attention heads
        patch_size (int): Size of patches
        num_patches (int): Total count of patches (patch sequence size) 
        dropout (float): Probability of dropout
        batch_size (int): Batch size (used for OneCycleLR)
        learning_rate (float): Maximum learning rate
        weight_decay (float): Optimizer weight decay
    '''
    def __init__(
        self, 
        latent_dim, 
        embed_dim,
        attn_mlp_dim, 
        trnfr_mlp_dim, 
        trnfr_heads, 
        dropout, 
        trnfr_layers, 
        n_blocks, 
        n_classes,
        batch_size, 
        learning_rate):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()

        # Transformer with arbitrary number of encoders, heads, and hidden size
        self.model = Perceiver(
            latent_dim=latent_dim, 
            embed_dim=embed_dim,
            attn_mlp_dim=attn_mlp_dim, 
            trnfr_mlp_dim=trnfr_mlp_dim, 
            trnfr_heads=trnfr_heads, 
            dropout=dropout, 
            trnfr_layers=trnfr_layers, 
            n_blocks=n_blocks, 
            n_classes=n_classes,
            batch_size=batch_size
        )

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate
        )
        
        steps_per_epoch = 60000 // self.hparams.batch_size
   
        lr_scheduler_dict = {
            "scheduler":OneCycleLR(
                optimizer,
                self.hparams.learning_rate,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
                anneal_strategy='cos'
            ),
            "interval":"step",
        }
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler_dict}

In [24]:
model_kwargs = {
    "latent_dim":128,
    "embed_dim":32,
    "attn_mlp_dim":128, 
    "trnfr_mlp_dim":128, 
    "trnfr_heads":8, 
    "dropout":0.0, 
    "trnfr_layers":24, 
    "n_blocks":4, 
    "n_classes":100,
    "batch_size":64,
    "learning_rate":0.01
}

In [25]:
cifar100 = CIFAR100DataModule(
    batch_size=model_kwargs["batch_size"], 
    num_workers=12, 
    data_dir=CIFAR)

pl.seed_everything(42)
model = PerceiverTrainingModule(**model_kwargs)

Global seed set to 42


### Start Training

In [26]:
trainer = pl.Trainer(
    max_epochs=180,
    accelerator='gpu', 
    devices=1,
    #logger=wandb_logger, #comment out if not using WandB
    callbacks=[TQDMProgressBar(refresh_rate=10)])
    
trainer.fit(model, datamodule=cifar100)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Global seed set to 42
Global seed set to 42


Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params
------------------------------------
0 | model | Perceiver | 1.7 M 
------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.738     Total estimated model params size (MB)


Epoch 0:   1%|▏         | 10/783 [06:18<8:08:08, 37.89s/it, loss=458, v_num=22] val_loss=4.290, val_acc=0.0446]
Epoch 5:  37%|███▋      | 290/783 [13:44<23:21,  2.84s/it, loss=4.06, v_num=23, val_loss=4.090, val_acc=0.0776]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
