# KL Divergence Loss Implementation

## Problem Analysis: Implement KL Divergence Loss

### Problem Statement

You are tasked with implementing the **Kullback-Leibler (KL) Divergence Loss** as a custom loss function in PyTorch, a key component in training LLMs for tasks like knowledge distillation or aligning model outputs with a target distribution. KL Divergence measures the difference between two probability distributions, often used to encourage a model's predicted probabilities to match a target distribution.

### Mathematical Definition

KL Divergence between two discrete probability distributions P (target) and Q (predicted) is defined as:

```
D_KL(P || Q) = ∑_i P(i) log(P(i)/Q(i))
```

For a batch of n samples and c classes, the loss is averaged:

```
L = (1/n) ∑_{j=1}^n ∑_{i=1}^c P_j(i) log(P_j(i)/Q_j(i))
```

where:
- P_j(i): Target probability for sample j, class i
- Q_j(i): Predicted probability for sample j, class i

## Requirements

- Implement a `KLDivergenceLoss` class inheriting from `torch.nn.Module`
- Define the `forward` method to compute the KL Divergence Loss
- Handle numerical stability (e.g., avoid division by zero or log of zero)
- Integrate into a simple training pipeline with a synthetic dataset of probability distributions
- Use PyTorch for tensor operations and autograd

## Constraints

- Use only PyTorch (no scikit-learn or other ML libraries)
- Handle batch inputs (P, Q ∈ R^{n×c})
- Ensure the loss is a scalar for optimization
- Add a small constant (ε) to prevent numerical issues

## Synthetic Dataset

- Generate n=100 samples with c=10 classes
- **Target (P)**: Softmax of random logits to simulate true probabilities
- **Predicted (Q)**: Softmax of model outputs to simulate predicted probabilities
- Test on a small batch to verify the loss

## Implementation Guidelines

### Key Considerations

1. **Numerical Stability**: Add epsilon (ε = 1e-8) to prevent log(0) and division by 0
2. **Probability Validation**: Ensure inputs are valid probability distributions
3. **Batch Processing**: Handle multiple samples simultaneously
4. **Gradient Flow**: Maintain differentiability for backpropagation

### Expected Output Structure

```python
class KLDivergenceLoss(torch.nn.Module):
    def __init__(self, epsilon=1e-8):
        # Initialize with numerical stability constant
        
    def forward(self, target_probs, predicted_probs):
        # Compute KL divergence loss
        # Return scalar loss value
```

### Testing Requirements

- Verify loss computation on synthetic data
- Test gradient computation
- Validate numerical stability
- Compare with PyTorch's built-in KLDivLoss (if applicable)

## Usage Example

```python
# Create loss function
kl_loss = KLDivergenceLoss()

# Generate synthetic data
target_probs = torch.softmax(torch.randn(100, 10), dim=1)
predicted_probs = torch.softmax(torch.randn(100, 10), dim=1)

# Compute loss
loss = kl_loss(target_probs, predicted_probs)
```

## Evaluation Metrics

- Loss value should be non-negative
- Loss should be 0 when P = Q
- Gradient should flow properly for optimization
- Should handle edge cases gracefully

## Expected Deliverables

1. Complete `KLDivergenceLoss` class implementation
2. Synthetic dataset generation code
3. Training pipeline integration
4. Test cases and validation
5. Documentation and comments

In [None]:
import torch
# Purpose: Import PyTorch for tensor operations and neural network functionality.
# Theory: PyTorch provides tensors with GPU support and autograd for automatic differentiation, essential for custom loss functions and model training.

import torch.nn as nn
# Purpose: Import neural network modules, including nn.Module for defining custom loss and model classes.
# Theory: nn.Module enables custom loss functions like KL Divergence to integrate with PyTorch’s autograd system.

import torch.optim as optim
# Purpose: Import optimization algorithms like Adam for updating model parameters.
# Theory: Adam adapts learning rates using momentum, suitable for optimizing models with custom losses.

# Set random seed for reproducibility
torch.manual_seed(42)
# Purpose: Fix the random seed to ensure consistent random number generation.
# Theory: Ensures reproducibility of synthetic data and model initialization, aligning with previous TorchLeet problems (e.g., DNN regression).

# Generate synthetic data
n_samples, n_classes = 100, 10
# Purpose: Define the number of samples (100) and classes (10) for the synthetic dataset.
# Theory: A dataset with multiple classes simulates probability distributions, suitable for testing KL Divergence.

X = torch.rand(n_samples, n_classes)
# Purpose: Generate random input logits for the model, shape [100, 10].
# Theory: Random logits represent unnormalized scores, which will be converted to probabilities via softmax.

y_true = torch.softmax(torch.rand(n_samples, n_classes), dim=1)
# Purpose: Generate target probability distributions using softmax, shape [100, 10].
# Theory: Softmax ensures valid probabilities (\sum_i P(i) = 1). Random logits simulate true distributions from a teacher model.

