# Module 5.4: Loss Functions

**Goal**: Understand loss functions for language modeling

**Time**: 60 minutes

**Concepts Covered**:
- Cross-entropy implementation
- Causal LM loss (shifted targets)
- Perplexity calculation
- Connection to maximum likelihood
- Visualize loss landscapes

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def cross_entropy_loss(logits, targets):
    """Cross-entropy loss from scratch"""
    # Log-softmax for numerical stability
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Negative log-likelihood
    loss = F.nll_loss(log_probs, targets, reduction='mean')
    return loss

def causal_lm_loss(logits, input_ids):
    """Causal language modeling loss (shifted targets)"""
    # Shift: predict next token
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous()
    
    # Flatten for cross-entropy
    shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    shift_labels = shift_labels.view(-1)
    
    loss = F.cross_entropy(shift_logits, shift_labels)
    return loss

def perplexity(loss):
    """Calculate perplexity from loss"""
    return torch.exp(loss).item()

# Example
vocab_size = 1000
seq_len = 10
logits = torch.randn(2, seq_len, vocab_size)
targets = torch.randint(0, vocab_size, (2, seq_len))

loss = causal_lm_loss(logits, targets)
ppl = perplexity(loss)

print(f"Loss: {loss.item():.4f}")
print(f"Perplexity: {ppl:.2f}")
print("\nPerplexity = e^loss")
print("Lower perplexity = better model")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.