In [1]:
import torch
import math
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy, precision
import torchmetrics.functional as tf

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar

from positional_encodings.torch_encodings import PositionalEncodingPermute2D

from data.cifar100 import CIFAR100DataModule
from data.cifar10 import CIFAR10DataModule
from data.mnist import MNISTDataModule

In [2]:
# Set this to whatever folder you wish CIFAR-100 to be downloaded into
CIFAR = "/media/curttigges/project-files/datasets/cifar-10/"
MNIST = "/media/curttigges/project-files/datasets/mnist/"

In [3]:
class PositionalImageEmbedding(nn.Module):
    """Reshapes images and concatenates position encoding
    
    Initializes position encoding,
    
    Args:

    Returns:
    
    """
    def __init__(self, input_shape, input_channels, embed_dim, bands=4):
        super().__init__()

        self.ff = self.fourier_features(
            shape=input_shape, bands=bands)
        self.conv = nn.Conv1d(input_channels + self.ff.shape[0], embed_dim, 1)


    def forward(self, x):
        # initial x of shape [BATCH_SIZE x CHANNELS x HEIGHT x WIDTH]
        

        # create position encoding of the same shape as x
        enc = self.ff.unsqueeze(0).expand(
            (x.shape[0],) + self.ff.shape)
        enc = enc.type_as(x)
        #print(enc.shape)
        x = torch.cat([x, enc], dim=1)
        # concatenate position encoding along the channel dimension
        # shape is now [BATCH_SIZE x COLOR_CHANNELS + POS_ENC_CHANNELS x HEIGHT x WIDTH]

        x = x.flatten(2)
        # reshape to [BATCH_SIZE x CHANNELS x HEIGHT*WIDTH]

        x = self.conv(x)
        # shape is now [BATCH_SIZE x EMBED_DIM x HEIGHT*WIDTH]

        x = x.permute(2, 0, 1)
        # shape is now [HEIGHT*WIDTH x BATCH_SIZE x EMBED_DIM]

        return x

    #def learnable_pos_embed(self):
    #    pos_embed = nn.Parameter(torch.randn(1, 1 + num_patches, embed_size))

    def fourier_features(self, shape, bands):
        # This first "shape" refers to the shape of the input data, not the output of this function
        dims = len(shape)

        # Every tensor we make has shape: (bands, dimension, x, y, etc...)

        # Pos is computed for the second tensor dimension
        # (aptly named "dimension"), with respect to all
        # following tensor-dimensions ("x", "y", "z", etc.)
        pos = torch.stack(list(torch.meshgrid(
            *(torch.linspace(-1.0, 1.0, steps=n) for n in list(shape))
        )))
        pos = pos.unsqueeze(0).expand((bands,) + pos.shape)

        # Band frequencies are computed for the first
        # tensor-dimension (aptly named "bands") with
        # respect to the index in that dimension
        band_frequencies = (torch.logspace(
            math.log(1.0),
            math.log(shape[0]/2),
            steps=bands,
            base=math.e
        )).view((bands,) + tuple(1 for _ in pos.shape[1:])).expand(pos.shape)

        # For every single value in the tensor, let's compute:
        #             freq[band] * pi * pos[d]

        # We can easily do that because every tensor is the
        # same shape, and repeated in the dimensions where
        # it's not relevant (e.g. "bands" dimension for the "pos" tensor)
        result = (band_frequencies * math.pi * pos).view((dims * bands,) + shape)

        # Use both sin & cos for each band, and then add raw position as well
        # TODO: raw position
        result = torch.cat([
            torch.sin(result),
            torch.cos(result),
        ], dim=0)

        return result


In [4]:
class PerceiverAttention(nn.Module):
    """Basic decoder block used both for cross-attention and the latent transformer
    """
    def __init__(self, embed_dim, mlp_dim, n_heads, dropout=0.0):
        super().__init__()

        self.lnorm1 = nn.LayerNorm(embed_dim)
        self.lnormq = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=n_heads)

        self.lnorm2 = nn.LayerNorm(embed_dim)
        self.linear1 = nn.Linear(embed_dim, mlp_dim)
        self.act = nn.GELU()
        self.linear2 = nn.Linear(mlp_dim, embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, q):
        # x will be of shape [PIXELS x BATCH_SIZE x EMBED_DIM]
        # q will be of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM] when this is
        # used for cross-attention; otherwise same as x

        # attention block
        out = self.lnorm1(x)
        #q = self.lnormq(q)
        out, _ = self.attn(query=q, key=x, value=x)
        # out will be of shape [LATENT_DIM x BATCH_SIZE x EMBED_DIM] after matmul
        # when used for cross-attention; otherwise same as x
        
        # first residual connection
        resid = out + q

        # dense block
        out = self.lnorm2(resid)
        out = self.linear1(out)
        out = self.act(out)
        out = self.linear2(out)
        out = self.drop(out)

        # second residual connection
        out = out + resid

        return out


