# Bayesian Learning: Updating Beliefs with Evidence

**How do beliefs change when you learn something new?**

This notebook explores Bayes' theorem through interactive examples!

**You'll discover:**
- üß† How to update beliefs with evidence
- üöï The famous Taxicab Problem
- üìä Why base rates matter SO MUCH
- üéÆ Interactive belief updating

**Prepare for some surprising results!** ü§Ø

---

## üöÄ Setup

Let's get everything ready!

**Note**: After running the installation cell below, you may need to restart the runtime (Runtime ‚Üí Restart runtime) before proceeding with the rest of the notebook.

In [None]:
!pip install genjax ipywidgets matplotlib seaborn -q

print("‚úÖ Installation complete!")

In [None]:
import jax
import jax.numpy as jnp
from genjax import gen, flip, ChoiceMap
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np

jax.config.update('jax_enable_x64', True)
sns.set_style("whitegrid")

print("‚úÖ Ready to learn Bayesian inference!")

## üöï The Taxicab Problem

**The Story**:

Chibany witnesses a hit-and-run accident at night. He says the taxi was **blue**.

**What we know**:
- 85% of taxis are **green**, 15% are **blue**
- Chibany identifies colors correctly 80% of the time

**The Question**: What's the probability it was actually a blue taxi?

**Your intuition**: Before we calculate, what do YOU think?
- 80% (matching Chibany's accuracy)?
- 15% (matching the base rate)?
- Something else?

Let's find out! üîç

## üéØ Method 1: Simulation

Let's simulate thousands of scenarios and count!

In [None]:
@gen
def taxicab_scenario(base_rate_blue=0.15, accuracy=0.80):
    """
    Simulate one taxicab scenario.
    
    Args:
        base_rate_blue: P(taxi is blue)
        accuracy: P(identifies correctly)
    
    Returns:
        True if taxi is blue, False if green
    """
    # True color
    is_blue = flip(base_rate_blue) @ "is_blue"
    
    # What Chibany says depends on true color
    if is_blue:
        says_blue = flip(accuracy) @ "says_blue"  # Correct 80%
    else:
        says_blue = flip(1 - accuracy) @ "says_blue"  # Mistake 20%
    
    return is_blue

# Simulate 100,000 scenarios
n_sims = 100000
key = jax.random.key(42)
keys = jax.random.split(key, n_sims)

def run_scenario(k):
    trace = taxicab_scenario.simulate(k, (0.15, 0.80))
    choices = trace.get_choices()
    return (int(choices['is_blue']), int(choices['says_blue']))

results = jax.vmap(run_scenario)(keys)
is_blue_results, says_blue_results = results

# Filter: Keep only cases where Chibany says "blue"
says_blue_mask = (says_blue_results == 1)
n_says_blue = jnp.sum(says_blue_mask)

# Among those, count actually blue
actually_blue_and_says_blue = jnp.logical_and(is_blue_results == 1, says_blue_results == 1)
n_actually_blue = jnp.sum(actually_blue_and_says_blue)

# Calculate posterior
prob_blue_given_says_blue = n_actually_blue / n_says_blue

print("üöï Taxicab Problem: Simulation Results")
print("=" * 60)
print(f"Total scenarios: {n_sims:,}")
print(f"Chibany says 'blue': {int(n_says_blue):,} times")
print(f"Actually blue: {int(n_actually_blue):,} times")
print("\n" + "=" * 60)
print(f"P(Blue | Says Blue) = {prob_blue_given_says_blue:.4f}")
print("=" * 60)
print("\nü§î Surprised? Only ~41%!")
print("   Even though Chibany is 80% accurate!")

## üßÆ Method 2: Bayes' Theorem

Let's verify with the formula:

$$P(\text{Blue} \mid \text{Says Blue}) = \frac{P(\text{Says Blue} \mid \text{Blue}) \cdot P(\text{Blue})}{P(\text{Says Blue})}$$

Breaking it down step by step:

In [None]:
# Prior probabilities
P_blue = 0.15
P_green = 0.85

# Likelihoods
P_says_blue_given_blue = 0.80  # Correct identification
P_says_blue_given_green = 0.20  # Mistake

# Evidence (total probability of saying "blue")
P_says_blue = (P_blue * P_says_blue_given_blue + 
               P_green * P_says_blue_given_green)

# Posterior (Bayes' theorem)
P_blue_given_says_blue = (P_says_blue_given_blue * P_blue) / P_says_blue

print("üßÆ Bayes' Theorem Calculation")
print("=" * 60)
print("Step 1: Prior")
print(f"   P(Blue) = {P_blue}  (base rate)")
print(f"   P(Green) = {P_green}")
print("\nStep 2: Likelihood")
print(f"   P(Says Blue | Blue) = {P_says_blue_given_blue}  (accuracy)")
print(f"   P(Says Blue | Green) = {P_says_blue_given_green}  (mistake rate)")
print("\nStep 3: Evidence")
print(f"   P(Says Blue) = {P_blue} √ó {P_says_blue_given_blue} + {P_green} √ó {P_says_blue_given_green}")
print(f"                = {P_blue * P_says_blue_given_blue} + {P_green * P_says_blue_given_green}")
print(f"                = {P_says_blue}")
print("\nStep 4: Posterior (Bayes' Theorem)")
print(f"   P(Blue | Says Blue) = ({P_says_blue_given_blue} √ó {P_blue}) / {P_says_blue}")
print(f"                       = {P_says_blue_given_blue * P_blue} / {P_says_blue}")
print(f"                       = {P_blue_given_says_blue:.4f}")
print("\n" + "=" * 60)
print("‚ú® Matches simulation! Math works!")

## üí° Why So Low?

**The insight**: Even with 80% accuracy, there are **more false positives than true positives**!

Let's break it down with numbers (out of 100 taxis):

In [None]:
# Imagine 100 taxis
total_taxis = 100

# True colors
n_blue = int(total_taxis * 0.15)  # 15 blue
n_green = int(total_taxis * 0.85)  # 85 green

# What Chibany identifies
blue_identified_correctly = n_blue * 0.80  # 12
green_misidentified_as_blue = n_green * 0.20  # 17

total_says_blue = blue_identified_correctly + green_misidentified_as_blue

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# True distribution
ax1.bar(['Blue Taxis', 'Green Taxis'], [n_blue, n_green], 
        color=['#3498db', '#2ecc71'], alpha=0.7, edgecolor='black')
ax1.set_ylabel('Count (out of 100)', fontsize=12, fontweight='bold')
ax1.set_title('TRUE Distribution', fontsize=14, fontweight='bold')
ax1.set_ylim([0, 100])
ax1.grid(axis='y', alpha=0.3)

for i, v in enumerate([n_blue, n_green]):
    ax1.text(i, v + 2, str(v), ha='center', fontsize=14, fontweight='bold')

# What Chibany says "blue"
categories = ['True Positives\n(Actually Blue)', 'False Positives\n(Actually Green)']
counts = [blue_identified_correctly, green_misidentified_as_blue]
colors_bar = ['#3498db', '#e74c3c']

bars = ax2.bar(categories, counts, color=colors_bar, alpha=0.7, edgecolor='black')
ax2.set_ylabel('Count', fontsize=12, fontweight='bold')
ax2.set_title('When Chibany Says "Blue"', fontsize=14, fontweight='bold')
ax2.set_ylim([0, max(counts) * 1.3])
ax2.grid(axis='y', alpha=0.3)

for bar in bars:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
            f'{height:.0f}',
            ha='center', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("üí° The Key Insight:")
print("=" * 60)
print(f"Blue taxis correctly identified: {blue_identified_correctly:.0f}")
print(f"Green taxis misidentified: {green_misidentified_as_blue:.0f}")
print(f"\nTotal 'says blue': {total_says_blue:.0f}")
print(f"\nProbability actually blue: {blue_identified_correctly}/{total_says_blue:.0f} ‚âà {blue_identified_correctly/total_says_blue:.2%}")
print("\nüéØ MORE false positives than true positives!")
print("   This is why the probability is only ~41%!")

## üéÆ Interactive Exploration!

**Now it's your turn!** Explore how changing parameters affects the posterior.

**Try these scenarios:**
1. **Equal taxis**: Base rate = 0.50 ‚Üí What happens?
2. **Mostly blue**: Base rate = 0.85 ‚Üí Now what?
3. **Perfect witness**: Accuracy = 1.00 ‚Üí As expected?
4. **Worse witness**: Accuracy = 0.60 ‚Üí Still useful?

Watch how the **base rate** and **accuracy** interact!

In [None]:
# Create interactive controls
base_rate_slider = widgets.FloatSlider(
    value=0.15, min=0.01, max=0.99, step=0.01,
    description='P(Blue):',
    style={'description_width': '120px'}
)

accuracy_slider = widgets.FloatSlider(
    value=0.80, min=0.50, max=1.00, step=0.01,
    description='Accuracy:',
    style={'description_width': '120px'}
)

output_widget = widgets.Output()

def explore_bayes(base_rate_blue, accuracy):
    """Explore Bayesian updating with different parameters."""
    with output_widget:
        clear_output(wait=True)
        
        # Calculate using Bayes' theorem
        P_blue = base_rate_blue
        P_green = 1 - base_rate_blue
        P_says_blue_given_blue = accuracy
        P_says_blue_given_green = 1 - accuracy
        
        P_says_blue = (P_blue * P_says_blue_given_blue + 
                      P_green * P_says_blue_given_green)
        
        P_blue_given_says_blue = (P_says_blue_given_blue * P_blue) / P_says_blue
        
        # Simulate to verify
        n_sims = 10000
        key = jax.random.key(42)
        keys = jax.random.split(key, n_sims)
        
        results = jax.vmap(lambda k: run_scenario(k))(keys)
        
        def run_scenario(k):
            trace = taxicab_scenario.simulate(k, (base_rate_blue, accuracy))
            choices = trace.get_choices()
            return (int(choices['is_blue']), int(choices['says_blue']))
        
        results = jax.vmap(run_scenario)(keys)
        is_blue_sims, says_blue_sims = results
        
        says_blue_mask = (says_blue_sims == 1)
        both_mask = jnp.logical_and(is_blue_sims == 1, says_blue_sims == 1)
        P_blue_given_says_blue_sim = jnp.sum(both_mask) / jnp.sum(says_blue_mask)
        
        # Create visualization
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
        
        # Prior vs Posterior
        categories = ['Prior\nP(Blue)', 'Posterior\nP(Blue|Says Blue)']
        probs = [P_blue, P_blue_given_says_blue]
        colors_comp = ['#95a5a6', '#3498db']
        
        bars1 = ax1.bar(categories, probs, color=colors_comp, alpha=0.7, edgecolor='black')
        ax1.set_ylabel('Probability', fontsize=12, fontweight='bold')
        ax1.set_title('Prior vs Posterior', fontsize=13, fontweight='bold')
        ax1.set_ylim([0, 1])
        ax1.grid(axis='y', alpha=0.3)
        ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='50%')
        ax1.legend()
        
        for bar in bars1:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                    f'{height:.3f}',
                    ha='center', fontsize=12, fontweight='bold')
        
        # Belief update arrow
        change = P_blue_given_says_blue - P_blue
        arrow_color = '#2ecc71' if change > 0 else '#e74c3c'
        ax1.annotate('', xy=(1, P_blue_given_says_blue), xytext=(0, P_blue),
                    arrowprops=dict(arrowstyle='->', lw=2, color=arrow_color))
        
        # Break down of "says blue"
        true_pos = P_blue * P_says_blue_given_blue
        false_pos = P_green * P_says_blue_given_green
        
        labels_pie = ['True Positives\n(Blue ‚Üí Says Blue)', 
                     'False Positives\n(Green ‚Üí Says Blue)']
        sizes = [true_pos, false_pos]
        colors_pie = ['#3498db', '#e74c3c']
        explode = (0.05, 0.05)
        
        wedges, texts, autotexts = ax2.pie(sizes, labels=labels_pie, autopct='%1.1f%%',
                                           colors=colors_pie, explode=explode,
                                           shadow=True, startangle=90)
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontsize(11)
            autotext.set_fontweight('bold')
        ax2.set_title('Breakdown of "Says Blue"', fontsize=13, fontweight='bold')
        
        # Comparison chart
        comparison_data = {
            'Prior': P_blue,
            'Likelihood\n(Accuracy)': P_says_blue_given_blue,
            'Posterior': P_blue_given_says_blue
        }
        
        bars3 = ax3.bar(comparison_data.keys(), comparison_data.values(),
                       color=['#95a5a6', '#f39c12', '#3498db'],
                       alpha=0.7, edgecolor='black')
        ax3.set_ylabel('Probability', fontsize=12, fontweight='bold')
        ax3.set_title('Three Key Probabilities', fontsize=13, fontweight='bold')
        ax3.set_ylim([0, 1])
        ax3.grid(axis='y', alpha=0.3)
        
        for bar in bars3:
            height = bar.get_height()
            ax3.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                    f'{height:.3f}',
                    ha='center', fontsize=11, fontweight='bold')
        
        # Theory vs Simulation
        comparison_vals = [P_blue_given_says_blue, float(P_blue_given_says_blue_sim)]
        comparison_labels = ['Theory\n(Bayes)', 'Simulation\n(10k runs)']
        
        bars4 = ax4.bar(comparison_labels, comparison_vals,
                       color=['#9b59b6', '#1abc9c'],
                       alpha=0.7, edgecolor='black')
        ax4.set_ylabel('P(Blue | Says Blue)', fontsize=12, fontweight='bold')
        ax4.set_title('Verification: Theory vs Simulation', fontsize=13, fontweight='bold')
        ax4.set_ylim([0, 1])
        ax4.grid(axis='y', alpha=0.3)
        
        for bar in bars4:
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                    f'{height:.4f}',
                    ha='center', fontsize=11, fontweight='bold')
        
        plt.tight_layout()
        plt.show()
        
        # Print summary
        print("üìä Bayesian Update Summary")
        print("=" * 60)
        print(f"Base rate (prior): P(Blue) = {P_blue:.3f}")
        print(f"Accuracy: {accuracy:.3f}")
        print(f"\nEvidence: P(Says Blue) = {P_says_blue:.3f}")
        print(f"  - True positives: {true_pos:.3f}")
        print(f"  - False positives: {false_pos:.3f}")
        print(f"\nPosterior: P(Blue | Says Blue) = {P_blue_given_says_blue:.4f}")
        print(f"Simulation: {float(P_blue_given_says_blue_sim):.4f}")
        print(f"\nBelief change: {P_blue:.3f} ‚Üí {P_blue_given_says_blue:.3f}")
        print(f"Œî = {change:+.3f} ({change/P_blue*100:+.1f}%)")
        
        # Interpretation
        if abs(P_blue_given_says_blue - P_blue) < 0.01:
            print("\nüí° Evidence barely changed beliefs!")
        elif P_blue_given_says_blue > 0.5:
            print("\n‚úÖ Posterior > 50%: More likely blue than green!")
        else:
            print("\n‚ö†Ô∏è Posterior < 50%: Still more likely green!")
        
        if false_pos > true_pos:
            print("   üî¥ More false positives than true positives!")
            print("   This is why base rates matter SO MUCH.")

