# Transformer Implementation
This notebook contains an example implementation of a transformer model. Some things are omitted, for example:
- I did not provide a loss calculation or training loop
- I did not include dropout because I was worried the provided code was getting more complicated already
- I assume that inputs are already tokenized and embedded

For a more complete, but still accessible, implementation, I recommend looking at:
- https://github.com/karpathy/nanoGPT (early version)
- https://github.com/karpathy/nanochat (newer release)

See also the original video here: https://www.youtube.com/watch?v=kCc8FmEb1nY.

## Implement Attention

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from matplotlib import pyplot as plt

In [None]:
#As a simple function
def attention(q,k,v):
    mask    = torch.tril( torch.ones(q.shape[-2],q.shape[-2]) ).bool()
    qkT     = torch.matmul( q, k.transpose(-2,-1) ) / (q.shape[-1]**0.5)
    qkT_msk = qkT.masked_fill_(~mask, value=float("-inf"))
    weights = F.softmax(qkT_msk,-1)
    att     = torch.matmul(weights,v)
    return att

In [None]:
#As a PyTorch module
class SelfAttentionHead(nn.Module):
    '''
    Defines a single attention head
    '''
    def __init__(self, d_embed, d_cntxt, d_k, d_v):
        super().__init__()
        
        # self.d_embed = d_embed
        # self.d_cntxt = d_cntxt
        self.d_k     = d_k
        # self.d_v     = d_v
        #trainable maps
        self.query = nn.Linear(d_embed,d_k,bias=False)
        self.key   = nn.Linear(d_embed,d_k,bias=False)
        self.value = nn.Linear(d_embed,d_v,bias=False)
        #not trainable mask
        self.register_buffer('mask', torch.tril( torch.ones(d_cntxt,d_cntxt) ).bool())
        
    def attention(self,q,k,v,inspect=False):
        qkT     = torch.matmul( q, k.transpose(-2,-1) )
        qkT_scl = qkT / torch.sqrt(torch.tensor(self.d_k))
        qkT_msk = qkT_scl.clone()
        qkT_msk = qkT_msk.masked_fill_(~self.mask, value=float("-inf"))
        weights = F.softmax(qkT_msk,-1)
        att = torch.matmul(weights,v)
        if not inspect:
            return att
        else:
            return { 'qkT':qkT, 'qkT_scl':qkT_scl, 'qkT_msk':qkT_msk, 'weights':weights, 'attention':att }

    def forward(self,x):
        q = query(x)
        k = key(x)
        v = value(x)
        
        return self.attention(q,k,v)

## Experiment

### Explore inputs

In [None]:
#define various dimensions
d_batch = 2   #batch dimension
d_embed = 300 #length of embeddings
d_cntxt = 8   #context size
d_k     = 4   #length of queries and keys
d_v     = 4   #length of values

n_head  = 4   #number of attention heads to run in parallel
n_block = 6   #number of blocks in the transformer

In [None]:
#define maps between embeddings and queries/keys/values
query = nn.Linear(d_embed,d_k,bias=False)
key   = nn.Linear(d_embed,d_k,bias=False)
value = nn.Linear(d_embed,d_v,bias=False)

In [None]:
#generate a random input
x = torch.rand(d_batch,d_cntxt,d_embed)
x.shape

In [None]:
#run query/key/value and look at sizes
q, k, v = [ fn(x) for fn in [ query, key, value ] ]
q.shape, k.shape, v.shape

### Now run attention

In [None]:
#generate an instance of the above class
attnHead = SelfAttentionHead(d_embed,d_cntxt,d_k,d_v)

In [None]:
#inspect the mask
#- values marked with True will be retained
#- values marked with False will be zeroed out
attnHead.mask

In [None]:
plt.figure(figsize=(4,4))
plt.imshow(attnHead.mask)
plt.tight_layout()
plt.savefig('figures/mask.png')

In [None]:
#run the module
out = attnHead(x)
out.shape

### Inspect steps
The `attention` function in the PyTorch module includes an `inspect` option that returns a dictionary of its steps

In [None]:
#generate i.i.d. inputs from the unit normal distribution
q = torch.randn(d_batch,d_cntxt,d_k)
k = torch.randn(d_batch,d_cntxt,d_k)
v = torch.randn(d_batch,d_cntxt,d_v)

#run attention and print keys
d = attnHead.attention(q,k,v,inspect=True)
d.keys()

In [None]:
#QK^T has entries with standard deviation ~d_k
d['qkT'].mean(), d['qkT'].var(-1)

In [None]:
#once scaled, standard deviations are closer to 1
d['qkT_scl'].mean(-1), d['qkT_scl'].var(-1)

In [None]:
#print the weights used to compute the outputs
d['weights']

### Compare implementations

In [None]:
#module and function implementations produce the same output
at1 = attnHead.attention(q,k,v)
at2 = attention(q,k,v)
torch.allclose(at1,at2)

In [None]:
#module and PyTorch implementations produce the same output
# note: scaled_dot_product_attention() available in PyTorch > 2.0
at3 = F.scaled_dot_product_attention(q,k,v,attn_mask=attnHead.mask)
torch.allclose(at1,at3)