In [5]:
class LatentTransformer(nn.Module):
    """Latent transformer module with n_layers count of decoders
    """
    def __init__(self, embed_dim, mlp_dim, n_heads, dropout, n_layers):
        super().__init__()
        self.transformer = nn.ModuleList([
            PerceiverAttention(
                embed_dim=embed_dim, 
                mlp_dim=mlp_dim, 
                n_heads=n_heads, 
                dropout=dropout) 
            for l in range(n_layers)])

    def forward(self, l):
        
        for trnfr in self.transformer:
            l = trnfr(l, l)
        
        return l

In [6]:
class PerceiverIOBlock(nn.Module):
    """Block consisting of one latent transformer, preceded by an optional cross-attention
    """
    def __init__(self, embed_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers, inner_ca=False):
        super().__init__()
        self.inner_ca = inner_ca
        
        # Optional cross-attention. Can be omitted
        if self.inner_ca:
            self.cross_attention = PerceiverAttention(
                embed_dim, attn_mlp_dim, n_heads=1, dropout=dropout)

        self.latent_transformer = LatentTransformer(
            embed_dim, trnfr_mlp_dim, trnfr_heads, dropout, trnfr_layers)

    def forward(self, x, l):
        if self.inner_ca:
            l = self.cross_attention(x, l)

        l = self.latent_transformer(l)

        return l

In [7]:
class ClassifierIO(nn.Module):
    """Perceiver IO classification calculation
    """
    def __init__(self, embed_dim, output_dim, output_heads, n_classes, dropout=0.0):
        super().__init__()
        # learnable label embedding
        self.n_classes = n_classes
        self.output_dim = output_dim
        
        self.label_emb = nn.Parameter(torch.rand(n_classes, 1, output_dim))
        
        self.output_attn = PerceiverAttention(
            embed_dim, embed_dim * 4, output_heads, dropout=dropout)
        
        self.fc1 = nn.Linear(output_dim, output_dim)
        self.fc2 = nn.Linear(self.n_classes * self.output_dim, n_classes)

    def forward(self, x):
        # latent, batch, embed
        L, B, E = x.shape
        #print(f"Latent shape: {x.shape}")
        #print(f"Output emb shape: {self.label_emb.shape}")
        output_emb = self.label_emb.repeat(1, B, 1)
        #print(f"Output array shape: {output_emb.shape}")

        x = self.output_attn(x, output_emb)
        #print(x.shape)

        x = self.fc1(x)
        x = F.gelu(x)
        x = torch.reshape(x,(B, self.n_classes * self.output_dim))
        #x = x.mean(dim=0)
        x = self.fc2(x)

        return x

In [8]:
class Classifier(nn.Module):
    """Original Perceiver classification calculation
    """
    def __init__(self, embed_dim, latent_dim, batch_size, n_classes):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.fc2 = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        # latent, batch, embed

        x = self.fc1(x)
        x = x.mean(dim=0)
        x = self.fc2(x)

        return x

In [9]:
class PerceiverIO(nn.Module):
    """Full Perceiver IO model
    """
    def __init__(
        self, input_shape, latent_dim, embed_dim, output_dim, attn_mlp_dim, trnfr_mlp_dim, trnfr_heads, 
        dropout, trnfr_layers, n_blocks, n_classes, batch_size, inner_ca=False):
        super().__init__()

        # Initialize latent array
        self.latent = nn.Parameter(
            torch.nn.init.trunc_normal_(
                torch.zeros((latent_dim, 1, embed_dim)), mean=0, std=0.02, a=-2, b=2))

        # Initialize embedding with position encoding
        self.embed = PositionalImageEmbedding(input_shape, 1, embed_dim)

        # Initialize initial block with cross-attention
        self.initial_perceiver_block = PerceiverIOBlock(
            embed_dim=embed_dim, 
            attn_mlp_dim=attn_mlp_dim, 
            trnfr_mlp_dim=trnfr_mlp_dim, 
            trnfr_heads=trnfr_heads, 
            dropout = dropout, 
            trnfr_layers = trnfr_layers,
            inner_ca=True)

        # Initialize arbitrary number of Perceiver blocks; will be transformer
        # blocks unless inner_ca (inner cross-attention) is enabled
        self.perceiver_blocks = nn.ModuleList([
            PerceiverIOBlock(
                embed_dim=embed_dim, 
                attn_mlp_dim=attn_mlp_dim, 
                trnfr_mlp_dim=trnfr_mlp_dim, 
                trnfr_heads=trnfr_heads, 
                dropout = dropout, 
                trnfr_layers = trnfr_layers,
                inner_ca=inner_ca)
            for b in range(n_blocks)])

        # PerceiverIO classification layer
        self.classifier = ClassifierIO(
            embed_dim=embed_dim, output_dim=output_dim, output_heads=8, n_classes=n_classes)

        # Original classification layer
        #self.classifier = Classifier(embed_dim=embed_dim, latent_dim=latent_dim, batch_size=batch_size, n_classes=n_classes)

    def forward(self, x):
        # First we expand our latent query matrix to size of batch
        batch_size = x.shape[0]
        latent = self.latent.expand(-1, batch_size, -1)

        # Next, we pass the image through the embedding module to get flattened input
        x = self.embed(x)

        # 
        latent = self.initial_perceiver_block(x, latent)

        # Next, we iteratively pass the latent matrix and image embedding through
        # perceiver blocks
        for pb in self.perceiver_blocks:
            latent = pb(x, latent)
    
        # Finally, we project the output to the number of target classes
        latent = self.classifier(latent)

        return latent

