# Information Theory for Machine Learning
## Entropy-Based Metrics and the Information Bottleneck Principle

Welcome to the **mathematical science of information**! Information theory provides the fundamental framework for understanding how much information is contained in data and how efficiently we can process it.

### What You'll Master
By the end of this notebook, you'll understand:
1. **Entropy and information** - Quantifying uncertainty and surprise
2. **Mutual information** - Measuring statistical dependence
3. **KL divergence** - Comparing probability distributions
4. **Information bottleneck** - The principle behind representation learning
5. **Channel capacity** - Limits of information transmission
6. **Applications in ML** - From feature selection to deep learning

### Why This is Revolutionary
- **Deep learning** can be understood through information bottleneck theory
- **Feature selection** uses mutual information to find relevant variables
- **Model compression** applies information theory to reduce model size
- **Generalization** is fundamentally about information processing

### Real-World Applications
- **Data compression**: JPEG, MP3, ZIP files
- **Cryptography**: Measuring randomness and security
- **Neural networks**: Understanding what layers learn
- **Reinforcement learning**: Information-theoretic exploration

Let's decode the mathematics of information! 📡

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import entropy, multivariate_normal
from scipy.special import rel_entr
from sklearn.datasets import make_classification, load_digits
from sklearn.feature_selection import mutual_info_classif
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score
import pandas as pd
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
np.random.seed(42)

# Mathematical constants
LOG2 = np.log(2)
NATS_TO_BITS = 1 / LOG2

print("📡 Information Theory toolkit loaded!")
print("Ready to measure information and uncertainty!")

## 1. Entropy: The Fundamental Measure of Information

### What is Entropy?
**Entropy** measures the average amount of information (or uncertainty) in a random variable.

**Shannon Entropy Formula**:
```
H(X) = -∑ P(x) log P(x)
```

### Intuitive Understanding
- **High entropy**: Outcome is unpredictable (fair coin flip)
- **Low entropy**: Outcome is predictable (biased coin)
- **Zero entropy**: Outcome is certain (deterministic)

### Units of Measurement
- **Bits**: log₂ (most common in CS)
- **Nats**: ln (natural logarithm)
- **Dits**: log₁₀ (decimal digits)

### Properties of Entropy
1. **Non-negative**: H(X) ≥ 0
2. **Maximum**: H(X) ≤ log|X| (uniform distribution)
3. **Concave**: More spread out → higher entropy
4. **Additive**: H(X,Y) = H(X) + H(Y) if X ⊥ Y

### Real-World Analogy
Think of entropy as **"surprise level"**:
- Seeing the sun rise: Low surprise (low entropy)
- Winning the lottery: High surprise (high entropy)
- Weather in desert: Low entropy (predictable)
- Weather in England: High entropy (unpredictable)