# Define the KL Divergence Loss
class KLDivergenceLoss(nn.Module):
    # Purpose: Define a custom KL Divergence Loss by subclassing nn.Module.
    # Theory: nn.Module integrates the loss with PyTorch’s autograd, enabling gradient computation for optimization.
    
    def __init__(self, epsilon=1e-10):
        # Purpose: Initialize the loss with a small constant for numerical stability.
        # Theory: epsilon prevents division by zero or log(0) in KL Divergence computation.
        
        super(KLDivergenceLoss, self).__init__()
        # Purpose: Call the parent nn.Module constructor to set up the module.
        # Theory: Ensures proper initialization for autograd integration.
        
        self.epsilon = epsilon
        # Purpose: Store epsilon as an instance variable.
        # Theory: Used to stabilize log computations, a common practice in probabilistic losses.
    
    def forward(self, y_pred, y_true):
        # Purpose: Compute the KL Divergence Loss between predicted and true probability distributions.
        # Theory: Computes D_KL(P || Q) = \sum P(i) \log(P(i)/Q(i)), averaged over samples. y_pred and y_true are [batch_size, n_classes].
        
        y_pred = torch.softmax(y_pred, dim=1)
        # Purpose: Convert predicted logits to probabilities using softmax.
        # Theory: Softmax ensures \sum_i Q(i) = 1, making y_pred a valid probability distribution.
        
        y_pred = torch.clamp(y_pred, min=self.epsilon)
        # Purpose: Clip predicted probabilities to avoid log(0) or division by zero.
        # Theory: Clamping adds numerical stability, ensuring all values are positive.
        
        kl_div = y_true * (torch.log(y_true + self.epsilon) - torch.log(y_pred))
        # Purpose: Compute the KL Divergence term P(i) \log(P(i)/Q(i)) element-wise.
        # Theory: Expands to P(i) (\log P(i) - \log Q(i)). Adding epsilon to y_true prevents log(0) in rare cases.
        
        kl_div = torch.sum(kl_div, dim=1)
        # Purpose: Sum KL Divergence over classes for each sample.
        # Theory: Aggregates the divergence across all classes, producing a per-sample loss [batch_size].
        
        return torch.mean(kl_div)
        # Purpose: Average the loss over the batch to produce a scalar.
        # Theory: Scalar loss is required for optimization, enabling gradient backpropagation.

# Define a simple model to generate predicted probabilities
class SimpleModel(nn.Module):
    # Purpose: Define a simple neural network to output logits for KL Divergence.
    # Theory: A linear layer maps inputs to logits, which are converted to probabilities via softmax.
    
    def __init__(self):
        super(SimpleModel, self).__init__()
        # Purpose: Initialize the parent nn.Module class.
        # Theory: Ensures proper parameter registration and autograd setup.
        
        self.linear = nn.Linear(n_classes, n_classes)
        # Purpose: Create a linear layer mapping input logits to output logits.
        # Theory: Applies z = Wx + b, where W is [n_classes, n_classes], b is [n_classes]. Outputs logits for softmax.
    
    def forward(self, x):
        # Purpose: Define the forward pass to compute logits.
        # Theory: Builds the computational graph for autograd, producing logits for KL Divergence.
        
        return self.linear(x)
        # Purpose: Apply the linear transformation.
        # Theory: Outputs logits [batch_size, n_classes], which KLDivergenceLoss converts to probabilities.

# Initialize the model, loss function, and optimizer
model = SimpleModel()
# Purpose: Create an instance of the model.
# Theory: Initializes weights and biases randomly (Xavier initialization), tracked by autograd.

criterion = KLDivergenceLoss()
# Purpose: Initialize the KL Divergence Loss with default epsilon.
# Theory: Prepares the custom loss for training, handling numerical stability.

optimizer = optim.Adam(model.parameters(), lr=0.01)
# Purpose: Initialize Adam optimizer with learning rate 0.01.
# Theory: Adam adapts learning rates using momentum (β1=0.9, β2=0.999), suitable for optimizing probabilistic models.

# Training loop
epochs = 1000
# Purpose: Set the number of training iterations to 1000 epochs.
# Theory: Multiple epochs allow the model to learn to match the target distribution, consistent with previous TorchLeet problems.

for epoch in range(epochs):
    # Purpose: Iterate over the dataset for training.
    # Theory: Each epoch processes the entire dataset, updating parameters to minimize KL Divergence.
    
    # Forward pass
    predictions = model(X)
    # Purpose: Compute model logits by passing input X through the model.
    # Theory: X [100, 10] produces logits [100, 10], which are converted to probabilities in the loss.
    
    loss = criterion(predictions, y_true)
    # Purpose: Calculate the KL Divergence Loss between predicted and true distributions.
    # Theory: Computes D_KL(P || Q), averaging over samples. Both tensors are [100, 10].
    
    # Backward pass and optimization
    optimizer.zero_grad()
    # Purpose: Reset gradients of all model parameters.
    # Theory: Prevents gradient accumulation from previous iterations, ensuring correct updates.
    
    loss.backward()
    # Purpose: Compute gradients of the loss with respect to model parameters.
    # Theory: Autograd backpropagates through the loss (KL Divergence) → softmax → linear layer.
    
    optimizer.step()
    # Purpose: Update model parameters using gradients.
    # Theory: Adam applies adaptive updates to minimize the loss.
    
    # Log progress every 100 epochs
    if (epoch + 1) % 100 == 0:
        # Purpose: Print training progress.
        # Theory: Monitoring loss helps assess convergence and detect issues like numerical instability.
        
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")
        # Purpose: Display epoch and loss value.
        # Theory: loss.item() extracts the scalar loss for readable output.

# Testing on new data
X_test = torch.rand(2, n_classes)
# Purpose: Generate a small test set with 2 samples.
# Theory: Tests model generalization on new inputs, shape [2, 10].

with torch.no_grad():
    # Purpose: Disable gradient tracking for inference.
    # Theory: Saves memory and computation during evaluation.
    
    predictions = torch.softmax(model(X_test), dim=1)
    # Purpose: Compute predicted probabilities for test inputs.
    # Theory: Softmax converts logits to probabilities, shape [2, 10].
    
    print(f"Test Predictions: {predictions.tolist()}")
    # Purpose: Print predicted probabilities.
    # Theory: Shows how closely the model matches a target distribution (not directly comparable here due to random test data).