# `016` Transformers

Requirements: 014 Attention and dropout, 015 Residual connections

☢️☢️ WIP ☢️☢️

Attention mechanisms were proposed as a way to make the RNN inputs contain information about other terms in the sequence. This way, every sequence element can be contextualized properly with the information from other tokens. However, an architecture called transformer proposed by [Vaswani et al., 2017](ttps://arxiv.org/pdf/2002.04745v1.pdf) took the world by surprise.

Basically, he removed the RNN layers and used just a bunch of linear layers, attention mechanisms, residual connections and normalization. Applied in the context of German to English translation, the architecture achieved better quality (BLEU score) than any previous model. Furthermore, applying it over general text corpuses created a level of generalization pretty impressive and scalable with model size.

In this notebook I will define a transformer block, build a model with many of them, and train it over a corpus of educational content called FineWeb-EDU.

In [1]:
import torch

In [None]:
class TransformerBlock(torch.nn.Module):
	def __init__(self, embed_dim, num_heads, dropout):
		super().__init__()
		self.attention_heads = torch.nn.MultiheadAttention(embed_dim, num_heads, dropout)
		self.norm1 = torch.nn.RMSNorm(embed_dim)
		self.ff = torch.nn.Linear(embed_dim, embed_dim)
		self.norm2 = torch.nn.RMSNorm(embed_dim)

	def forward(self, x):
		x = x + self.attention_heads(x)
		x = self.norm1(x)
		x = x + self.ff(x)
		x = torch.relu(x)
		x = self.norm2(x)
		return x

In [None]:
class Transformer:
	def __init__(self, num_embeddings, output_dim, embed_dim=32, num_blocks=4, num_heads=4, dropout=.2):
		super().__init__()
		self.tok_embed = torch.nn.Embedding(num_embeddings, embed_dim)
		self.pos_embed = torch.nn.Embedding(num_embeddings, embed_dim)
		self.blocks = torch.nn.Sequential([
			TransformerBlock(embed_dim, num_heads, dropout)
			for _ in range(num_blocks)
		])
		self.output = torch.nn.Linear(embed_dim, output_dim)
	
	def forward(self, x):
		x = self.tok_embed(x) + self.pos_embed(torch.arange(x.size(-1)))
		x = self.blocks(x)
		x = self.output(x)
		return x