In [None]:
def demonstrate_entropy_concepts():
    """Explore entropy with interactive examples"""
    
    print("📊 Entropy: Measuring Information and Uncertainty")
    print("=" * 50)
    
    # Helper function to calculate entropy
    def calculate_entropy(probabilities, base=2):
        """Calculate entropy with specified base"""
        probabilities = np.array(probabilities)
        # Remove zero probabilities to avoid log(0)
        probabilities = probabilities[probabilities > 0]
        if base == 2:
            return -np.sum(probabilities * np.log2(probabilities))
        elif base == np.e:
            return -np.sum(probabilities * np.log(probabilities))
        else:
            return -np.sum(probabilities * np.log(probabilities) / np.log(base))
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Coin flip entropy
    print("\n1. Coin Flip Entropy")
    print("   Fair coin: Maximum entropy")
    print("   Biased coin: Lower entropy")
    
    p_values = np.linspace(0.01, 0.99, 100)
    entropies = [calculate_entropy([p, 1-p]) for p in p_values]
    
    axes[0, 0].plot(p_values, entropies, 'b-', linewidth=2)
    axes[0, 0].axvline(x=0.5, color='r', linestyle='--', alpha=0.7, label='Fair coin (max entropy)')
    axes[0, 0].scatter([0.5], [1.0], color='red', s=100, zorder=5)
    axes[0, 0].set_xlabel('P(Heads)')
    axes[0, 0].set_ylabel('Entropy (bits)')
    axes[0, 0].set_title('Coin Flip Entropy')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()
    
    # Show specific examples
    examples = [(0.1, 'Biased'), (0.5, 'Fair'), (0.9, 'Biased')]
    for p, label in examples:
        h = calculate_entropy([p, 1-p])
        print(f"   P(H)={p}: H = {h:.3f} bits")
    
    # 2. Dice entropy
    print("\n2. Dice Roll Entropy")
    print("   Fair die vs loaded die")
    
    # Fair die
    fair_die = [1/6] * 6
    fair_entropy = calculate_entropy(fair_die)
    
    # Loaded die
    loaded_die = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1]
    loaded_entropy = calculate_entropy(loaded_die)
    
    x_pos = np.arange(6)
    width = 0.35
    
    axes[0, 1].bar(x_pos - width/2, fair_die, width, label=f'Fair die (H={fair_entropy:.2f})', alpha=0.8)
    axes[0, 1].bar(x_pos + width/2, loaded_die, width, label=f'Loaded die (H={loaded_entropy:.2f})', alpha=0.8)
    axes[0, 1].set_xlabel('Dice Face')
    axes[0, 1].set_ylabel('Probability')
    axes[0, 1].set_title('Fair vs Loaded Dice')
    axes[0, 1].set_xticks(x_pos)
    axes[0, 1].set_xticklabels([f'{i+1}' for i in range(6)])
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    print(f"   Fair die: H = {fair_entropy:.3f} bits")
    print(f"   Loaded die: H = {loaded_entropy:.3f} bits")
    
    # 3. Text entropy
    print("\n3. Text Entropy (English vs Random)")
    
    # English letter frequencies (approximate)
    english_freq = [0.08167, 0.01492, 0.02782, 0.04253, 0.12, 0.02228, 0.02015, 0.06094, 0.06966, 0.00153,
                   0.00772, 0.04025, 0.02406, 0.06749, 0.07507, 0.01929, 0.00095, 0.05987, 0.06327, 0.09056,
                   0.02758, 0.00978, 0.02360, 0.00150, 0.01974, 0.00074]
    
    # Random text (uniform distribution)
    random_freq = [1/26] * 26
    
    english_entropy = calculate_entropy(english_freq)
    random_entropy = calculate_entropy(random_freq)
    
    letters = [chr(ord('A') + i) for i in range(26)]
    x_letters = np.arange(26)
    
    axes[0, 2].bar(x_letters, english_freq, alpha=0.7, label=f'English (H={english_entropy:.2f})')
    axes[0, 2].axhline(y=1/26, color='red', linestyle='--', alpha=0.7, label=f'Random (H={random_entropy:.2f})')
    axes[0, 2].set_xlabel('Letter')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].set_title('Letter Frequency: English vs Random')
    axes[0, 2].set_xticks(x_letters[::3])
    axes[0, 2].set_xticklabels(letters[::3])
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    print(f"   English text: H = {english_entropy:.3f} bits per letter")
    print(f"   Random text: H = {random_entropy:.3f} bits per letter")
    
    # 4. Conditional entropy
    print("\n4. Conditional Entropy: H(Y|X)")
    print("   Entropy of Y given knowledge of X")
    
    # Create a simple example: weather prediction
    # X = season, Y = temperature
    seasons = ['Spring', 'Summer', 'Fall', 'Winter']
    temp_ranges = ['Cold', 'Mild', 'Warm', 'Hot']
    
    # Joint probability P(season, temperature)
    joint_prob = np.array([
        [0.05, 0.15, 0.05, 0.00],  # Spring
        [0.00, 0.05, 0.15, 0.05],  # Summer
        [0.05, 0.15, 0.05, 0.00],  # Fall
        [0.15, 0.10, 0.00, 0.00]   # Winter
    ])
    
    # Marginal probabilities
    p_season = joint_prob.sum(axis=1)
    p_temp = joint_prob.sum(axis=0)
    
    # Calculate entropies
    H_temp = calculate_entropy(p_temp)
    
    # Conditional entropy H(Temp|Season)
    H_temp_given_season = 0
    for i, season_prob in enumerate(p_season):
        if season_prob > 0:
            conditional_probs = joint_prob[i] / season_prob
            H_temp_given_season += season_prob * calculate_entropy(conditional_probs)
    
    # Visualize joint distribution
    im = axes[1, 0].imshow(joint_prob, cmap='Blues', aspect='auto')
    axes[1, 0].set_xticks(range(4))
    axes[1, 0].set_yticks(range(4))
    axes[1, 0].set_xticklabels(temp_ranges)
    axes[1, 0].set_yticklabels(seasons)
    axes[1, 0].set_xlabel('Temperature')
    axes[1, 0].set_ylabel('Season')
    axes[1, 0].set_title('Joint Distribution P(Season, Temp)')
    
    # Add probability values to heatmap
    for i in range(4):
        for j in range(4):
            axes[1, 0].text(j, i, f'{joint_prob[i, j]:.2f}', 
                           ha='center', va='center', color='white' if joint_prob[i, j] > 0.1 else 'black')
    
    print(f"   H(Temperature) = {H_temp:.3f} bits")
    print(f"   H(Temperature|Season) = {H_temp_given_season:.3f} bits")
    print(f"   Information gain = {H_temp - H_temp_given_season:.3f} bits")
    
    # 5. Cross-entropy and KL divergence
    print("\n5. Cross-Entropy and KL Divergence")
    print("   Measuring difference between distributions")
    
    # True distribution (uniform)
    true_dist = np.array([0.25, 0.25, 0.25, 0.25])
    
    # Different predicted distributions
    pred_dists = {
        'Perfect': np.array([0.25, 0.25, 0.25, 0.25]),
        'Close': np.array([0.3, 0.3, 0.2, 0.2]),
        'Poor': np.array([0.7, 0.1, 0.1, 0.1]),
        'Terrible': np.array([0.95, 0.02, 0.02, 0.01])
    }
    
    x_cat = np.arange(4)
    categories = ['A', 'B', 'C', 'D']
    
    axes[1, 1].bar(x_cat - 0.3, true_dist, 0.2, label='True', alpha=0.8)
    
    colors = ['green', 'yellow', 'orange', 'red']
    for i, (name, pred_dist) in enumerate(pred_dists.items()):
        offset = -0.1 + i * 0.2
        axes[1, 1].bar(x_cat + offset, pred_dist, 0.15, label=name, alpha=0.7, color=colors[i])
        
        # Calculate KL divergence
        kl_div = np.sum(true_dist * np.log(true_dist / pred_dist))
        cross_entropy = -np.sum(true_dist * np.log(pred_dist))
        print(f"   {name}: KL(true||pred) = {kl_div:.3f}, Cross-entropy = {cross_entropy:.3f}")
    
    axes[1, 1].set_xlabel('Category')
    axes[1, 1].set_ylabel('Probability')
    axes[1, 1].set_title('Cross-Entropy: True vs Predicted Distributions')
    axes[1, 1].set_xticks(x_cat)
    axes[1, 1].set_xticklabels(categories)
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Information-theoretic quantities relationships
    print("\n6. Information-Theoretic Relationships")
    
    # Create Venn diagram showing relationships
    fig_venn = plt.figure(figsize=(8, 6))
    ax_venn = fig_venn.add_subplot(111)
    
    # Draw circles for X and Y
    circle1 = plt.Circle((0.4, 0.5), 0.3, fill=False, linewidth=2, color='blue')
    circle2 = plt.Circle((0.6, 0.5), 0.3, fill=False, linewidth=2, color='red')
    ax_venn.add_patch(circle1)
    ax_venn.add_patch(circle2)
    
    # Add labels
    ax_venn.text(0.25, 0.5, 'H(X|Y)', fontsize=12, ha='center', va='center')
    ax_venn.text(0.75, 0.5, 'H(Y|X)', fontsize=12, ha='center', va='center')
    ax_venn.text(0.5, 0.5, 'I(X;Y)', fontsize=12, ha='center', va='center', weight='bold')
    ax_venn.text(0.15, 0.8, 'H(X)', fontsize=14, ha='center', color='blue', weight='bold')
    ax_venn.text(0.85, 0.8, 'H(Y)', fontsize=14, ha='center', color='red', weight='bold')
    ax_venn.text(0.5, 0.15, 'H(X,Y) = H(X) + H(Y|X) = H(Y) + H(X|Y)', fontsize=12, ha='center')
    ax_venn.text(0.5, 0.05, 'I(X;Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)', fontsize=12, ha='center', weight='bold')
    
    ax_venn.set_xlim(0, 1)
    ax_venn.set_ylim(0, 1)
    ax_venn.set_aspect('equal')
    ax_venn.axis('off')
    ax_venn.set_title('Information-Theoretic Quantities', fontsize=16, weight='bold')
    
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show the Venn diagram
    plt.figure(figsize=(8, 6))
    plt.show()
    
    print("\n🎯 Key Information Theory Concepts:")
    print("• Entropy measures uncertainty/information content")
    print("• Conditional entropy: uncertainty remaining after observing another variable")
    print("• Mutual information: reduction in uncertainty due to another variable")
    print("• KL divergence: 'distance' between probability distributions")
    print("• Cross-entropy: expected message length using wrong distribution")

demonstrate_entropy_concepts()