<a href="https://colab.research.google.com/github/pixelsandpointers/annotated-transformer/blob/main/Spline_based_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementation of Spline-based Transformers
See paper: https://la.disneyresearch.com/publication/spline-based-transformers/  
No official implementation, so this one is based on the additional material's pseudocode.

In [None]:
import einops
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


In [None]:
class Spline(nn.Module):
    N = torch.tensor([                                                          # characteristic matrix
        [ 1,  0,  0,  0],
        [-3,  3,  0,  0],
        [ 3, -6,  3,  0],
        [-1,  3, -3,  1]
    ])
    def __init__(self, trainable: bool = True):
        super(Spline, self).__init__()
        self.param = nn.Parameter(self.N, requires_grad=trainable)              # register parameter

    def forward(self, control_points: torch.Tensor, seq_len: int):
        # assume batch size x seq_len x emb_dim
        device = control_points.device
        batch_size, n_points, _ = control_points.shape
        assert n_points == 4, "Only implements the cubic spline"
        linspace = torch.linspace(0, 1, steps=seq_len, device=device)           # create domain from [0, 1] to evaluate the spline
        linspace = einops.repeat(linspace, 't -> b t', b=batch_size)            # creates batch_size x linspace
        powers = torch.arange(0, n_points, device=device.device).T              # power terms [1, t, t^2, t^3]
        times = einops.rearrange(linspace, '... -> ... 1') ** powers
        return times @ self.param @ control_points


In [None]:
class SBT(nn.Module):
    def __init__(self,
                 n_out: int,
                 n_dim: int = 512,
                 n_enc: int = 6,
                 n_dec: int = 6,
                 n_control_points: int = 4):
        super(SBT, self).__init__()
        self.n_dim = n_dim
        self.n_enc = n_enc
        self.n_dec = n_dec
        self.n_control_points = n_control_points
        self.control_points = nn.Parameter(torch.zeros(n_control_points, n_dim))

        # layers
        self.emb = nn.Sequential(
            nn.Linear(n_dim, n_dim*2),
            nn.GELU(),
            nn.Linear(n_dim*2, n_dim)
            )

        # setup Transformer Encoder (T5 was used in the paper)
        t_enc_layer = nn.TransformerEncoderLayer(n_dim, 8, activation='gelu', batch_first=True)
        self.t_enc = nn.TransformerEncoder(t_enc_layer, num_layers=n_enc)

        # setup Transformer Decoder
        t_dec_layer = nn.TransformerDecoderLayer(n_dim, 8, activation='gelu', batch_first=True)
        self.t_dec = nn.TransformerDecoder(t_dec_layer, num_layers=n_dec)

        # setup Spline
        self.spline = Spline(trainable=True)

        # setup classification head
        self.head = nn.Linear(n_dim, n_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:                         # x = [batch_size x seq_len x n_dim]
        batch_size, seq_len, *_ = x.shape
        emb = self.emb(x)

        # project control points
        control_points = einops.repeat(self.control_points, 'p d -> b p d',
                                            b=batch_size)

        enc_in = torch.concat((control_points, emb), axis=1)
        enc_out = self.t_enc(enc_in)

        control_points = enc_out[:, :self.n_control_points, :]
        latent = self.spline(control_points, seq_len)
        dec_out = self.t_dec(latent, memory=enc_out)

        return dec_out

In [None]:
# Hyperparameters
input_dim = 28 * 28  # For MNIST images
latent_dim = 20
batch_size = 64
lr = 1e-3
epochs = 10

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Model, optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SBT(input_dim, n_dim=input_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        x, _ = batch
        x = x.to(device)
        x = x.unsqueeze(0)

        optimizer.zero_grad()
        recon_x = model(x)
        loss = nn.functional.mse_loss(recon_x, x)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

print("Training complete!")
