# Transformer Variants

This notebook explores different transformer architectures including BERT, GPT, T5, and Vision Transformers. We'll implement each variant and understand their unique characteristics.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple, Dict, List
import math
from dataclasses import dataclass

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Understanding the Three Paradigms

Let's visualize the three main transformer paradigms: Encoder-only, Decoder-only, and Encoder-Decoder.

In [None]:
# Visualize the three paradigms
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Encoder-only (BERT)
ax = axes[0]
ax.text(0.5, 0.9, 'Encoder-only (BERT)', ha='center', fontsize=14, weight='bold')
ax.text(0.5, 0.7, 'Input: [CLS] The cat sat [SEP]', ha='center', fontsize=10)
ax.arrow(0.5, 0.65, 0, -0.1, width=0.02, head_width=0.05, fc='blue', ec='blue')
ax.text(0.5, 0.5, 'Bidirectional\nSelf-Attention', ha='center', fontsize=10, 
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))
ax.arrow(0.5, 0.35, 0, -0.1, width=0.02, head_width=0.05, fc='blue', ec='blue')
ax.text(0.5, 0.2, 'Output: Contextualized\nRepresentations', ha='center', fontsize=10)
ax.text(0.5, 0.05, 'Use: Classification,\nNER, QA', ha='center', fontsize=9, style='italic')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')

