# Building a Transformer from Scratch

Welcome to Topic 5! In this notebook, we'll build a complete transformer model from scratch, implementing every component ourselves. This hands-on approach will solidify your understanding of how transformers work.

## Learning Objectives

By the end of this notebook, you will:
- Implement a transformer from the ground up
- Understand every component in detail
- Train a small transformer on a toy task
- Debug and visualize the training process
- Gain practical implementation experience

## Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple, Dict, List
import math
import time
from tqdm import tqdm

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Plotting configuration
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 1. Building Blocks: Scaled Dot-Product Attention

Let's start with the fundamental building block.

In [None]:
def scaled_dot_product_attention(query: torch.Tensor, 
                               key: torch.Tensor, 
                               value: torch.Tensor,
                               mask: Optional[torch.Tensor] = None,
                               dropout: Optional[nn.Dropout] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute scaled dot-product attention.
    
    Args:
        query: [batch_size, n_heads, seq_len, d_k]
        key: [batch_size, n_heads, seq_len, d_k]
        value: [batch_size, n_heads, seq_len, d_v]
        mask: [batch_size, 1, 1, seq_len] or [batch_size, 1, seq_len, seq_len]
        dropout: Dropout layer
    
    Returns:
        output: [batch_size, n_heads, seq_len, d_v]
        attention_weights: [batch_size, n_heads, seq_len, seq_len]
    """
    d_k = query.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply dropout if provided
    if dropout is not None:
        attention_weights = dropout(attention_weights)
    
    # Apply attention to values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Test scaled dot-product attention
batch_size = 2
n_heads = 8
seq_len = 10
d_k = 64

# Create random tensors
query = torch.randn(batch_size, n_heads, seq_len, d_k)
key = torch.randn(batch_size, n_heads, seq_len, d_k)
value = torch.randn(batch_size, n_heads, seq_len, d_k)

# Compute attention
output, attention_weights = scaled_dot_product_attention(query, key, value)

print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights[0, 0].detach().numpy(), cmap='Blues', cbar=True)
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights Visualization')
plt.tight_layout()
plt.show()

## 2. Multi-Head Attention Implementation

Now let's build the multi-head attention module.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        # Xavier initialization
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
                
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # Project and reshape to [batch_size, n_heads, seq_len, d_k]
        Q = self.W_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply scaled dot-product attention
        attention_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask, self.dropout
        )
        
        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # Final linear projection
        output = self.W_o(attention_output)
        
        return output, attention_weights

# Test multi-head attention
d_model = 512
n_heads = 8
mha = MultiHeadAttention(d_model, n_heads)

# Create input
x = torch.randn(2, 10, d_model)
output, attn_weights = mha(x, x, x)  # Self-attention

print(f"Multi-Head Attention:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in mha.parameters()):,}")

## 3. Position-wise Feed-Forward Network

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()  # Using GELU instead of ReLU (modern choice)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Test feed-forward network
d_model = 512
d_ff = 2048
ffn = PositionwiseFeedForward(d_model, d_ff)

x = torch.randn(2, 10, d_model)
output = ffn(x)

print(f"Feed-Forward Network:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in ffn.parameters()):,}")

## 4. Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # Create div_term for sinusoidal pattern
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cosine to odd indices  
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Add positional encoding
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

# Visualize positional encoding
d_model = 128
max_len = 100
pe_layer = PositionalEncoding(d_model, max_len, dropout=0)

# Get positional encodings
pe_matrix = pe_layer.pe[0, :, :].numpy()

plt.figure(figsize=(12, 8))
plt.imshow(pe_matrix.T, aspect='auto', cmap='RdBu')
plt.colorbar(label='Value')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding Visualization')
plt.tight_layout()
plt.show()

# Plot specific dimensions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
for i, (ax, dim) in enumerate(zip(axes.flat, [0, 1, 20, 21])):
    ax.plot(pe_matrix[:, dim])
    ax.set_title(f'Dimension {dim}')
    ax.set_xlabel('Position')
    ax.set_ylabel('Encoding Value')
    ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Encoder Layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention block
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward block
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

## 6. Decoder Layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor,
                tgt_mask: Optional[torch.Tensor] = None,
                src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention block (masked)
        attn_output, _ = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-attention block
        attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed-forward block
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

## 7. Complete Transformer Model

In [None]:
class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 d_model: int = 512,
                 n_heads: int = 8,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 d_ff: int = 2048,
                 max_seq_length: int = 100,
                 dropout: float = 0.1):
        super().__init__()
        
        # Token embeddings
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # Final linear layer
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def generate_mask(self, src: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask.to(tgt.device)
        return src_mask, tgt_mask
    
    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        # Encoder
        src_embedded = self.dropout(self.encoder_embedding(src) * math.sqrt(self.encoder_embedding.embedding_dim))
        src_embedded = self.positional_encoding(src_embedded)
        
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
            
        # Decoder
        tgt_embedded = self.dropout(self.decoder_embedding(tgt) * math.sqrt(self.decoder_embedding.embedding_dim))
        tgt_embedded = self.positional_encoding(tgt_embedded)
        
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, tgt_mask, src_mask)
            
        # Final projection
        output = self.fc_out(dec_output)
        
        return output

# Create model
model = Transformer(
    src_vocab_size=1000,
    tgt_vocab_size=1000,
    d_model=256,
    n_heads=8,
    num_encoder_layers=3,
    num_decoder_layers=3,
    d_ff=1024,
    max_seq_length=100,
    dropout=0.1
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Test forward pass
src = torch.randint(1, 1000, (2, 10)).to(device)
tgt = torch.randint(1, 1000, (2, 8)).to(device)
output = model(src, tgt)
print(f"\nForward pass:")
print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Output shape: {output.shape}")

## 8. Training on a Toy Task: Number Sorting

Let's train our transformer on a simple task to verify it works.

In [None]:
# Create a toy dataset: sorting sequences of numbers
class SortingDataset(Dataset):
    def __init__(self, num_samples: int, seq_length: int, vocab_size: int):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        
        # Generate data
        self.data = []
        for _ in range(num_samples):
            # Generate random sequence (excluding 0 which is padding)
            seq = torch.randint(1, vocab_size, (seq_length,))
            sorted_seq = torch.sort(seq)[0]
            self.data.append((seq, sorted_seq))
            
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx]

# Create dataset
train_dataset = SortingDataset(1000, seq_length=8, vocab_size=50)
val_dataset = SortingDataset(200, seq_length=8, vocab_size=50)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Show example
src_example, tgt_example = train_dataset[0]
print(f"Example:")
print(f"Input:  {src_example.tolist()}")
print(f"Target: {tgt_example.tolist()}")

In [None]:
# Create smaller model for toy task
model = Transformer(
    src_vocab_size=50,
    tgt_vocab_size=50,
    d_model=128,
    n_heads=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    d_ff=512,
    max_seq_length=20,
    dropout=0.1
).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# Learning rate scheduler
def get_lr(step, d_model, warmup_steps):
    if step == 0:
        step = 1
    return d_model ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))

scheduler = optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda step: get_lr(step, model.encoder_embedding.embedding_dim, warmup_steps=100)
)

In [None]:
# Training function
def train_epoch(model, data_loader, criterion, optimizer, scheduler):
    model.train()
    total_loss = 0
    total_correct = 0
    total_tokens = 0
    
    progress_bar = tqdm(data_loader, desc='Training')
    for src, tgt in progress_bar:
        src = src.to(device)
        tgt = tgt.to(device)
        
        # Teacher forcing: use ground truth as decoder input
        tgt_input = tgt
        
        # Forward pass
        output = model(src, tgt_input)
        
        # Calculate loss
        loss = criterion(output.reshape(-1, output.size(-1)), tgt.reshape(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Calculate accuracy
        predictions = output.argmax(dim=-1)
        correct = (predictions == tgt).sum().item()
        tokens = tgt.numel()
        
        total_loss += loss.item() * tokens
        total_correct += correct
        total_tokens += tokens
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{correct/tokens:.2%}'
        })
    
    return total_loss / total_tokens, total_correct / total_tokens

# Validation function
@torch.no_grad()
def validate(model, data_loader, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_tokens = 0
    
    for src, tgt in data_loader:
        src = src.to(device)
        tgt = tgt.to(device)
        
        output = model(src, tgt)
        loss = criterion(output.reshape(-1, output.size(-1)), tgt.reshape(-1))
        
        predictions = output.argmax(dim=-1)
        correct = (predictions == tgt).sum().item()
        tokens = tgt.numel()
        
        total_loss += loss.item() * tokens
        total_correct += correct
        total_tokens += tokens
    
    return total_loss / total_tokens, total_correct / total_tokens

In [None]:
# Training loop
num_epochs = 10
train_losses = []
train_accs = []
val_losses = []
val_accs = []

print("Starting training...")
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2%}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2%}")

# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Loss plot
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot(train_accs, label='Train Acc')
ax2.plot(val_accs, label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Inference and Visualization

In [None]:
@torch.no_grad()
def greedy_decode(model, src, max_length=20):
    """Greedy decoding for sequence generation."""
    model.eval()
    
    # Start with zeros (we'll overwrite with actual predictions)
    tgt = torch.zeros(1, max_length, dtype=torch.long).to(device)
    
    # Get encoder output once
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    src_embedded = model.encoder_embedding(src) * math.sqrt(model.encoder_embedding.embedding_dim)
    src_embedded = model.positional_encoding(src_embedded)
    
    enc_output = src_embedded
    for enc_layer in model.encoder_layers:
        enc_output = enc_layer(enc_output, src_mask)
    
    # Generate tokens one by one
    for i in range(max_length):
        # Create mask
        tgt_mask = torch.triu(torch.ones(i+1, i+1), diagonal=1).bool().to(device)
        tgt_mask = ~tgt_mask
        tgt_mask = tgt_mask.unsqueeze(0).unsqueeze(0)
        
        # Decode
        tgt_embedded = model.decoder_embedding(tgt[:, :i+1]) * math.sqrt(model.decoder_embedding.embedding_dim)
        tgt_embedded = model.positional_encoding(tgt_embedded)
        
        dec_output = tgt_embedded
        for dec_layer in model.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, tgt_mask, src_mask)
        
        # Get next token
        output = model.fc_out(dec_output)
        next_token = output[0, i].argmax()
        
        if i < max_length - 1:
            tgt[0, i+1] = next_token
            
        # Stop if we predict padding
        if next_token == 0:
            break
    
    return tgt[0, 1:i+1]  # Return without the initial zero

# Test on some examples
print("Testing model predictions:")
print("=" * 50)

for i in range(5):
    src, tgt = val_dataset[i]
    src_tensor = src.unsqueeze(0).to(device)
    
    # Get prediction
    pred = greedy_decode(model, src_tensor, max_length=len(tgt))
    
    print(f"\nExample {i+1}:")
    print(f"Input:      {src.tolist()}")
    print(f"Target:     {tgt.tolist()}")
    print(f"Prediction: {pred.tolist()}")
    print(f"Correct:    {torch.equal(pred[:len(tgt)], tgt.to(device))}")

## 10. Attention Visualization

In [None]:
# Get attention weights from the model
def get_attention_weights(model, src, tgt):
    """Extract attention weights from all layers."""
    model.eval()
    attention_weights = {'encoder': [], 'decoder': [], 'cross': []}
    
    # Hook functions to capture attention weights
    def make_hook(name, layer_idx):
        def hook(module, input, output):
            if hasattr(module, 'self_attention'):
                attention_weights[name].append(output[1])
        return hook
    
    # Register hooks
    handles = []
    for i, layer in enumerate(model.encoder_layers):
        handle = layer.self_attention.register_forward_hook(make_hook('encoder', i))
        handles.append(handle)
    
    # Forward pass
    with torch.no_grad():
        output = model(src, tgt)
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    return attention_weights

# Note: This is a simplified version. In practice, you'd need to modify the model
# to return attention weights properly.

print("Attention visualization would show:")
print("1. Self-attention patterns in encoder")
print("2. Masked self-attention in decoder")
print("3. Cross-attention between encoder and decoder")
print("\nThis helps understand what the model is 'looking at' during processing.")

## Summary

In this notebook, we've built a complete transformer from scratch:

1. **Scaled Dot-Product Attention**: The core attention mechanism
2. **Multi-Head Attention**: Parallel attention with different representations
3. **Feed-Forward Networks**: Position-wise transformations
4. **Positional Encoding**: Adding position information
5. **Encoder/Decoder Layers**: Complete transformer blocks
6. **Full Model**: Putting everything together
7. **Training**: Successfully trained on a toy sorting task
8. **Inference**: Generated sequences using the trained model

Key takeaways:
- Transformers are built from simple, composable blocks
- Attention is the key mechanism enabling long-range dependencies
- Residual connections and layer normalization are crucial for training
- The architecture is highly parallelizable

Next, we'll explore tokenization and embeddings in detail!