### Training Setup

In [10]:
class PerceiverIOTrainingModule(pl.LightningModule):
    '''Classic Perceiver

    Args:
        input_shape (tuple of ints): Dimensions of input images
        latent_dim (int): Size of latent array
        embed_dim (int): Size of embedding output from linear projection layer
        attn_mlp_dim (int): Size of MLP
        trnfr_mlp_dim (int): Size transformer MLP
        trnfr_heads (int): Number of self-attention heads in the latent transformer 
        dropout (float): dropout for network
        trnfr_layers (int): Number of decoders in the transformers
        n_blocks (int): Number of Perceiver blocks
        n_classes (int): Number of target classes
        batch_size (int): Batch size
        learning_rate (float): Learning Rate
    '''
    def __init__(
        self,
        input_shape, 
        latent_dim, 
        embed_dim,
        output_dim,
        attn_mlp_dim, 
        trnfr_mlp_dim, 
        trnfr_heads, 
        dropout, 
        trnfr_layers, 
        n_blocks, 
        n_classes,
        batch_size, 
        learning_rate,
        inner_ca):
        super().__init__()

        # Key parameters
        self.save_hyperparameters()

        # Transformer with arbitrary number of encoders, heads, and hidden size
        
        self.model = PerceiverIO(
            input_shape=input_shape,
            latent_dim=latent_dim, 
            embed_dim=embed_dim,
            output_dim=output_dim,
            attn_mlp_dim=attn_mlp_dim, 
            trnfr_mlp_dim=trnfr_mlp_dim, 
            trnfr_heads=trnfr_heads, 
            dropout=dropout, 
            trnfr_layers=trnfr_layers, 
            n_blocks=n_blocks, 
            n_classes=n_classes,
            batch_size=batch_size,
            inner_ca=inner_ca
        )

    def forward(self, x):
        x = self.model(x)        
        return x

    def evaluate(self, batch, stage=None):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate
        )
        gamma = 0.1 ** 0.5
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=3, gamma=gamma, last_epoch=-1, verbose=False)
            
        return {"optimizer":optimizer, "lr_scheduler":lr_scheduler}

In [11]:
model_kwargs = {
    "input_shape":(28, 28),
    "latent_dim":8,
    "embed_dim":16,
    "output_dim":16,
    "attn_mlp_dim":16, 
    "trnfr_mlp_dim":16, 
    "trnfr_heads":8, 
    "dropout":0.1, 
    "trnfr_layers":6, 
    "n_blocks":6, 
    "n_classes":10,
    "batch_size":64,
    "learning_rate":0.003,
    "inner_ca":True
}

In [12]:
cifar10 = CIFAR10DataModule(
    batch_size=model_kwargs["batch_size"], 
    num_workers=12, 
    data_dir=CIFAR)

mnist_dm = MNISTDataModule(
    download_dir=MNIST,
    batch_size=model_kwargs["batch_size"],
    num_workers=12
)

pl.seed_everything(42)
model = PerceiverIOTrainingModule(**model_kwargs)

Global seed set to 42


### Start Training

In [13]:
# Comment out if not using wandb
wandb_logger = WandbLogger(
    project="perceiver", 
    save_dir="training/logs/",
    log_model=True)
wandb_logger.watch(model, log="all")

trainer = pl.Trainer(
    max_epochs=24,
    devices=1,
    accelerator='gpu',
    logger=wandb_logger, #comment out if not using WandB
    callbacks=[TQDMProgressBar(refresh_rate=10)])
    
trainer.fit(model, datamodule=mnist_dm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33mascendant[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params
--------------------------------------
0 | model | PerceiverIO | 85.5 K
--------------------------------------
85.5 K    Trainable params
0         Non-trainable params
85.5 K    Total params
0.342     Total estimated model params size (MB)


Epoch 23: 100%|██████████| 939/939 [23:34<00:00,  1.51s/it, loss=0.338, v_num=8zmg, val_loss=0.400, val_acc=0.875]   
