In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from typing import Optional
import time

# Set style for prettier plots
plt.style.use('default')
sns.set_palette("husl")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")


In [None]:
def relu(x):
    """The classic: ReLU activation"""
    return torch.clamp(x, min=0)

def gelu(x):
    """GELU: Gaussian Error Linear Unit"""
    return x * 0.5 * (1 + torch.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))

def swish_silu(x):
    """Swish/SiLU: x * sigmoid(x)"""
    return x * torch.sigmoid(x)

def leaky_relu(x, alpha=0.01):
    """Leaky ReLU: fixes the dying ReLU problem"""
    return torch.where(x > 0, x, alpha * x)

# Let's visualize these functions
x = torch.linspace(-5, 5, 1000)

plt.figure(figsize=(15, 10))

activations = {
    'ReLU': relu(x),
    'Leaky ReLU': leaky_relu(x),
    'GELU': gelu(x),
    'Swish/SiLU': swish_silu(x)
}

for i, (name, y) in enumerate(activations.items(), 1):
    plt.subplot(2, 2, i)
    plt.plot(x.numpy(), y.numpy(), linewidth=2, label=name)
    plt.grid(True, alpha=0.3)
    plt.title(f'{name} Activation', fontsize=14, fontweight='bold')
    plt.xlabel('Input (x)')
    plt.ylabel('Output f(x)')
    plt.legend()

plt.tight_layout()
plt.suptitle('Evolution of Activation Functions', fontsize=16, fontweight='bold', y=1.02)
plt.show()

# Let's also plot them together for comparison
plt.figure(figsize=(12, 8))
for name, y in activations.items():
    plt.plot(x.numpy(), y.numpy(), linewidth=2.5, label=name)

plt.grid(True, alpha=0.3)
plt.title('Activation Functions Comparison', fontsize=16, fontweight='bold')
plt.xlabel('Input (x)', fontsize=12)
plt.ylabel('Output f(x)', fontsize=12)
plt.legend(fontsize=12)
plt.xlim(-4, 4)
plt.ylim(-1, 4)
plt.show()


In [None]:
class SwiGLU(nn.Module):
    """SwiGLU Activation Function
    
    SwiGLU(x) = Swish(xW1 + b1) ⊙ (xW2 + b2)
    where Swish(z) = z * sigmoid(z)
    """
    
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Two linear projections for gating
        self.w1 = nn.Linear(in_features, out_features, bias=bias)
        self.w2 = nn.Linear(in_features, out_features, bias=bias)
        
    def forward(self, x):
        # SwiGLU: SiLU(xW1) ⊙ (xW2)
        return F.silu(self.w1(x)) * self.w2(x)


# Let's test our implementation
input_dim = 4
batch_size = 3
test_input = torch.randn(batch_size, input_dim)

print("🧪 Testing SwiGLU Implementation")
print(f"Input shape: {test_input.shape}")
print(f"Input values:\n{test_input}")

swiglu = SwiGLU(input_dim, input_dim)
output = swiglu(test_input)

print(f"\nOutput shape: {output.shape}")
print(f"Output values:\n{output}")
print(f"\n✅ SwiGLU working correctly!")

print("\n💡 Exercise 1: Try replacing F.silu with F.gelu or torch.sigmoid!")


In [None]:
# Compare point-wise activations first
z = torch.linspace(-5, 5, 1000)

activations = {
    "ReLU": lambda x: F.relu(x),
    "GELU": lambda x: F.gelu(x),
    "SiLU (Swish)": lambda x: F.silu(x),
    "Sigmoid": lambda x: torch.sigmoid(x)
}

plt.figure(figsize=(12, 4))
for name, fn in activations.items():
    plt.plot(z, fn(z).detach(), label=name, linewidth=2.5)

plt.title("Point-wise Activation Functions", fontsize=14, fontweight='bold')
plt.xlabel("Input (z)")
plt.ylabel("Output")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Now let's create GLU variants for comparison
class GLUVariants(nn.Module):
    """Collection of GLU variants for comparison"""
    
    def __init__(self, in_features: int, out_features: int, variant: str = 'swiglu'):
        super().__init__()
        self.variant = variant.lower()
        self.w1 = nn.Linear(in_features, out_features, bias=False)
        self.w2 = nn.Linear(in_features, out_features, bias=False)
        
    def forward(self, x):
        if self.variant == 'glu':
            # Original GLU: sigmoid gate
            return torch.sigmoid(self.w1(x)) * self.w2(x)
        elif self.variant == 'swiglu':
            # SwiGLU: SiLU/Swish gate
            return F.silu(self.w1(x)) * self.w2(x)
        elif self.variant == 'geglu':
            # GeGLU: GELU gate
            return F.gelu(self.w1(x)) * self.w2(x)
        elif self.variant == 'reglu':
            # ReGLU: ReLU gate
            return F.relu(self.w1(x)) * self.w2(x)
        else:
            raise ValueError(f"Unknown variant: {self.variant}")


