# Tutorial 06: Probability Concepts in ML

Interactive visualizations for:
- Joint, marginal, conditional probability
- Bayes' theorem in action
- Probability vs likelihood
- Prior → Posterior updates

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import comb

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## Part 1: Joint, Marginal, and Conditional Probability

Let's create a joint distribution and visualize all the relationships.

In [None]:
# Create a joint probability table P(X, Y)
# X: 0 or 1 (e.g., has disease or not)
# Y: 0 or 1 (e.g., tests negative or positive)

# Joint probabilities (must sum to 1)
P_joint = np.array([
    [0.85, 0.05],   # X=0: P(X=0,Y=0)=0.85, P(X=0,Y=1)=0.05
    [0.02, 0.08]    # X=1: P(X=1,Y=0)=0.02, P(X=1,Y=1)=0.08
])

print("Joint Probability Table P(X, Y):")
print("=" * 40)
print(f"              Y=0      Y=1     | P(X)")
print(f"X=0          {P_joint[0,0]:.2f}     {P_joint[0,1]:.2f}    | {P_joint[0,:].sum():.2f}")
print(f"X=1          {P_joint[1,0]:.2f}     {P_joint[1,1]:.2f}    | {P_joint[1,:].sum():.2f}")
print("-" * 40)
print(f"P(Y)         {P_joint[:,0].sum():.2f}     {P_joint[:,1].sum():.2f}    | {P_joint.sum():.2f}")

# Compute marginals
P_X = P_joint.sum(axis=1)  # Sum over Y
P_Y = P_joint.sum(axis=0)  # Sum over X

print(f"\nMarginal P(X): {P_X}")
print(f"Marginal P(Y): {P_Y}")

In [None]:
# Compute conditional probabilities
# P(Y|X) = P(X,Y) / P(X)
P_Y_given_X = P_joint / P_X[:, np.newaxis]

# P(X|Y) = P(X,Y) / P(Y)
P_X_given_Y = P_joint / P_Y[np.newaxis, :]

print("Conditional P(Y|X):")
print(f"P(Y=0|X=0) = {P_Y_given_X[0,0]:.4f}    P(Y=1|X=0) = {P_Y_given_X[0,1]:.4f}  (sums to {P_Y_given_X[0,:].sum():.1f})")
print(f"P(Y=0|X=1) = {P_Y_given_X[1,0]:.4f}    P(Y=1|X=1) = {P_Y_given_X[1,1]:.4f}  (sums to {P_Y_given_X[1,:].sum():.1f})")

print("\nConditional P(X|Y):")
print(f"P(X=0|Y=0) = {P_X_given_Y[0,0]:.4f}    P(X=1|Y=0) = {P_X_given_Y[1,0]:.4f}  (sums to {P_X_given_Y[:,0].sum():.1f})")
print(f"P(X=0|Y=1) = {P_X_given_Y[0,1]:.4f}    P(X=1|Y=1) = {P_X_given_Y[1,1]:.4f}  (sums to {P_X_given_Y[:,1].sum():.1f})")

In [None]:
# Visualize joint, marginal, conditional
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Joint probability
im = axes[0, 0].imshow(P_joint, cmap='Blues', vmin=0, vmax=1)
axes[0, 0].set_xticks([0, 1])
axes[0, 0].set_xticklabels(['Y=0', 'Y=1'])
axes[0, 0].set_yticks([0, 1])
axes[0, 0].set_yticklabels(['X=0', 'X=1'])
for i in range(2):
    for j in range(2):
        axes[0, 0].text(j, i, f'{P_joint[i,j]:.2f}', ha='center', va='center', fontsize=14)
axes[0, 0].set_title('Joint P(X, Y)', fontsize=14)
plt.colorbar(im, ax=axes[0, 0])

# Marginals
x_pos = np.array([0, 1])
width = 0.35
axes[0, 1].bar(x_pos - width/2, P_X, width, label='P(X)', color='steelblue')
axes[0, 1].bar(x_pos + width/2, P_Y, width, label='P(Y)', color='coral')
axes[0, 1].set_xticks([0, 1])
axes[0, 1].set_xticklabels(['0', '1'])
axes[0, 1].set_ylabel('Probability', fontsize=12)
axes[0, 1].set_title('Marginal Probabilities', fontsize=14)
axes[0, 1].legend(fontsize=12)
axes[0, 1].set_ylim(0, 1)