# Create interactive widget
interactive_bayes = widgets.interactive(
    explore_bayes,
    base_rate_blue=base_rate_slider,
    accuracy=accuracy_slider
)

display(interactive_bayes)
display(output_widget)

# Run initial
explore_bayes(0.15, 0.80)

## üß™ Exercise: Medical Testing

**Scenario**: A disease affects 1% of the population. A test is 99% accurate.

**You test positive**. What's the probability you have the disease?

Use the tools above to calculate!

In [None]:
# Medical test scenario
prevalence = 0.01  # 1% have disease
test_accuracy = 0.99  # 99% accurate

# TODO: Calculate P(Disease | Positive Test)
# Hint: Use Bayes' theorem like the taxicab problem!

# P_disease = ...
# P_positive_given_disease = ...
# P_positive_given_healthy = ...
# P_positive = ...
# P_disease_given_positive = ...

# print(f"P(Disease | Positive) = {P_disease_given_positive:.4f}")

<details>
<summary><b>üí° Click to see solution</b></summary>

```python
P_disease = 0.01
P_healthy = 0.99
P_positive_given_disease = 0.99  # True positive rate
P_positive_given_healthy = 0.01  # False positive rate (1 - accuracy)

# Total probability of testing positive
P_positive = (P_disease * P_positive_given_disease + 
              P_healthy * P_positive_given_healthy)

# Bayes' theorem
P_disease_given_positive = (P_positive_given_disease * P_disease) / P_positive

print(f"P(Disease | Positive) = {P_disease_given_positive:.4f}")
print(f"\nSurprising result: Only ~50%!")
print(f"Even with 99% accurate test!")
print(f"\nWhy? The disease is so rare (1%) that:")
print(f"  True positives: {P_disease * P_positive_given_disease:.4f}")
print(f"  False positives: {P_healthy * P_positive_given_healthy:.4f}")
print(f"\nAlmost equal amounts of true and false positives!")
```