## Finishing the Transformer
With attention in place, we now build out the rest of the transformer

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    '''
    Defines a layer of multihead attention
    '''
    def __init__(self, n_head, d_embed, d_cntxt, d_k, d_v):
        super().__init__()
        
        #list of attention heads to be run side by side
        self.attnHeads = [ SelfAttentionHead(d_embed,d_cntxt,d_k,d_v) for _ in range(n_head) ]
        #projection back to the embedding dimension
        self.toEmbed   = nn.Linear(d_v*n_head,d_embed)
        
    def forward(self,x):
        #run all heads and concatenate
        out = torch.cat([ head(x) for head in self.attnHeads],-1)
        #run the linear layer to project back to the embedding dimension
        out = self.toEmbed(out)
        return out


class TransformerBlock(nn.Module):
    '''
    Defines a block of a transformer model
    '''
    def __init__(self, n_head, d_embed, d_cntxt, d_k, d_v, ffw_multiplier=4, ffw_activation=nn.ReLU()):
        super().__init__()
        
        #multi-head attention layer
        self.multiHead = MultiHeadAttentionLayer(n_head,d_embed,d_cntxt,d_k,d_v)
        
        #feedforward layer
        self.ffw   = nn.Sequential(
            nn.Linear(d_embed,d_embed*ffw_multiplier),
            ffw_activation,
            nn.Linear(d_embed*ffw_multiplier,d_embed),
        )
        
        #layer normalizations
        self.norm1 = nn.LayerNorm(d_embed)
        self.norm2 = nn.LayerNorm(d_embed)
        
    def forward(self,x):
        #run multi-head attention layer
        #include residual connection and layer norm
        x = self.norm1( x + self.multiHead(x) )
        #run feedforward layer
        x = self.norm2( x + self.ffw(x) )
        return x

class TransformerModel(nn.Module):
    '''
    Defines a transformer model
    '''
    def __init__(self, n_layer, n_head, d_vocab, d_embed, d_cntxt, d_k, d_v, ffw_multiplier=4, ffw_activation=nn.ReLU(), pe_orig=True):
        super().__init__()
        
        #blocks (each multi-head attention + feedforward)
        self.blocks  = nn.Sequential(*[ 
            TransformerBlock(n_head,d_embed,d_cntxt,d_k,d_v,ffw_multiplier=ffw_multiplier,ffw_activation=ffw_activation) 
            for _ in range(n_layer) 
        ])
        
        #linear layer to return to vocab size
        self.toVocab = nn.Linear(d_embed,d_vocab)
        
        #not trainable mask
        self.register_buffer('p_embed', self.pos_embed(d_cntxt,d_embed,original=pe_orig) )

    def pos_embed(self,d_cntxt,d_embed,original=True):
        p_embed = torch.zeros(d_cntxt,d_embed)
        if original:
            #PE(pos,2i  ) = sin((pos/10000)^(2i/d_embed))
            #PE(pos,2i+1) = cos((pos/10000)^(2i/d_embed))
            den = torch.pow(10000,-torch.arange(0,d_embed,2)/d_embed)
        else:
            #regular fourier basis seems more sane?
            den = 2*torch.pi*torch.arange(0,d_embed,2)/d_embed
        pos = torch.arange(d_cntxt)
        ang = torch.outer(pos,den)
        p_embed[:, ::2] = torch.sin(ang)
        p_embed[:,1::2] = torch.cos(ang)
        return p_embed

    def forward(self,x):
        #add positional embedding
        x = x + self.p_embed
        #run blocks
        attn_out = self.blocks(x)
        #run linear layer
        logits = self.toVocab(attn_out)
        #convert to probabilities
        probs  = F.softmax(logits,-1)
        return probs #for training may want to add loss or return logits here
    
    def predict(self,x):
        #get probabilities associated with last entry in context
        probs = self(x)[:,-1,:]
        #draw to get a prediction (probabilistic!)
        idx = torch.multinomial(probs, num_samples=1)
        return idx

In [None]:
#multi-head attention
multiHead = MultiHeadAttentionLayer(n_head,d_embed,d_cntxt,d_k,d_v)
multiHead(x).shape

In [None]:
#transformer block
trBlock = TransformerBlock(n_head,d_embed,d_cntxt,d_k,d_v,ffw_activation=nn.GELU())
trBlock(x).shape

In [None]:
#transformer block
n_layer = 6
d_vocab = 10000
transformer = TransformerModel(n_layer,n_head,d_vocab,d_embed,d_cntxt,d_k,d_v,ffw_activation=nn.GELU())
out = transformer(x)
out.shape, out.sum(-1)

In [None]:
#make a prediction (one for each batch)
transformer.predict(x)

## View Position Embeddings
This is a bit of an aside, but since we implemented the position embeddings above, let's visualize them.

In [None]:
def plot_pe(p_embed,separate=True):
    fig,ax = plt.subplots(1,1,figsize=(10,4))
    for i in range(p_embed.shape[0]):
        if separate:
            ax.plot(p_embed[i, ::2],label=f'Position {i} (sin)')
            ax.plot(p_embed[i,1::2],label=f'Position {i} (cos)')
        else:
            ax.plot(p_embed[i,   :],label=f'Position {i}')
    ax.set_xlabel("Embedding Dimension")
    ax.set_ylabel("Positional Embedding Value")
    # ax.legend()
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    return fig

In [None]:
#original version - seems like a lot of duplicate information?
plot_pe( transformer.pos_embed(d_cntxt,d_embed,original=True) );

In [None]:
#fourier version
plot_pe( transformer.pos_embed(d_cntxt,d_embed,original=False) );