# P(Y|X)
im2 = axes[1, 0].imshow(P_Y_given_X, cmap='Greens', vmin=0, vmax=1)
axes[1, 0].set_xticks([0, 1])
axes[1, 0].set_xticklabels(['Y=0', 'Y=1'])
axes[1, 0].set_yticks([0, 1])
axes[1, 0].set_yticklabels(['X=0', 'X=1'])
for i in range(2):
    for j in range(2):
        axes[1, 0].text(j, i, f'{P_Y_given_X[i,j]:.2f}', ha='center', va='center', fontsize=14)
axes[1, 0].set_title('Conditional P(Y|X)\n(each row sums to 1)', fontsize=14)
plt.colorbar(im2, ax=axes[1, 0])

# P(X|Y)
im3 = axes[1, 1].imshow(P_X_given_Y, cmap='Oranges', vmin=0, vmax=1)
axes[1, 1].set_xticks([0, 1])
axes[1, 1].set_xticklabels(['Y=0', 'Y=1'])
axes[1, 1].set_yticks([0, 1])
axes[1, 1].set_yticklabels(['X=0', 'X=1'])
for i in range(2):
    for j in range(2):
        axes[1, 1].text(j, i, f'{P_X_given_Y[i,j]:.2f}', ha='center', va='center', fontsize=14)
axes[1, 1].set_title('Conditional P(X|Y)\n(each column sums to 1)', fontsize=14)
plt.colorbar(im3, ax=axes[1, 1])

plt.tight_layout()
plt.show()

## Part 2: Bayes' Theorem in Action

The classic medical diagnosis example.

In [None]:
def bayes_medical_diagnosis(prevalence, sensitivity, specificity):
    """
    Compute P(Disease | Positive Test) using Bayes' theorem.
    
    Args:
        prevalence: P(D) - prior probability of disease
        sensitivity: P(T+|D) - true positive rate
        specificity: P(T-|not D) - true negative rate
    """
    P_D = prevalence
    P_not_D = 1 - prevalence
    P_pos_given_D = sensitivity
    P_pos_given_not_D = 1 - specificity  # False positive rate
    
    # Law of total probability: P(T+)
    P_pos = P_pos_given_D * P_D + P_pos_given_not_D * P_not_D
    
    # Bayes' theorem: P(D|T+)
    P_D_given_pos = (P_pos_given_D * P_D) / P_pos
    
    return P_D_given_pos, P_pos

# Example: Rare disease
prevalence = 0.01  # 1% have disease
sensitivity = 0.99  # 99% true positive rate
specificity = 0.95  # 95% true negative rate (5% false positive)

P_D_given_pos, P_pos = bayes_medical_diagnosis(prevalence, sensitivity, specificity)

print("Medical Diagnosis with Bayes' Theorem")
print("=" * 50)
print(f"Prior P(Disease) = {prevalence:.2%}")
print(f"Sensitivity P(T+|D) = {sensitivity:.2%}")
print(f"Specificity P(T-|¬D) = {specificity:.2%}")
print(f"\nP(Positive Test) = {P_pos:.2%}")
print(f"\n>>> P(Disease | Positive Test) = {P_D_given_pos:.2%} <<<")
print(f"\nSurprising! Despite 99% sensitivity, only {P_D_given_pos:.1%} chance of disease!")

In [None]:
# Visualize how posterior varies with prevalence
prevalences = np.linspace(0.001, 0.5, 100)
posteriors = [bayes_medical_diagnosis(p, 0.99, 0.95)[0] for p in prevalences]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Posterior vs prevalence
axes[0].plot(prevalences * 100, np.array(posteriors) * 100, 'b-', linewidth=2)
axes[0].axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='50% threshold')
axes[0].axvline(x=1, color='red', linestyle='--', alpha=0.5, label='1% prevalence')
axes[0].set_xlabel('Disease Prevalence (%)', fontsize=12)
axes[0].set_ylabel('P(Disease | Positive Test) (%)', fontsize=12)
axes[0].set_title('Posterior Probability vs Disease Prevalence\n(Sensitivity=99%, Specificity=95%)', fontsize=14)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Population breakdown
N = 10000
n_disease = int(N * prevalence)
n_healthy = N - n_disease

