# Paper 13: Attention Is All You Need
## Vaswani et al. (2017)

### The Transformer: Pure Attention Architecture

Revolutionary architecture that replaced RNNs with self-attention, enabling modern LLMs.

In [None]:
import torch
import matplotlib.pyplot as plt

torch.manual_seed(42)
# Torch-Numpy compatibility helpers
if not hasattr(torch.Tensor, 'copy'):
    torch.Tensor.copy = torch.Tensor.clone

def _astype(self, dtype):
    """
    Compute  astype.
    
    Args:
        dtype: Input parameter.
    Returns:
        Computed result.
    """
    if dtype is float:
        return self.float()
    if dtype is int:
        return self.int()
    return self.to(dtype)

if not hasattr(torch.Tensor, 'astype'):
    torch.Tensor.astype = _astype


## Scaled Dot-Product Attention

The fundamental building block:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def softmax(x, dim=-1):
    """
    Compute softmax probabilities.
    
    Args:
        x: Input data.
        dim: Dimension size.
    Returns:
        Softmax probabilities.
    """
    x_max = torch.amax(x, dim=axis, keepdim=True)
    exp_x = torch.exp(x - x_max)
    return exp_x / torch.sum(exp_x, dim=axis, keepdim=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot product attention.
    
    Args:
        Q: Input parameter.
        K: Input parameter.
        V: Input parameter.
        mask: Mask tensor or array.
    Returns:
        output: Input parameter.
        attention_weights: Input parameter.
    """
    d_k = Q.shape[-1]
    
    # Compute attention scores
    scores = torch.matmul(Q, K.T) / torch.sqrt(d_k)
    
    # Apply mask if provided (for causality or padding)
    if mask is not None:
        scores = scores + (mask * -1e9)
    
    # Softmax to get attention weights
    attention_weights = softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test scaled dot-product attention
seq_len = 5
d_model = 8

Q = torch.randn(seq_len, d_model)
K = torch.randn(seq_len, d_model)
V = torch.randn(seq_len, d_model)

output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Attention output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Attention weights sum (should be 1): {attn_weights.sum(dim=1)}")

# Visualize attention pattern
plt.figure(figsize=(8, 6))
plt.imshow(attn_weights, cmap='viridis', aspect='auto')
plt.colorbar(label='Attention Weight')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights Matrix')
plt.show()

## Multi-Head Attention

Multiple attention "heads" attend to different aspects of the input:
$$\text{MultiHead}(Q,K,V) = \text{Concat}(head_1, ..., head_h)W^O$$

In [None]:
class MultiHeadAttention:
    def __init__(self, d_model, num_heads):
        """
        Initialize the instance.
        
        Args:
            d_model: Model instance.
            num_heads: Number of heads.
        Returns:
            Computed result.
        """
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V for all heads (parallelized)
        self.W_q = torch.randn(d_model, d_model) * 0.1
        self.W_k = torch.randn(d_model, d_model) * 0.1
        self.W_v = torch.randn(d_model, d_model) * 0.1
        
        # Output projection
        self.W_o = torch.randn(d_model, d_model) * 0.1
    
    def split_heads(self, x):
        """
        Compute split heads.
        
        Args:
            x: Input data.
        Returns:
            Computed result.
        """
        seq_len = x.shape[0]
        x = x.reshape(seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 0, 2)
    
    def combine_heads(self, x):
        """
        Compute combine heads.
        
        Args:
            x: Input data.
        Returns:
            Computed result.
        """
        seq_len = x.shape[1]
        x = x.transpose(1, 0, 2)
        return x.reshape(seq_len, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        """
        Run the forward pass value.
        
        Args:
            Q: Input parameter.
            K: Input parameter.
            V: Input parameter.
            mask: Mask tensor or array.
        Returns:
            Computed result.
        """
        # Linear projections
        Q = torch.matmul(Q, self.W_q.T)
        K = torch.matmul(K, self.W_k.T)
        V = torch.matmul(V, self.W_v.T)
        
        # Split into multiple heads
        Q = self.split_heads(Q)  # (num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Apply attention to each head
        head_outputs = []
        self.attention_weights = []
        
        for i in range(self.num_heads):
            head_out, head_attn = scaled_dot_product_attention(
                Q[i], K[i], V[i], mask
            )
            head_outputs.append(head_out)
            self.attention_weights.append(head_attn)
        
        # Stack heads
        heads = torch.stack(head_outputs, dim=0)  # (num_heads, seq_len, d_k)
        
        # Combine heads
        combined = self.combine_heads(heads)  # (seq_len, d_model)
        
        # Final linear projection
        output = torch.matmul(combined, self.W_o.T)
        
        return output

# Test multi-head attention
d_model = 64
num_heads = 8
seq_len = 10

mha = MultiHeadAttention(d_model, num_heads)

X = torch.randn(seq_len, d_model)
output = mha.forward(X, X, X)  # Self-attention

print(f"\nMulti-Head Attention:")
print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of heads: {num_heads}")
print(f"Dimension per head: {mha.d_k}")

## Positional Encoding

Since Transformers have no recurrence, we add position information:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

In [None]:
def positional_encoding(seq_len, d_model):
    """
    Compute positional encoding.
    
    Args:
        seq_len: Length parameter.
        d_model: Model instance.
    Returns:
        pe: Input parameter.
    """
    pe = torch.zeros((seq_len, d_model))
    
    position = torch.arange(0, seq_len)[:, torch.newaxis]
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(10000.0) / d_model))
    
    # Apply sin to even indices
    pe[:, 0::2] = torch.sin(position * div_term)
    
    # Apply cos to odd indices
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

# Generate positional encodings
seq_len = 50
d_model = 64
pe = positional_encoding(seq_len, d_model)

# Visualize positional encodings
plt.figure(figsize=(12, 8))

plt.subplot(2, 1, 1)
plt.imshow(pe.T, cmap='RdBu', aspect='auto')
plt.colorbar(label='Encoding Value')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding (All Dimensions)')

plt.subplot(2, 1, 2)
# Plot first few dimensions
for i in [0, 1, 2, 3, 10, 20]:
    plt.plot(pe[:, i], label=f'Dim {i}')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.title('Positional Encoding (Selected Dimensions)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Positional encoding shape: {pe.shape}")
print(f"Different frequencies encode position at different scales")

## Feed-Forward Network

Applied to each position independently:
$$FFN(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

In [None]:
class FeedForward:
    def __init__(self, d_model, d_ff):
        """
        Initialize the instance.
        
        Args:
            d_model: Model instance.
            d_ff: Input parameter.
        Returns:
            Computed result.
        """
        self.W1 = torch.randn(d_model, d_ff) * 0.1
        self.b1 = torch.zeros(d_ff)
        self.W2 = torch.randn(d_ff, d_model) * 0.1
        self.b2 = torch.zeros(d_model)
    
    def forward(self, x):
        """
        Run the forward pass value.
        
        Args:
            x: Input data.
        Returns:
            Computed result.
        """
        # First layer with ReLU
        hidden = torch.maximum(0, torch.matmul(x, self.W1) + self.b1)
        
        # Second layer
        output = torch.matmul(hidden, self.W2) + self.b2
        
        return output

# Test feed-forward
d_model = 64
d_ff = 256  # Usually 4x larger

ff = FeedForward(d_model, d_ff)
x = torch.randn(10, d_model)
output = ff.forward(x)

print(f"\nFeed-Forward Network:")
print(f"Input: {x.shape}")
print(f"Hidden: ({x.shape[0]}, {d_ff})")
print(f"Output: {output.shape}")

## Layer Normalization

Normalize across features (not batch like BatchNorm)

In [None]:
class LayerNorm:
    def __init__(self, d_model, eps=1e-6):
        """
        Initialize the instance.
        
        Args:
            d_model: Model instance.
            eps: Input parameter.
        Returns:
            Computed result.
        """
        self.gamma = torch.ones(d_model)
        self.beta = torch.zeros(d_model)
        self.eps = eps
    
    def forward(self, x):
        """
        Run the forward pass value.
        
        Args:
            x: Input data.
        Returns:
            Computed result.
        """
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        
        normalized = (x - mean) / (std + self.eps)
        output = self.gamma * normalized + self.beta
        
        return output

ln = LayerNorm(d_model)
x = torch.randn(10, d_model) * 3 + 5  # Unnormalized
normalized = ln.forward(x)

print(f"\nLayer Normalization:")
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {normalized.mean():.4f}, std: {normalized.std():.4f}")

## Complete Transformer Block

In [None]:
class TransformerBlock:
    def __init__(self, d_model, num_heads, d_ff):
        """
        Initialize the instance.
        
        Args:
            d_model: Model instance.
            num_heads: Number of heads.
            d_ff: Input parameter.
        Returns:
            Computed result.
        """
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff)
        self.norm2 = LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        """
        Run the forward pass value.
        
        Args:
            x: Input data.
            mask: Mask tensor or array.
        Returns:
            Computed result.
        """
        # Multi-head attention with residual connection
        attn_output = self.attention.forward(x, x, x, mask)
        x = self.norm1.forward(x + attn_output)
        
        # Feed-forward with residual connection
        ff_output = self.ff.forward(x)
        x = self.norm2.forward(x + ff_output)
        
        return x

# Test transformer block
block = TransformerBlock(d_model=64, num_heads=8, d_ff=256)
x = torch.randn(10, 64)
output = block.forward(x)

print(f"\nTransformer Block:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nBlock contains:")
print(f"  1. Multi-Head Self-Attention")
print(f"  2. Layer Normalization")
print(f"  3. Feed-Forward Network")
print(f"  4. Residual Connections")

## Visualize Multi-Head Attention Patterns

In [None]:
# Create attention with interpretable input
seq_len = 8
d_model = 64
num_heads = 4

mha = MultiHeadAttention(d_model, num_heads)
X = torch.randn(seq_len, d_model)
output = mha.forward(X, X, X)

# Plot attention patterns for each head
fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))

for i, ax in enumerate(axes):
    attn = mha.attention_weights[i]
    im = ax.imshow(attn, cmap='viridis', aspect='auto', vmin=0, vmax=1)
    ax.set_title(f'Head {i+1}')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')
    
plt.colorbar(im, ax=axes, label='Attention Weight', fraction=0.046, pad=0.04)
plt.suptitle('Multi-Head Attention Patterns', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

print("\nEach head learns to attend to different patterns!")
print("Different heads capture different relationships in the data.")

## Causal (Masked) Self-Attention for Autoregressive Models

In [None]:
def create_causal_mask(seq_len):
    """
    Create causal mask.
    
    Args:
        seq_len: Length parameter.
    Returns:
        mask: Mask tensor or array.
    """
    mask = torch.triu(torch.ones((seq_len, seq_len)), k=1)
    return mask

# Test causal attention
seq_len = 8
causal_mask = create_causal_mask(seq_len)

Q = torch.randn(seq_len, d_model)
K = torch.randn(seq_len, d_model)
V = torch.randn(seq_len, d_model)

# Without mask (bidirectional)
output_bi, attn_bi = scaled_dot_product_attention(Q, K, V)

# With causal mask (unidirectional)
output_causal, attn_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

# Visualize difference
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 5))

# Causal mask
ax1.imshow(causal_mask, cmap='Reds', aspect='auto')
ax1.set_title('Causal Mask\n(1 = masked/not allowed)')
ax1.set_xlabel('Key Position')
ax1.set_ylabel('Query Position')

# Bidirectional attention
im2 = ax2.imshow(attn_bi, cmap='viridis', aspect='auto', vmin=0, vmax=1)
ax2.set_title('Bidirectional Attention\n(can see future)')
ax2.set_xlabel('Key Position')
ax2.set_ylabel('Query Position')

# Causal attention
im3 = ax3.imshow(attn_causal, cmap='viridis', aspect='auto', vmin=0, vmax=1)
ax3.set_title('Causal Attention\n(cannot see future)')
ax3.set_xlabel('Key Position')
ax3.set_ylabel('Query Position')

plt.colorbar(im3, ax=[ax2, ax3], label='Attention Weight')
plt.tight_layout()
plt.show()

print("\nCausal masking is crucial for:")
print("  - Autoregressive generation (GPT, language models)")
print("  - Prevents information leakage from future tokens")
print("  - Each position can only attend to itself and previous positions")

## Key Takeaways

### Why "Attention Is All You Need"?
- **No recurrence**: Processes entire sequence in parallel
- **No convolution**: Pure attention mechanism
- **Scales better**: O(nÂ²d) vs O(n) sequential operations in RNNs
- **Long-range dependencies**: Direct connections between any positions

### Core Components:
1. **Scaled Dot-Product Attention**: Efficient attention computation
2. **Multi-Head Attention**: Multiple representation subspaces
3. **Positional Encoding**: Inject position information
4. **Feed-Forward Networks**: Position-wise transformations
5. **Layer Normalization**: Stabilize training
6. **Residual Connections**: Enable deep networks

### Architecture Variants:
- **Encoder-Decoder**: Original Transformer (translation)
- **Encoder-only**: BERT (bidirectional understanding)
- **Decoder-only**: GPT (autoregressive generation)

### Advantages:
- Parallelizable training (unlike RNNs)
- Better long-range dependencies
- Interpretable attention patterns
- State-of-the-art on many tasks

### Impact:
- Foundation of modern NLP: GPT, BERT, T5, etc.
- Extended to vision: Vision Transformer (ViT)
- Multi-modal models: CLIP, Flamingo
- Enabled LLMs with billions of parameters