# Test all variants
variants = ['glu', 'swiglu', 'geglu', 'reglu']
test_input = torch.randn(8, 16)

print("🎭 Testing GLU Variants")
print("=" * 40)

for variant in variants:
    model = GLUVariants(16, 16, variant)
    with torch.no_grad():
        output = model(test_input)
    print(f"{variant.upper():<8}: Output shape {output.shape}, Mean: {output.mean():.4f}, Std: {output.std():.4f}")

print("\n✅ All GLU variants working correctly!")
print("\n💡 Challenge 2: Extend the visualization to show gated outputs!")


In [None]:
class FeedForwardSwiGLU(nn.Module):
    """Modern Transformer FFN with SwiGLU activation
    
    This is a drop-in replacement for standard Transformer FFN layers.
    Used in LLaMA, PaLM, and other state-of-the-art models.
    """
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # Note: We use 2 * d_ff for the intermediate dimension
        # This accounts for the gating mechanism
        self.fc1 = nn.Linear(d_model, 2 * d_ff, bias=False)  # Gate and Up projections combined
        self.fc2 = nn.Linear(d_ff, d_model, bias=False)      # Down projection
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Split the combined projection into gate and up branches
        gate, up = self.fc1(x).chunk(2, dim=-1)
        
        # Apply SwiGLU: SiLU(gate) * up
        hidden = F.silu(gate) * up
        
        # Apply dropout and final projection
        return self.fc2(self.dropout(hidden))


# For comparison, let's also implement a standard GELU FFN
class FeedForwardGELU(nn.Module):
    """Standard Transformer FFN with GELU activation"""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff, bias=False)
        self.fc2 = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.fc2(self.dropout(F.gelu(self.fc1(x))))


# Compare parameter counts
d_model = 512
d_ff = 1365  # Reduced for SwiGLU to keep params similar to 2048 GELU

swiglu_ffn = FeedForwardSwiGLU(d_model, d_ff)
gelu_ffn = FeedForwardGELU(d_model, 2048)  # Standard size for GELU

swiglu_params = sum(p.numel() for p in swiglu_ffn.parameters())
gelu_params = sum(p.numel() for p in gelu_ffn.parameters())

print("📊 Feed-Forward Layer Comparison")
print("=" * 40)
print(f"SwiGLU FFN (d_ff={d_ff}):     {swiglu_params:,} parameters")
print(f"GELU FFN (d_ff=2048):         {gelu_params:,} parameters")
print(f"Parameter ratio:              {swiglu_params/gelu_params:.3f}x")

# Test with sample input
batch_size, seq_len = 4, 64
test_input = torch.randn(batch_size, seq_len, d_model)

print(f"\n🧪 Testing with input shape: {test_input.shape}")

with torch.no_grad():
    swiglu_out = swiglu_ffn(test_input)
    gelu_out = gelu_ffn(test_input)

print(f"SwiGLU output shape: {swiglu_out.shape}")
print(f"GELU output shape:   {gelu_out.shape}")

print("\n✅ Both feed-forward layers working correctly!")
print("\n📝 Note: SwiGLU uses 2/3 the hidden dimension to keep parameter count similar")


In [None]:
class TinyTransformerBlock(nn.Module):
    """A single transformer block with SwiGLU"""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        # Layer norms (pre-norm style like modern LLMs)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
        # Multi-head attention
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        
        # SwiGLU feed-forward
        self.ffn = FeedForwardSwiGLU(d_model, d_ff, dropout)
        
    def forward(self, x, mask=None):
        # Pre-norm attention
        norm_x = self.ln1(x)
        attn_out, _ = self.self_attn(norm_x, norm_x, norm_x, attn_mask=mask)
        x = x + attn_out  # Residual connection
        
        # Pre-norm feed-forward
        norm_x = self.ln2(x)
        ffn_out = self.ffn(norm_x)
        x = x + ffn_out  # Residual connection
        
        return x


class TinyTransformer(nn.Module):
    """A tiny transformer language model (~1M parameters)"""
    
    def __init__(self, vocab_size: int = 500, d_model: int = 128, 
                 n_layers: int = 2, n_heads: int = 4, d_ff: int = 256):
        super().__init__()
        
        self.d_model = d_model
        
        # Token and position embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(512, d_model) * 0.02)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TinyTransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # Final layer norm and language model head
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights following modern practices"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        # Get embeddings
        token_emb = self.token_emb(input_ids)  # (batch, seq, d_model)
        pos_emb = self.pos_emb[:seq_len]       # (seq, d_model)
        x = token_emb + pos_emb
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Final layer norm and projection to vocabulary
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (batch, seq, vocab_size)
        
        return logits


# Create and test the model
tiny_model = TinyTransformer()
total_params = sum(p.numel() for p in tiny_model.parameters())

