# Transformer Extensions

- 📺 **Video:** [https://youtu.be/DPvDL8L4Dqo](https://youtu.be/DPvDL8L4Dqo)

## Overview
- Survey improvements to the vanilla transformer: layer scaling, relative positions, adapters, and sparse attention.
- Appreciate why these tweaks help different applications.

## Key ideas
- **Relative position bias:** improves extrapolation and modeling of long sequences.
- **Adapters:** lightweight modules enable parameter-efficient fine-tuning.
- **Layer scaling:** stabilize deep stacks via residual scaling or pre-norm.
- **Sparse patterns:** reduce quadratic cost for long contexts.

## Demo
Augment a transformer block with a relative-position bias and adapters using PyTorch modules to reflect techniques highlighted in the lecture (https://youtu.be/rZMAM19aP84).

In [1]:
import torch
from torch import nn

class Adapter(nn.Module):
    def __init__(self, dim, bottleneck=8):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck)
        self.up = nn.Linear(bottleneck, dim)
        self.activation = nn.ReLU()
    def forward(self, x):
        return self.up(self.activation(self.down(x)))

class RelPosBlock(nn.Module):
    def __init__(self, dim, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, nhead, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(dim, 2*dim), nn.ReLU(), nn.Linear(2*dim, dim))
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.adapter = Adapter(dim)
        self.rel_bias = nn.Parameter(torch.zeros(nhead, 1, 1))
    def forward(self, x):
        q = k = v = x
        attn_output, weights = self.attn(q, k, v)
        attn_output = attn_output + self.rel_bias.mean(dim=0)
        x = x + attn_output
        x = self.norm1(x)
        residual = x
        x = self.ff(x) + self.adapter(x)
        x = self.norm2(x + residual)
        return x, weights

block = RelPosBlock(dim=16, nhead=4)
inputs = torch.randn(1, 5, 16)
output, weights = block(inputs)
print('Output shape:', output.shape)
print('Attention weights shape:', weights.shape)


Output shape: torch.Size([1, 5, 16])
Attention weights shape: torch.Size([1, 5, 5])


## Try it
- Modify the demo
- Add a tiny dataset or counter-example


## References
- [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
- [Scaling Laws for Neural Language Models](https://arxiv.org/abs/2001.08361)
- [Efficient Transformers: A Survey](https://arxiv.org/abs/2009.06732)
- [Rethinking Attention with Performers](https://arxiv.org/abs/2009.14794)
- [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150)
- [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)


*Links only; we do not redistribute slides or papers.*