true_positive = int(n_disease * sensitivity)
false_negative = n_disease - true_positive
false_positive = int(n_healthy * (1 - specificity))
true_negative = n_healthy - false_positive

categories = ['True Positive\n(Disease, T+)', 'False Negative\n(Disease, T-)', 
              'False Positive\n(Healthy, T+)', 'True Negative\n(Healthy, T-)']
counts = [true_positive, false_negative, false_positive, true_negative]
colors = ['darkred', 'salmon', 'orange', 'lightgreen']

axes[1].bar(categories, counts, color=colors)
axes[1].set_ylabel('Number of People', fontsize=12)
axes[1].set_title(f'Population of {N:,} (Prevalence = {prevalence:.1%})', fontsize=14)

# Add count labels
for i, (cat, count) in enumerate(zip(categories, counts)):
    axes[1].text(i, count + 100, f'{count:,}', ha='center', fontsize=11)

# Highlight the positive tests
total_positive = true_positive + false_positive
axes[1].axhline(y=0, color='black', linewidth=0.5)

plt.tight_layout()
plt.show()

print(f"\nOf {total_positive} positive tests:")
print(f"  {true_positive} actually have disease ({true_positive/total_positive:.1%})")
print(f"  {false_positive} are false positives ({false_positive/total_positive:.1%})")

## Part 3: Probability vs Likelihood

The same formula, viewed two different ways.

In [None]:
def binomial_prob(k, n, theta):
    """P(X=k | n, theta) - probability of k successes in n trials."""
    return comb(n, k, exact=True) * (theta ** k) * ((1-theta) ** (n-k))

n = 10  # number of trials

# PROBABILITY view: fixed theta, varying k
theta_fixed = 0.7
k_values = np.arange(0, n+1)
probs = [binomial_prob(k, n, theta_fixed) for k in k_values]

# LIKELIHOOD view: fixed k (observed data), varying theta
k_observed = 7
theta_values = np.linspace(0.01, 0.99, 100)
likelihoods = [binomial_prob(k_observed, n, theta) for theta in theta_values]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Probability view
axes[0].bar(k_values, probs, color='steelblue', alpha=0.7, edgecolor='black')
axes[0].axvline(x=k_observed, color='red', linestyle='--', linewidth=2, label=f'Observed k={k_observed}')
axes[0].set_xlabel('k (number of successes)', fontsize=12)
axes[0].set_ylabel('P(X=k | θ=0.7)', fontsize=12)
axes[0].set_title(f'PROBABILITY View\nθ={theta_fixed} fixed, k varies\n(sums to {sum(probs):.1f})', fontsize=14)
axes[0].legend(fontsize=11)

# Likelihood view
axes[1].plot(theta_values, likelihoods, 'b-', linewidth=2)
axes[1].fill_between(theta_values, likelihoods, alpha=0.3)
axes[1].axvline(x=theta_fixed, color='red', linestyle='--', linewidth=2, label=f'MLE θ={k_observed/n}')
axes[1].set_xlabel('θ (probability parameter)', fontsize=12)
axes[1].set_ylabel(f'L(θ | k={k_observed})', fontsize=12)
axes[1].set_title(f'LIKELIHOOD View\nk={k_observed} fixed (observed), θ varies\n(does NOT sum to 1!)', fontsize=14)
axes[1].legend(fontsize=11)

# Mark MLE
mle_theta = k_observed / n
mle_likelihood = binomial_prob(k_observed, n, mle_theta)
axes[1].plot(mle_theta, mle_likelihood, 'ro', markersize=10)
axes[1].annotate(f'MLE: θ={mle_theta}', xy=(mle_theta, mle_likelihood), 
                 xytext=(mle_theta+0.1, mle_likelihood+0.02), fontsize=11)

plt.tight_layout()
plt.show()

