# Module 1.1: Neural Networks - The Basics

**Goal**: Understand how information flows through networks

**Time**: 60 minutes

**Concepts Covered**:
- Forward pass visualization
- Loss calculation
- Backward pass (gradients)
- Single-head attention
- Multi-head attention

## Setup
Install required packages (run once)

In [None]:
!pip install torch numpy matplotlib seaborn -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

## Lesson 1: Build Your First Neural Network (15 mins)

We'll build a simple 2-layer network from scratch to learn the XOR function.

**XOR Truth Table**:
```
Input1  Input2  Output
  0       0       0
  0       1       1
  1       0       1
  1       1       0
```

In [None]:
# XOR dataset
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

print("Input shape:", X.shape)
print("Output shape:", y.shape)
print("\nDataset:")
for i in range(len(X)):
    print(f"  {X[i].numpy()} -> {y[i].item()}")

In [None]:
class SimpleNN(nn.Module):
    """Simple 2-layer neural network
    
    Architecture: Input(2) -> Hidden(4) -> Output(1)
    """
    def __init__(self, input_size=2, hidden_size=4, output_size=1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        # Forward pass with visualization
        hidden = torch.sigmoid(self.fc1(x))  # Hidden layer with sigmoid activation
        output = torch.sigmoid(self.fc2(hidden))  # Output layer
        return output, hidden  # Return both for visualization

# Initialize model
model = SimpleNN()
print("Model architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
# Training setup
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

# Track loss history
losses = []

# Training loop
epochs = 1000
for epoch in range(epochs):
    # Forward pass
    predictions, hidden = model(X)
    loss = criterion(predictions, y)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Track loss
    losses.append(loss.item())
    
    if (epoch + 1) % 200 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

print("\nTraining complete!")

In [None]:
# Visualize learning curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Network Learning XOR Function')
plt.grid(True, alpha=0.3)
plt.show()

# Test the trained model
print("\nFinal predictions:")
with torch.no_grad():
    predictions, hidden_states = model(X)
    for i in range(len(X)):
        print(f"  Input: {X[i].numpy()} -> Predicted: {predictions[i].item():.4f}, True: {y[i].item()}")

# Enhanced XOR visualization: Decision boundary
print("\n" + "="*50)
print("Enhanced Visualization: Decision Boundary")
print("="*50)

# Create a grid of points
xx, yy = np.meshgrid(np.linspace(-0.5, 1.5, 100), np.linspace(-0.5, 1.5, 100))
grid_points = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)

with torch.no_grad():
    grid_pred, _ = model(grid_points)
    grid_pred = grid_pred.numpy().reshape(xx.shape)

# Plot decision boundary
plt.figure(figsize=(10, 8))
contour = plt.contourf(xx, yy, grid_pred, levels=20, cmap='RdYlBu', alpha=0.6)
plt.colorbar(contour, label='Model Output')
plt.scatter(X[:, 0], X[:, 1], c=y.squeeze(), s=200, cmap='RdYlBu', edgecolors='black', linewidths=2, zorder=5)
plt.xlabel('Input 1')
plt.ylabel('Input 2')
plt.title('XOR Decision Boundary Visualization')
plt.grid(True, alpha=0.3)
for i, (x_val, y_val) in enumerate(X):
    plt.annotate(f'({int(x_val[0])},{int(x_val[1])})', (x_val[0], x_val[1]), 
                xytext=(5, 5), textcoords='offset points', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## Lesson 2: Attention is All You Need - Visual Proof (20 mins)

Implement single-head attention from scratch and visualize how it works.

In [None]:
def single_head_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attention
    
    Args:
        Q: Query matrix (batch, seq_len, d_k)
        K: Key matrix (batch, seq_len, d_k)
        V: Value matrix (batch, seq_len, d_v)
        mask: Optional mask (batch, seq_len, seq_len)
    
    Returns:
        output: Attention output (batch, seq_len, d_v)
        attention_weights: Attention weights (batch, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute attention scores (Q @ K^T)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Weighted sum of values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

In [None]:
# Example sentence: "The cat sat on the mat"
sentence = "The cat sat on the mat"
tokens = sentence.split()
seq_len = len(tokens)
d_model = 8  # Small dimension for visualization

# Create random embeddings for tokens
torch.manual_seed(42)
embeddings = torch.randn(1, seq_len, d_model)

# Create Q, K, V matrices (in practice, these are learned projections)
Q = embeddings
K = embeddings
V = embeddings

# Compute attention
output, attention_weights = single_head_attention(Q, K, V)

print(f"Input shape: {embeddings.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

In [None]:
# Visualize attention heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
    attention_weights[0].detach().numpy(),
    xticklabels=tokens,
    yticklabels=tokens,
    cmap='YlOrRd',
    annot=True,
    fmt='.2f',
    cbar_kws={'label': 'Attention Weight'},
    linewidths=0.5
)
plt.title('Single-Head Attention Heatmap\n"The cat sat on the mat"', fontsize=14, fontweight='bold')
plt.xlabel('Key (attending to)', fontsize=12)
plt.ylabel('Query (attending from)', fontsize=12)
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Each row shows what a token attends to")
print("- Brighter colors = higher attention")
print("- Diagonal is often bright (self-attention)")
print("- Off-diagonal patterns show relationships (e.g., 'cat' â†’ 'sat')")

# Additional example: Show attention for different sentence structures
print("\n" + "="*50)
print("Additional Example: Subject-Verb Relationship")
print("="*50)

sentence2 = "The dog chased the ball"
tokens2 = sentence2.split()
seq_len2 = len(tokens2)
embeddings2 = torch.randn(1, seq_len2, d_model)
Q2, K2, V2 = embeddings2, embeddings2, embeddings2
output2, attn2 = single_head_attention(Q2, K2, V2)

plt.figure(figsize=(10, 8))
sns.heatmap(
    attn2[0].detach().numpy(),
    xticklabels=tokens2,
    yticklabels=tokens2,
    cmap='YlOrRd',
    annot=True,
    fmt='.2f',
    cbar_kws={'label': 'Attention Weight'},
    linewidths=0.5
)
plt.title('Attention: "The dog chased the ball"\n(Notice subject-verb relationships)', 
          fontsize=14, fontweight='bold')
plt.xlabel('Key (attending to)', fontsize=12)
plt.ylabel('Query (attending from)', fontsize=12)
plt.tight_layout()
plt.show()

## Lesson 3: Multi-Head Attention (25 mins)

Scale to 8 attention heads to capture different relationships.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        """Split into multiple heads"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # Linear projections
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Compute attention for each head
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final linear projection
        output = self.W_o(output)
        
        return output, attention_weights

# Initialize multi-head attention
d_model = 64
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)

# Create input
batch_size = 1
seq_len = 6
x = torch.randn(batch_size, seq_len, d_model)

# Forward pass
output, attention_weights = mha(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"  (batch, num_heads, seq_len, seq_len)")

In [None]:
# Visualize all 8 attention heads
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
axes = axes.flatten()

tokens = ["The", "cat", "sat", "on", "the", "mat"]

for head in range(num_heads):
    attn_matrix = attention_weights[0, head].detach().numpy()
    sns.heatmap(
        attn_matrix,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='YlOrRd',
        ax=axes[head],
        cbar=True,
        vmin=0,
        vmax=1,
        annot=True,
        fmt='.2f',
        cbar_kws={'label': 'Weight'}
    )
    axes[head].set_title(f'Head {head + 1}', fontsize=11, fontweight='bold')

plt.suptitle('Multi-Head Attention (8 Heads) - "The cat sat on the mat"', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("ðŸŽ¯ EXERCISE: Identify Head Specializations")
print("="*60)
print("\nAnalyze the attention patterns above and identify which heads focus on:")
print("\n1. SYNTAX (Grammatical relationships):")
print("   - Article-noun: 'The' â†’ 'cat', 'the' â†’ 'mat'")
print("   - Preposition-object: 'on' â†’ 'mat'")
print("\n2. SEMANTICS (Meaning relationships):")
print("   - Subject-verb: 'cat' â†’ 'sat'")
print("   - Verb-object: 'sat' â†’ 'on'")
print("\n3. POSITION (Spatial relationships):")
print("   - Neighboring tokens (adjacent positions)")
print("   - Long-range dependencies")
print("\n4. SELF-ATTENTION:")
print("   - Strong diagonal patterns (token attending to itself)")
print("\nðŸ’¡ Tip: Look for patterns where attention weights are consistently high")
print("   between specific token pairs across different query positions.")
print("\n" + "-"*60)
print("Write your observations:")
print("  Syntax-focused heads: _____")
print("  Semantics-focused heads: _____")
print("  Position-focused heads: _____")
print("-"*60)

## Key Takeaways

âœ… **Neural Networks**: Information flows forward, gradients flow backward

âœ… **Attention Mechanism**: Allows tokens to "look at" other tokens

âœ… **Multi-Head Attention**: Multiple perspectives capture different relationships

âœ… **Scaled Dot-Product**: Division by âˆšd_k prevents gradient vanishing

## Next Steps

Continue to **Module 1.2: Transformer Architecture Deep Dive** to learn about:
- Position encoding (sinusoidal and RoPE)
- Feed-forward networks
- Layer normalization
- Complete transformer blocks