# Decoder-only (GPT)
ax = axes[1]
ax.text(0.5, 0.9, 'Decoder-only (GPT)', ha='center', fontsize=14, weight='bold')
ax.text(0.5, 0.7, 'Input: The cat', ha='center', fontsize=10)
ax.arrow(0.5, 0.65, 0, -0.1, width=0.02, head_width=0.05, fc='green', ec='green')
ax.text(0.5, 0.5, 'Causal\nSelf-Attention', ha='center', fontsize=10,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen"))
ax.arrow(0.5, 0.35, 0, -0.1, width=0.02, head_width=0.05, fc='green', ec='green')
ax.text(0.5, 0.2, 'Output: Next token\npredictions', ha='center', fontsize=10)
ax.text(0.5, 0.05, 'Use: Generation,\nCompletion', ha='center', fontsize=9, style='italic')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')

# Encoder-Decoder (T5)
ax = axes[2]
ax.text(0.5, 0.9, 'Encoder-Decoder (T5)', ha='center', fontsize=14, weight='bold')
ax.text(0.25, 0.7, 'Input:\nTranslate:', ha='center', fontsize=9)
ax.text(0.75, 0.7, 'Target:\nTraduire:', ha='center', fontsize=9)
ax.arrow(0.25, 0.6, 0, -0.08, width=0.015, head_width=0.04, fc='red', ec='red')
ax.arrow(0.75, 0.6, 0, -0.08, width=0.015, head_width=0.04, fc='orange', ec='orange')
ax.text(0.25, 0.45, 'Encoder', ha='center', fontsize=10,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
ax.text(0.75, 0.45, 'Decoder', ha='center', fontsize=10,
        bbox=dict(boxstyle="round,pad=0.3", facecolor="lightsalmon"))
ax.arrow(0.35, 0.45, 0.3, 0, width=0.01, head_width=0.03, fc='purple', ec='purple')
ax.text(0.5, 0.48, 'Cross-Attention', ha='center', fontsize=8)
ax.arrow(0.25, 0.35, 0, -0.08, width=0.015, head_width=0.04, fc='red', ec='red')
ax.arrow(0.75, 0.35, 0, -0.08, width=0.015, head_width=0.04, fc='orange', ec='orange')
ax.text(0.5, 0.2, 'Output: Sequence-to-sequence', ha='center', fontsize=10)
ax.text(0.5, 0.05, 'Use: Translation,\nSummarization', ha='center', fontsize=9, style='italic')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')

plt.tight_layout()
plt.show()

## 2. BERT Implementation

Let's implement BERT (Bidirectional Encoder Representations from Transformers) step by step.

In [None]:
@dataclass
class BERTConfig:
    vocab_size: int = 30522
    hidden_size: int = 768
    num_hidden_layers: int = 12
    num_attention_heads: int = 12
    intermediate_size: int = 3072
    hidden_dropout_prob: float = 0.1
    attention_probs_dropout_prob: float = 0.1
    max_position_embeddings: int = 512
    type_vocab_size: int = 2  # For segment embeddings

class BERTEmbeddings(nn.Module):
    """BERT embeddings: token + position + segment."""
    
    def __init__(self, config: BERTConfig):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
        self.LayerNorm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
            
        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

# Test BERT embeddings
config = BERTConfig(vocab_size=1000, hidden_size=128, num_hidden_layers=2)
embeddings = BERTEmbeddings(config)

# Example: two sentences
input_ids = torch.tensor([[101, 7592, 1010, 2045, 102, 2129, 2024, 2017, 102]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1]])

embedded = embeddings(input_ids, token_type_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Embedded shape: {embedded.shape}")
print(f"\nFirst 5 dimensions of first token embedding: {embedded[0, 0, :5]}")

### BERT's Masked Language Modeling (MLM)

In [None]:
def create_mlm_masks(input_ids, vocab_size, mask_prob=0.15, mask_token_id=103):
    """
    Create masks for MLM training.
    - 80% of the time: Replace with [MASK]
    - 10% of the time: Replace with random token
    - 10% of the time: Keep original
    """
    labels = input_ids.clone()
    
    # Create random mask
    probability_matrix = torch.full(labels.shape, mask_prob)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    
    # Don't mask special tokens (assuming [CLS]=101, [SEP]=102, [PAD]=0)
    special_tokens_mask = (input_ids == 101) | (input_ids == 102) | (input_ids == 0)
    masked_indices = masked_indices & ~special_tokens_mask
    
    # 80% of time, replace with [MASK]
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    input_ids[indices_replaced] = mask_token_id
    
    # 10% of time, replace with random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vocab_size, labels.shape, dtype=torch.long)
    input_ids[indices_random] = random_words[indices_random]
    
    # 10% of time, keep original (do nothing)
    
    # Only compute loss on masked tokens
    labels[~masked_indices] = -100
    
    return input_ids, labels, masked_indices

# Demonstrate MLM
original_text = "The quick brown fox jumps over the lazy dog"
tokens = original_text.split()
vocab = {word: i+104 for i, word in enumerate(set(tokens))}  # Start from 104 to avoid special tokens
vocab.update({'[CLS]': 101, '[SEP]': 102, '[MASK]': 103, '[PAD]': 0})

# Convert to IDs
input_ids = torch.tensor([[vocab['[CLS]']] + [vocab[word] for word in tokens] + [vocab['[SEP]']]])
masked_input, labels, mask_indices = create_mlm_masks(input_ids.clone(), len(vocab))

print("Original tokens:", ['[CLS]'] + tokens + ['[SEP]'])
print("\nMasked input IDs:", masked_input[0].tolist())
print("\nLabels (-100 means ignore):", labels[0].tolist())
print("\nMasked positions:", mask_indices[0].tolist())

## 3. GPT Implementation

Now let's implement GPT (Generative Pre-trained Transformer) with causal attention.

In [None]:
class CausalSelfAttention(nn.Module):
    """GPT-style causal self-attention."""
    
    def __init__(self, n_embd, n_head, n_positions=1024, attn_pdrop=0.1):
        super().__init__()
        assert n_embd % n_head == 0
        
        self.n_head = n_head
        self.n_embd = n_embd
        
        # Query, Key, Value projections
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        
        self.attn_dropout = nn.Dropout(attn_pdrop)
        
        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(n_positions, n_positions))
                                     .view(1, 1, n_positions, n_positions))
        
    def forward(self, x):
        B, T, C = x.size()
        
        # Calculate query, key, values for all heads
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        # Causal self-attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        
        return y, att

# Demonstrate causal attention
seq_len = 8
n_embd = 64
n_head = 4