# Show they DON'T integrate to 1
approx_integral = np.trapz(likelihoods, theta_values)
print(f"\nKey difference:")
print(f"  Probability (over k): sums to {sum(probs):.4f}")
print(f"  Likelihood (over θ): integrates to {approx_integral:.4f} (NOT 1!)")

## Part 4: Prior → Posterior Updates (Bayesian Inference)

In [None]:
# Beta-Binomial conjugate prior example
# Prior: Beta(a, b)
# Likelihood: Binomial(n, theta)
# Posterior: Beta(a + k, b + n - k)

def plot_bayesian_update(prior_a, prior_b, n_trials, n_successes):
    """Visualize Bayesian update from prior to posterior."""
    theta = np.linspace(0, 1, 1000)
    
    # Prior
    prior = stats.beta.pdf(theta, prior_a, prior_b)
    
    # Likelihood (unnormalized)
    likelihood = theta**n_successes * (1-theta)**(n_trials-n_successes)
    likelihood = likelihood / likelihood.max()  # Normalize for visualization
    
    # Posterior
    post_a = prior_a + n_successes
    post_b = prior_b + n_trials - n_successes
    posterior = stats.beta.pdf(theta, post_a, post_b)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    ax.plot(theta, prior, 'b-', linewidth=2, label=f'Prior: Beta({prior_a}, {prior_b})')
    ax.plot(theta, likelihood * prior.max(), 'g--', linewidth=2, 
            label=f'Likelihood (scaled): {n_successes}/{n_trials} successes')
    ax.plot(theta, posterior, 'r-', linewidth=3, label=f'Posterior: Beta({post_a}, {post_b})')
    
    ax.fill_between(theta, prior, alpha=0.2, color='blue')
    ax.fill_between(theta, posterior, alpha=0.2, color='red')
    
    ax.set_xlabel('θ (probability parameter)', fontsize=12)
    ax.set_ylabel('Density', fontsize=12)
    ax.set_title('Bayesian Update: Prior × Likelihood ∝ Posterior', fontsize=14)
    ax.legend(fontsize=11)
    ax.set_xlim(0, 1)
    
    # Add MLE and MAP estimates
    mle = n_successes / n_trials if n_trials > 0 else 0.5
    map_estimate = (post_a - 1) / (post_a + post_b - 2) if post_a > 1 and post_b > 1 else 0.5
    
    ax.axvline(x=mle, color='green', linestyle=':', linewidth=2, alpha=0.7)
    ax.axvline(x=map_estimate, color='red', linestyle=':', linewidth=2, alpha=0.7)
    
    plt.tight_layout()
    plt.show()
    
    print(f"MLE estimate: {mle:.3f}")
    print(f"MAP estimate: {map_estimate:.3f}")
    print(f"Posterior mean: {post_a / (post_a + post_b):.3f}")

# Example 1: Uniform prior, 7/10 successes
print("Example 1: Uniform Prior (no prior knowledge)")
print("=" * 50)
plot_bayesian_update(prior_a=1, prior_b=1, n_trials=10, n_successes=7)

In [None]:
# Example 2: Strong prior that coin is fair
print("Example 2: Strong Prior (believe coin is fair)")
print("=" * 50)
plot_bayesian_update(prior_a=20, prior_b=20, n_trials=10, n_successes=7)

In [None]:
# Example 3: More data overwhelms prior
print("Example 3: More Data (prior gets overwhelmed)")
print("=" * 50)
plot_bayesian_update(prior_a=20, prior_b=20, n_trials=100, n_successes=70)

In [None]:
# Sequential Bayesian updates
print("Sequential Bayesian Updates")
print("=" * 50)

# Start with uniform prior
a, b = 1, 1
theta = np.linspace(0, 1, 1000)

# Simulate coin flips from a coin with true θ=0.65
np.random.seed(42)
true_theta = 0.65
flips = np.random.binomial(1, true_theta, 50)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

checkpoints = [0, 1, 5, 10, 25, 50]  # After how many flips