**The lesson**: Base rates dominate! A 99% accurate test on a 1% disease gives 50/50 odds.
</details>

## üéì What You've Learned

Congratulations! You now understand:

‚úÖ **Bayes' theorem**: How to update beliefs with evidence  
‚úÖ **Base rate importance**: Why prior probabilities matter enormously  
‚úÖ **False positives**: How they can outnumber true positives  
‚úÖ **Posterior probability**: Combining prior beliefs with new evidence  
‚úÖ **Real-world applications**: Medical tests, witness testimony, and more  

**The KEY insight:**
> *Accuracy isn't everything! Base rates (how common something is) dramatically affect posterior probabilities. This is why rare diseases remain unlikely even after positive tests, and why common things stay common even with imperfect evidence.*

**Why this matters:**
- Medical testing: Don't panic from one positive result
- Justice: Witness testimony needs context
- Machine learning: Class imbalance affects predictions
- Daily life: Common explanations are usually correct

---

## üöÄ Next Steps

Ready to go deeper?
- **Tutorial 2**: Continuous distributions and Gaussian processes
- **GenJAX Tutorial**: Build sophisticated probabilistic models
- **Practice**: Find real-world scenarios to apply Bayes' theorem!

---

**Remember**: Your intuition about probabilities is often wrong. Trust the math, run the simulations, and always consider base rates! üéØ