print("🤖 Tiny Transformer with SwiGLU")
print("=" * 40)
print(f"Model parameters: {total_params:,}")
print(f"Architecture: 2 layers, 4 heads, 128 d_model")

# Test with dummy input
batch_size, seq_len = 2, 32
input_ids = torch.randint(0, 500, (batch_size, seq_len))

print(f"\n🧪 Testing with input shape: {input_ids.shape}")

with torch.no_grad():
    logits = tiny_model(input_ids)

print(f"Output logits shape: {logits.shape}")
print(f"Logits represent predictions for 500 vocabulary tokens")

# Show sample predictions for the last token
probs = F.softmax(logits[0, -1], dim=-1)
top_k = 5
top_probs, top_indices = probs.topk(top_k)

print(f"\n📊 Top {top_k} predicted tokens (for last position):")
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
    print(f"  {i+1}. Token {idx.item():3d}: {prob.item():.4f} probability")

print("\n✅ Tiny transformer with SwiGLU working correctly!")
print("🎉 This architecture is similar to modern LLMs like LLaMA!")


In [None]:
# Create both SwiGLU and GELU versions for comparison
class TinyTransformerGELU(TinyTransformer):
    """Same as TinyTransformer but with GELU instead of SwiGLU"""
    
    def __init__(self, vocab_size: int = 500, d_model: int = 128, 
                 n_layers: int = 2, n_heads: int = 4, d_ff: int = 384):  # Larger d_ff for GELU
        super().__init__(vocab_size, d_model, n_layers, n_heads, d_ff)
        
        # Replace SwiGLU with GELU in all blocks
        for block in self.blocks:
            block.ffn = FeedForwardGELU(d_model, d_ff)


def train_model_steps(model, num_steps=100, lr=3e-4):
    """Simple training loop for demonstration"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_curve = []
    
    model.train()
    for step in range(num_steps):
        # Generate random batch (next token prediction task)
        batch_size, seq_len = 8, 16
        input_ids = torch.randint(0, 500, (batch_size, seq_len))
        
        # Forward pass
        logits = model(input_ids)
        
        # Next token prediction loss
        # Predict tokens 1 to seq_len based on tokens 0 to seq_len-1
        loss = F.cross_entropy(
            logits[:, :-1].reshape(-1, 500),  # predictions
            input_ids[:, 1:].reshape(-1)      # targets
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_curve.append(loss.item())
        
        if step % 25 == 0:
            print(f"Step {step:3d}: Loss = {loss.item():.4f}")
    
    return loss_curve


# Quick training comparison
print("🏁 Training Comparison: SwiGLU vs GELU")
print("=" * 50)

# Create models with similar parameter counts
swiglu_model = TinyTransformer(d_ff=256)  # SwiGLU with smaller d_ff
gelu_model = TinyTransformerGELU(d_ff=384)  # GELU with larger d_ff

swiglu_params = sum(p.numel() for p in swiglu_model.parameters())
gelu_params = sum(p.numel() for p in gelu_model.parameters())

print(f"SwiGLU model: {swiglu_params:,} parameters")
print(f"GELU model:   {gelu_params:,} parameters")
print(f"Parameter ratio: {swiglu_params/gelu_params:.3f}")

# Train both models
print(f"\n🚀 Training SwiGLU model...")
swiglu_losses = train_model_steps(swiglu_model, num_steps=100)

print(f"\n🚀 Training GELU model...")
gelu_losses = train_model_steps(gelu_model, num_steps=100)

# Plot results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(swiglu_losses, label='SwiGLU', linewidth=2, color='blue')
plt.plot(gelu_losses, label='GELU', linewidth=2, color='red')
plt.title('Training Loss Comparison', fontweight='bold')
plt.xlabel('Training Step')
plt.ylabel('Cross-Entropy Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Smooth the curves for better visualization
window = 10
swiglu_smooth = np.convolve(swiglu_losses, np.ones(window)/window, mode='valid')
gelu_smooth = np.convolve(gelu_losses, np.ones(window)/window, mode='valid')

plt.plot(swiglu_smooth, label='SwiGLU (smoothed)', linewidth=2, color='blue')
plt.plot(gelu_smooth, label='GELU (smoothed)', linewidth=2, color='red')
plt.title('Smoothed Training Curves', fontweight='bold')
plt.xlabel('Training Step')
plt.ylabel('Cross-Entropy Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n📊 Final Results:")
print(f"SwiGLU final loss: {swiglu_losses[-1]:.4f}")
print(f"GELU final loss:   {gelu_losses[-1]:.4f}")
print(f"Improvement:       {((gelu_losses[-1] - swiglu_losses[-1]) / gelu_losses[-1] * 100):+.2f}%")

print("\n💡 Note: Results may vary due to random initialization!")
print("📚 In real experiments, SwiGLU typically shows consistent improvements")
