### Quick note on Transformer Block Unification

Fun aside is that the MLP blocks of a Transformer are actually identical to the Self-Attention blocks of the transformer, with a few changes:

- the query projection is missing, the data itself is the query
- the key, value are data-independent parameters (i.e. the MLP block is really just a cross-attention "soft lookup" into a fixed {key:value} table
- the Softmax (map/reduce non-linearity) is replaced with GeLU (map-only non-linearity)
- the final Linear projection back to the residual pathway is missing

This immediately suggests a unification of these blocks into a more general Transformer "superblock", that is simply wired up to the residual pathway either in parallel (e.g. as in all the heads of a multi-headed self-attention), or in series (as usually done from block to block otherwise). It also suggests in-between generalizations, e.g. multi-headed attention suggests the equivalent use of "groups" in Linear (or Conv) layers. Alternatively, attention could be done over two pools of nodes simultaneously: those where key,value are data-dependent and those that aren't, dispensing with the need for a distinction.

**TLDR**: A much simpler Transformer with a single type of block wired up to a residual pathway in both parallel and in series is possible but to my knowledge has not yet been convincingly achieved.

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange

In [2]:
# toy example
torch.manual_seed(1337)
B,T,C = 8,512,128
n_head = 4
n_embd = C
x = torch.randn(B, T, C)
x.shape

torch.Size([8, 512, 128])

In [3]:
# self-attention block of a single Head
nh = C//n_head # head size, 128/4 = 32
key = nn.Linear(C, nh, bias=False)
query = nn.Linear(C, nh, bias=False)
value = nn.Linear(C, nh, bias=False)
proj = nn.Linear(nh, C, bias=False)
k = key(x)
q = query(x)
v = value(x)
att = torch.softmax(q @ k.transpose(-2,-1), dim=-1)
y = att @ v
r = proj(y) # standard self-attention blocks have one more Linear when back to residual pathway
r.shape

torch.Size([8, 512, 128])

In [4]:
# typical linear block on a Transformer
layer1 = nn.Linear(C, C*4, bias=False)
layer2 = nn.Linear(C*4, C, bias=False)
l1 = F.gelu(layer1(x))
l2 = layer2(l1) # projects back to residual pathway
l2.shape

torch.Size([8, 512, 128])

In [5]:
# linear block is actually attention over a fixed (not data-dependent) {k:v} dict
q = x # change 1: query is simply the input
k = layer1.weight # key and value are data-independent learnable parameters
v = layer2.weight.T
att = F.gelu(q @ k.transpose(-2,-1)) # change 2: using gelu instead of softmax
y = att @ v
y.shape

torch.Size([8, 512, 128])

In [6]:
(l2 == y).all() # cool

tensor(True)

In [7]:
q1a =  F.gelu(layer1(q))
print(q1a.shape)
torch.allclose(q1a,att)

torch.Size([8, 512, 512])


True

In [8]:
q2a =  layer2(q1a)
y =  att @ v
print(q2a.shape)
torch.allclose(q2a,y)

torch.Size([8, 512, 128])


True

In [9]:
(q2a == y).all()

tensor(True)