for idx, n in enumerate(checkpoints):
    if n == 0:
        current_a, current_b = 1, 1
    else:
        current_a = 1 + flips[:n].sum()
        current_b = 1 + n - flips[:n].sum()
    
    posterior = stats.beta.pdf(theta, current_a, current_b)
    
    axes[idx].plot(theta, posterior, 'b-', linewidth=2)
    axes[idx].fill_between(theta, posterior, alpha=0.3)
    axes[idx].axvline(x=true_theta, color='red', linestyle='--', linewidth=2, label=f'True θ={true_theta}')
    
    if n > 0:
        mean = current_a / (current_a + current_b)
        axes[idx].axvline(x=mean, color='green', linestyle=':', linewidth=2, label=f'Mean={mean:.2f}')
    
    axes[idx].set_xlabel('θ', fontsize=11)
    axes[idx].set_ylabel('Density', fontsize=11)
    heads = flips[:n].sum() if n > 0 else 0
    axes[idx].set_title(f'After {n} flips ({heads}H, {n-heads}T)\nBeta({current_a}, {current_b})', fontsize=12)
    axes[idx].legend(fontsize=9)
    axes[idx].set_xlim(0, 1)

plt.tight_layout()
plt.show()

print(f"\nWith more data, posterior concentrates around true θ={true_theta}")

## Part 5: ML Losses as Negative Log-Likelihoods

In [None]:
# Show that MSE = NLL for Gaussian
# If y|x ~ N(f(x), σ²), then -log p(y|x) ∝ (y - f(x))²

y_true = 3.0
predictions = np.linspace(0, 6, 100)
sigma = 1.0

# MSE loss
mse = (predictions - y_true) ** 2

# Negative log-likelihood (Gaussian)
nll = 0.5 * (predictions - y_true)**2 / sigma**2 + 0.5 * np.log(2 * np.pi * sigma**2)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(predictions, mse, 'b-', linewidth=2, label='MSE Loss')
axes[0].plot(predictions, nll, 'r--', linewidth=2, label='NLL (Gaussian)')
axes[0].axvline(x=y_true, color='green', linestyle=':', linewidth=2, label=f'y_true={y_true}')
axes[0].set_xlabel('Prediction', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('MSE = Negative Log-Likelihood (Gaussian noise)', fontsize=14)
axes[0].legend(fontsize=11)

# Cross-entropy for classification
y_true_class = 1  # True class is 1
p_pred = np.linspace(0.01, 0.99, 100)  # Predicted P(class=1)

# Cross-entropy: -y*log(p) - (1-y)*log(1-p)
# For y=1: -log(p)
ce_loss = -np.log(p_pred)

axes[1].plot(p_pred, ce_loss, 'b-', linewidth=2)
axes[1].set_xlabel('Predicted P(y=1)', fontsize=12)
axes[1].set_ylabel('Cross-Entropy Loss', fontsize=12)
axes[1].set_title('Cross-Entropy = NLL for Bernoulli\n(True class = 1)', fontsize=14)
axes[1].axvline(x=1, color='green', linestyle=':', linewidth=2, label='Optimal: predict p=1')
axes[1].legend(fontsize=11)

plt.tight_layout()
plt.show()

print("Key insight: Common loss functions ARE negative log-likelihoods!")
print("  • MSE Loss ↔ Gaussian noise assumption")
print("  • Cross-Entropy ↔ Bernoulli/Categorical assumption")

## Summary

| Concept | Formula | Interpretation |
|---------|---------|----------------|
| **Joint** | $P(x,y)$ | Probability of x AND y |
| **Marginal** | $P(x) = \sum_y P(x,y)$ | Sum out other variables |
| **Conditional** | $P(x|y) = P(x,y)/P(y)$ | Restrict to world where y happened |
| **Bayes** | $P(A|B) = P(B|A)P(A)/P(B)$ | Flip the conditioning |
| **Probability** | $P(x|\theta)$ as function of $x$ | What data might we see? |
| **Likelihood** | $P(x|\theta)$ as function of $\theta$ | What parameters explain data? |
| **Prior** | $P(\theta)$ | Belief before seeing data |
| **Posterior** | $P(\theta|x) \propto P(x|\theta)P(\theta)$ | Belief after seeing data |