causal_attn = CausalSelfAttention(n_embd, n_head, seq_len)
x = torch.randn(1, seq_len, n_embd)
output, attention = causal_attn(x)

# Visualize causal mask
plt.figure(figsize=(8, 6))
mask = causal_attn.bias[0, 0, :seq_len, :seq_len]
plt.imshow(mask, cmap='Blues', aspect='auto')
plt.colorbar()
plt.title('Causal Attention Mask')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
for i in range(seq_len):
    for j in range(seq_len):
        plt.text(j, i, int(mask[i, j].item()), ha='center', va='center')
plt.show()

# Show actual attention pattern
plt.figure(figsize=(8, 6))
plt.imshow(attention[0, 0].detach().numpy(), cmap='hot', aspect='auto')
plt.colorbar()
plt.title('Causal Attention Pattern (Head 1)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

### GPT Text Generation

In [None]:
def generate_text(model, input_ids, max_new_tokens=20, temperature=1.0, top_k=None):
    """
    Generate text using GPT-style autoregressive generation.
    """
    model.eval()
    
    for _ in range(max_new_tokens):
        # Forward pass
        with torch.no_grad():
            logits = model(input_ids)
            
        # Get logits for the last position
        logits = logits[:, -1, :] / temperature
        
        # Optional top-k sampling
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
            
        # Sample from the distribution
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append to sequence
        input_ids = torch.cat((input_ids, next_token), dim=1)
        
    return input_ids

# Simple demo of generation process
print("GPT Generation Process:")
print("1. Start with prompt: 'The cat'")
print("2. Model predicts next token probabilities")
print("3. Sample from distribution (e.g., 'sat' with p=0.3)")
print("4. Append to sequence: 'The cat sat'")
print("5. Repeat until desired length")

# Visualize generation probabilities
vocab_example = ['the', 'cat', 'sat', 'on', 'mat', 'jumped', 'ran', 'slept']
probs_example = [0.05, 0.02, 0.3, 0.2, 0.15, 0.1, 0.08, 0.1]

plt.figure(figsize=(10, 6))
plt.bar(vocab_example, probs_example)
plt.title('Next Token Probabilities (Example)')
plt.xlabel('Tokens')
plt.ylabel('Probability')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 4. T5 Implementation

T5 (Text-to-Text Transfer Transformer) treats every NLP task as text generation.

In [None]:
class T5RelativePositionBias(nn.Module):
    """T5 uses relative position biases instead of absolute position embeddings."""
    
    def __init__(self, num_buckets=32, max_distance=128, n_heads=8):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.n_heads = n_heads
        self.relative_attention_bias = nn.Embedding(num_buckets, n_heads)
        
    def _relative_position_bucket(self, relative_position):
        """Translate relative position to a bucket number."""
        ret = 0
        n = -relative_position
        
        # Each bucket covers a range of positions
        num_buckets = self.num_buckets
        max_distance = self.max_distance
        
        # Half of the buckets are for positive positions
        num_buckets //= 2
        ret += (n < 0).long() * num_buckets
        n = torch.abs(n)
        
        # Exact buckets for small positions
        max_exact = num_buckets // 2
        is_small = n < max_exact
        
        # Logarithmic buckets for large positions
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        
        ret += torch.where(is_small, n, val_if_large)
        return ret
    
    def forward(self, query_length, key_length):
        """Compute relative position bias."""
        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
        
        relative_position = memory_position - context_position
        relative_position_bucket = self._relative_position_bucket(relative_position)
        
        values = self.relative_attention_bias(relative_position_bucket)
        values = values.permute([2, 0, 1]).unsqueeze(0)
        
        return values

# Demonstrate T5 relative position bias
rel_pos_bias = T5RelativePositionBias()
bias = rel_pos_bias(8, 8)

plt.figure(figsize=(10, 8))
for head in range(4):
    plt.subplot(2, 2, head + 1)
    plt.imshow(bias[0, head].detach().numpy(), cmap='coolwarm', aspect='auto')
    plt.colorbar()
    plt.title(f'Relative Position Bias (Head {head + 1})')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
plt.tight_layout()
plt.show()

# T5 task examples
print("\nT5 Text-to-Text Examples:")
t5_examples = [
    ("translate English to German: The house is wonderful.", "Das Haus ist wunderbar."),
    ("summarize: <long article text>", "<summary>"),
    ("sentiment: This movie is terrible.", "negative"),
    ("question: What is the capital? context: Paris is the capital of France.", "Paris")
]

for input_text, output_text in t5_examples:
    print(f"\nInput:  {input_text}")
    print(f"Output: {output_text}")

## 5. Vision Transformer (ViT)

Let's implement Vision Transformer which applies transformers to image patches.

In [None]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

# Visualize patch extraction
def visualize_patches(img_size=224, patch_size=16):
    n_patches_per_dim = img_size // patch_size
    
    # Create a sample image with grid pattern
    img = np.zeros((img_size, img_size, 3))
    
    # Add grid lines
    for i in range(0, img_size, patch_size):
        img[i:i+2, :] = [1, 0, 0]  # Red horizontal lines
        img[:, i:i+2] = [1, 0, 0]  # Red vertical lines
    
    # Add some patches with different colors
    colors = plt.cm.hsv(np.linspace(0, 1, n_patches_per_dim * n_patches_per_dim))[:, :3]
    
    for i in range(n_patches_per_dim):
        for j in range(n_patches_per_dim):
            patch_idx = i * n_patches_per_dim + j
            y_start = i * patch_size + 2
            y_end = (i + 1) * patch_size - 2
            x_start = j * patch_size + 2
            x_end = (j + 1) * patch_size - 2
            
            # Fill patch with color
            img[y_start:y_end, x_start:x_end] = colors[patch_idx] * 0.5
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original image
    ax1.imshow(img)
    ax1.set_title(f'Image divided into {n_patches_per_dim}x{n_patches_per_dim} patches')
    ax1.axis('off')
    
    # Patch sequence
    ax2.text(0.5, 0.9, 'Patches as Sequence', ha='center', fontsize=14, weight='bold', transform=ax2.transAxes)
    
    # Draw patches as sequence
    total_patches = n_patches_per_dim * n_patches_per_dim
    for i in range(min(total_patches, 10)):  # Show first 10 patches
        x = i * 0.08 + 0.1
        rect = plt.Rectangle((x, 0.4), 0.06, 0.2, 
                           facecolor=colors[i], edgecolor='black', linewidth=2)
        ax2.add_patch(rect)
        ax2.text(x + 0.03, 0.3, str(i), ha='center', fontsize=10)
    
    ax2.text(0.92, 0.5, '...', ha='center', fontsize=16, transform=ax2.transAxes)
    ax2.set_xlim(0, 1)
    ax2.set_ylim(0, 1)
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return total_patches

# Visualize
total_patches = visualize_patches(224, 16)
print(f"\nTotal patches: {total_patches}")
print(f"Each patch: 16x16 pixels")
print(f"Sequence length: {total_patches} + 1 [CLS] token = {total_patches + 1}")

### ViT Classification Head

In [None]:
class VisionTransformer(nn.Module):
    """Simplified Vision Transformer."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
                 num_classes=1000, embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches
        
        # Learnable [CLS] token and position embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, 
            dim_feedforward=embed_dim * 4, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        
        # Add [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Transformer encoding
        x = x.transpose(0, 1)  # (seq_len, batch, dim)
        x = self.transformer(x)
        x = x.transpose(0, 1)  # (batch, seq_len, dim)
        
        # Classification: use [CLS] token
        x = self.norm(x[:, 0])
        x = self.head(x)
        
        return x

# Create mini ViT
vit = VisionTransformer(img_size=32, patch_size=8, num_classes=10, 
                       embed_dim=128, depth=4, num_heads=4)

# Test with random image
img = torch.randn(2, 3, 32, 32)
output = vit(img)
print(f"Input shape: {img.shape}")
print(f"Output shape: {output.shape}")
print(f"\nModel parameters: {sum(p.numel() for p in vit.parameters()):,}")

## 6. Comparing Architectures

Let's compare the different transformer variants side by side.

In [None]:
# Architecture comparison table
comparison_data = {
    'Architecture': ['BERT', 'GPT', 'T5', 'ViT'],
    'Type': ['Encoder-only', 'Decoder-only', 'Encoder-Decoder', 'Encoder-only'],
    'Attention': ['Bidirectional', 'Causal', 'Bidirectional + Causal', 'Bidirectional'],
    'Position': ['Learned', 'Learned', 'Relative bias', 'Learned'],
    'Pre-training': ['MLM + NSP', 'Next token', 'Span corruption', 'Supervised'],
    'Best for': ['Understanding', 'Generation', 'Any seq2seq', 'Vision'],
}

import pandas as pd
df = pd.DataFrame(comparison_data)

# Style the dataframe
styled_df = df.style.set_properties(**{
    'text-align': 'center',
    'font-size': '12px',
}).set_table_styles([
    {'selector': 'th', 'props': [('font-size', '14px'), ('text-align', 'center')]}
])

print("Transformer Architecture Comparison:")
display(styled_df)

# Visualize attention patterns
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# BERT: Bidirectional attention
ax = axes[0, 0]
bert_attn = np.ones((8, 8))
ax.imshow(bert_attn, cmap='Blues', alpha=0.8)
ax.set_title('BERT: Bidirectional Attention', fontsize=14)
ax.set_xlabel('Keys')
ax.set_ylabel('Queries')

# GPT: Causal attention
ax = axes[0, 1]
gpt_attn = np.tril(np.ones((8, 8)))
ax.imshow(gpt_attn, cmap='Greens', alpha=0.8)
ax.set_title('GPT: Causal Attention', fontsize=14)
ax.set_xlabel('Keys')
ax.set_ylabel('Queries')

# T5 Encoder: Bidirectional
ax = axes[1, 0]
t5_enc_attn = np.ones((8, 8))
ax.imshow(t5_enc_attn, cmap='Reds', alpha=0.8)
ax.set_title('T5 Encoder: Bidirectional', fontsize=14)
ax.set_xlabel('Keys')
ax.set_ylabel('Queries')

# T5 Decoder: Causal + Cross-attention
ax = axes[1, 1]
# Show decoder self-attention (causal)
t5_dec_attn = np.tril(np.ones((8, 8)))
ax.imshow(t5_dec_attn, cmap='Oranges', alpha=0.8)
ax.set_title('T5 Decoder: Causal Self-Attention', fontsize=14)
ax.set_xlabel('Keys')
ax.set_ylabel('Queries')

plt.tight_layout()
plt.show()

## 7. Choosing the Right Architecture

Let's create a decision tree for selecting the appropriate transformer variant.

In [None]:
def recommend_architecture(task, requirements):
    """Recommend transformer architecture based on task and requirements."""
    
    recommendations = []
    
    # Task-based recommendations
    if task == 'classification':
        recommendations.append(('BERT', 'Best for understanding context'))
    elif task == 'generation':
        recommendations.append(('GPT', 'Designed for text generation'))
    elif task == 'translation':
        recommendations.append(('T5', 'Excellent for seq2seq tasks'))
    elif task == 'image_classification':
        recommendations.append(('ViT', 'State-of-the-art for vision'))
    
    # Requirement-based adjustments
    if 'bidirectional' in requirements:
        recommendations.append(('BERT/T5-Encoder', 'Full context awareness'))
    if 'autoregressive' in requirements:
        recommendations.append(('GPT/T5-Decoder', 'Sequential generation'))
    if 'efficiency' in requirements:
        recommendations.append(('DistilBERT/ALBERT', 'Compressed models'))
    
    return recommendations

# Interactive task selector
tasks = {
    'Text Classification': 'classification',
    'Named Entity Recognition': 'classification',
    'Question Answering': 'classification',
    'Text Generation': 'generation',
    'Story Completion': 'generation',
    'Translation': 'translation',
    'Summarization': 'translation',
    'Image Classification': 'image_classification'
}

print("Task-Architecture Recommendations:\n")
for task_name, task_type in tasks.items():
    recs = recommend_architecture(task_type, [])
    print(f"{task_name}:")
    for arch, reason in recs:
        print(f"  → {arch}: {reason}")
    print()

## 8. Performance and Efficiency Trade-offs

In [None]:
# Model size and performance comparison
model_stats = {
    'Model': ['BERT-Base', 'BERT-Large', 'GPT-2', 'GPT-3', 'T5-Small', 'T5-Large', 'ViT-Base', 'ViT-Large'],
    'Parameters': ['110M', '340M', '1.5B', '175B', '60M', '770M', '86M', '307M'],
    'Layers': [12, 24, 48, 96, 6, 24, 12, 24],
    'Hidden Size': [768, 1024, 1600, 12288, 512, 1024, 768, 1024],
    'Attention Heads': [12, 16, 25, 96, 8, 16, 12, 16],
}

df_models = pd.DataFrame(model_stats)

# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Parameters comparison
params_numeric = [110, 340, 1500, 175000, 60, 770, 86, 307]
colors = ['blue', 'blue', 'green', 'green', 'red', 'red', 'purple', 'purple']

ax1.bar(range(len(df_models)), params_numeric, color=colors, alpha=0.7)
ax1.set_yscale('log')
ax1.set_xticks(range(len(df_models)))
ax1.set_xticklabels(df_models['Model'], rotation=45, ha='right')
ax1.set_ylabel('Parameters (Millions)')
ax1.set_title('Model Size Comparison (Log Scale)')
ax1.grid(True, alpha=0.3)

# Architecture comparison
architectures = ['BERT', 'GPT', 'T5', 'ViT']
properties = ['Bidirectional', 'Causal', 'Encoder-Decoder', 'Patch-based']
use_cases = ['Understanding', 'Generation', 'Translation', 'Vision']

x = np.arange(len(architectures))
width = 0.25

# Simple comparison metrics (illustrative)
understanding = [0.9, 0.6, 0.8, 0.0]
generation = [0.3, 0.95, 0.85, 0.0]
efficiency = [0.7, 0.8, 0.6, 0.85]

ax2.bar(x - width, understanding, width, label='Understanding', alpha=0.8)
ax2.bar(x, generation, width, label='Generation', alpha=0.8)
ax2.bar(x + width, efficiency, width, label='Efficiency', alpha=0.8)

ax2.set_xlabel('Architecture')
ax2.set_ylabel('Relative Score')
ax2.set_title('Architecture Capabilities')
ax2.set_xticks(x)
ax2.set_xticklabels(architectures)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Insights:")
print("• BERT excels at understanding tasks but cannot generate text autoregressively")
print("• GPT is optimized for generation but lacks bidirectional context")
print("• T5 is versatile but requires more parameters for similar performance")
print("• ViT shows that transformers can excel beyond NLP")

## Summary and Best Practices

### Key Takeaways:

1. **BERT (Encoder-only)**:
   - Bidirectional attention sees full context
   - Best for: Classification, NER, Question Answering
   - Cannot generate text naturally

2. **GPT (Decoder-only)**:
   - Causal attention for autoregressive generation
   - Best for: Text generation, completion, few-shot learning
   - Limited understanding without bidirectional context

3. **T5 (Encoder-Decoder)**:
   - Flexible text-to-text framework
   - Best for: Translation, summarization, any seq2seq task
   - More parameters but very versatile

4. **ViT (Vision Transformer)**:
   - Applies transformers to image patches
   - Best for: Image classification, vision tasks
   - Shows transformers work beyond NLP

### Selection Guide:

```
if task == "understanding":
    use BERT or RoBERTa
elif task == "generation":
    use GPT-2, GPT-3, or GPT-4
elif task == "translation" or "summarization":
    use T5 or BART
elif task == "vision":
    use ViT or DEIT
else:
    consider task-specific variants
```