# Chibany's Meals: Probability Through Simulation

This interactive notebook lets you explore probability by simulating Chibany's daily meals!

**What you'll do:**
- üé≤ Simulate thousands of random meal combinations
- üìä See how simulation matches theory
- üéÆ Use interactive sliders to explore different probabilities
- üßÆ Calculate probabilities by counting

**No coding experience needed!** Just run each cell (Shift+Enter) and have fun!

---

## üöÄ Step 1: Setup (Run This First!)

This installs GenJAX and loads the tools we need. It takes about 1-2 minutes the first time.

In [None]:
# Install required packages (only needed once)
!pip install genjax ipywidgets matplotlib numpy -q

print("‚úÖ Installation complete!")

In [None]:
# Import libraries
import jax
import jax.numpy as jnp
from genjax import gen, bernoulli
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np

# Configure JAX
jax.config.update('jax_enable_x64', True)

print("‚úÖ Libraries loaded! Ready to simulate Chibany's meals!")

## üç± Step 2: Define Chibany's Meal Model

Remember from the tutorial:
- **Outcome space**: $\Omega = \{HH, HT, TH, TT\}$
- Each meal is either **H**amburger or **T**onkatsu
- We'll simulate this process thousands of times!

Here's the GenJAX code that does it:

In [None]:
@gen
def chibany_day(lunch_prob=0.5, dinner_prob=0.5):
    """
    Simulate one day of Chibany's meals.
    
    Returns:
        (lunch, dinner) where 0=Hamburger, 1=Tonkatsu
    """
    lunch = bernoulli(lunch_prob) @ "lunch"
    dinner = bernoulli(dinner_prob) @ "dinner"
    return (lunch, dinner)

# Map outcomes to readable names
OUTCOME_NAMES = {
    (0, 0): "HH (Hamburger, Hamburger)",
    (0, 1): "HT (Hamburger, Tonkatsu)",
    (1, 0): "TH (Tonkatsu, Hamburger)",
    (1, 1): "TT (Tonkatsu, Tonkatsu)"
}

print("‚úÖ Meal model defined!")
print("\nüí° Try it once:")

# Generate one day
key = jax.random.key(42)
trace = chibany_day.simulate(key, (0.5, 0.5))
meals = trace.get_retval()
outcome = OUTCOME_NAMES[tuple(int(m) for m in meals)]
print(f"Today's outcome: {outcome}")

## üéÆ Step 3: Interactive Simulator!

**Now for the fun part!** Use the sliders to explore different scenarios:

**Try these experiments:**
1. **Equal probability**: Both at 0.50 ‚Üí All outcomes equally likely
2. **Chibany loves tonkatsu**: Both at 0.80 ‚Üí TT becomes most common
3. **Different meals**: Lunch 0.80, Dinner 0.20 ‚Üí Watch the distribution!
4. **Extreme case**: Lunch 1.00, Dinner 0.00 ‚Üí Only TH possible!

Watch how the **observed counts** (what actually happened in simulations) match the **theoretical probabilities** (what math predicts)!

In [None]:
# Create interactive controls
lunch_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05,
    description='P(Lunch=T):',
    style={'description_width': '120px'},
    continuous_update=False
)

dinner_slider = widgets.FloatSlider(
    value=0.5, min=0.0, max=1.0, step=0.05,
    description='P(Dinner=T):',
    style={'description_width': '120px'},
    continuous_update=False
)

n_sims_slider = widgets.SelectionSlider(
    options=[100, 1000, 5000, 10000, 50000],
    value=10000,
    description='# Simulations:',
    style={'description_width': '120px'},
    continuous_update=False
)

output_widget = widgets.Output()

def simulate_and_visualize(lunch_prob, dinner_prob, n_sims):
    """Run simulations and create visualization."""
    with output_widget:
        clear_output(wait=True)
        
        # Generate simulations
        key = jax.random.key(42)
        keys = jax.random.split(key, n_sims)
        
        def run_one_day(k):
            trace = chibany_day.simulate(k, (lunch_prob, dinner_prob))
            return trace.get_retval()
        
        days = jax.vmap(run_one_day)(keys)
        lunch_results, dinner_results = days
        
        # Count outcomes
        counts = {
            'HH': int(jnp.sum((lunch_results == 0) & (dinner_results == 0))),
            'HT': int(jnp.sum((lunch_results == 0) & (dinner_results == 1))),
            'TH': int(jnp.sum((lunch_results == 1) & (dinner_results == 0))),
            'TT': int(jnp.sum((lunch_results == 1) & (dinner_results == 1)))
        }
        
        # Calculate theoretical probabilities
        theory = {
            'HH': (1-lunch_prob) * (1-dinner_prob),
            'HT': (1-lunch_prob) * dinner_prob,
            'TH': lunch_prob * (1-dinner_prob),
            'TT': lunch_prob * dinner_prob
        }
        
        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # Bar chart comparing observed vs theoretical
        outcomes = list(counts.keys())
        observed_probs = [counts[o] / n_sims for o in outcomes]
        theoretical_probs = [theory[o] for o in outcomes]
        
        x = np.arange(len(outcomes))
        width = 0.35
        
        bars1 = ax1.bar(x - width/2, observed_probs, width, 
                       label='Observed', color='#4ecdc4', alpha=0.8)
        bars2 = ax1.bar(x + width/2, theoretical_probs, width,
                       label='Theory', color='#ff6b6b', alpha=0.6)
        
        ax1.set_xlabel('Outcome', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Probability', fontsize=12, fontweight='bold')
        ax1.set_title(f'Observed vs Theoretical Probabilities\n({n_sims:,} simulations)', 
                     fontsize=13, fontweight='bold')
        ax1.set_xticks(x)
        ax1.set_xticklabels(outcomes, fontsize=11)
        ax1.legend(fontsize=10)
        ax1.grid(axis='y', alpha=0.3)
        ax1.set_ylim([0, max(max(observed_probs), max(theoretical_probs)) * 1.15])
        
        # Add value labels
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.3f}',
                        ha='center', va='bottom', fontsize=9)
        
        # Pie chart of actual counts
        colors = ['#95a5a6', '#3498db', '#e74c3c', '#2ecc71']
        explode = [0.05] * 4
        
        wedges, texts, autotexts = ax2.pie(
            [counts[o] for o in outcomes],
            labels=outcomes,
            autopct='%1.1f%%',
            colors=colors,
            explode=explode,
            shadow=True,
            startangle=90
        )
        
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontsize(11)
            autotext.set_fontweight('bold')
        
        for text in texts:
            text.set_fontsize(12)
            text.set_fontweight('bold')
        
        ax2.set_title(f'Distribution of Outcomes\n({n_sims:,} total)', 
                     fontsize=13, fontweight='bold')
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print("\nüìä Results Summary:")
        print("=" * 60)
        print(f"{'Outcome':<10} {'Observed':<15} {'Theoretical':<15} {'Difference'}")
        print("-" * 60)
        
        for outcome in outcomes:
            obs = counts[outcome] / n_sims
            theo = theory[outcome]
            diff = abs(obs - theo)
            print(f"{outcome:<10} {obs:<15.4f} {theo:<15.4f} {diff:.4f}")
        
        # Calculate P(at least one tonkatsu)
        has_tonkatsu = jnp.logical_or(lunch_results, dinner_results)
        obs_prob = float(jnp.mean(has_tonkatsu))
        theo_prob = 1 - (1-lunch_prob) * (1-dinner_prob)
        
        print("\n" + "=" * 60)
        print("üéØ P(At least one Tonkatsu):")
        print(f"   Observed:    {obs_prob:.4f}")
        print(f"   Theoretical: {theo_prob:.4f}")
        print(f"   Difference:  {abs(obs_prob - theo_prob):.4f}")
        print("\nüí° As simulations increase, observed ‚Üí theoretical!")

# Create interactive widget
interactive_sim = widgets.interactive(
    simulate_and_visualize,
    lunch_prob=lunch_slider,
    dinner_prob=dinner_slider,
    n_sims=n_sims_slider
)

display(interactive_sim)
display(output_widget)

# Run initial simulation
simulate_and_visualize(0.5, 0.5, 10000)

## üßÆ Step 4: Calculate Specific Probabilities

Let's calculate the probability of specific events by filtering simulations.

**Example question**: What's P(at least one tonkatsu)?

In [None]:
# Generate 10,000 days with equal probabilities
n_simulations = 10000
lunch_p = 0.5
dinner_p = 0.5

key = jax.random.key(42)
keys = jax.random.split(key, n_simulations)

def run_one_day(k):
    trace = chibany_day.simulate(k, (lunch_p, dinner_p))
    return trace.get_retval()

days = jax.vmap(run_one_day)(keys)
lunch_results, dinner_results = days

# Event A: At least one tonkatsu
event_A = jnp.logical_or(lunch_results == 1, dinner_results == 1)
prob_A = jnp.mean(event_A)

# Event B: Both are tonkatsu
event_B = jnp.logical_and(lunch_results == 1, dinner_results == 1)
prob_B = jnp.mean(event_B)

# Event C: Exactly one tonkatsu
event_C = jnp.logical_xor(lunch_results == 1, dinner_results == 1)
prob_C = jnp.mean(event_C)

print("üìä Event Probabilities (simulated):")
print("=" * 50)
print(f"P(At least one T) = {prob_A:.4f} (Theory: {1 - 0.5*0.5:.4f})")
print(f"P(Both are T)     = {prob_B:.4f} (Theory: {0.5*0.5:.4f})")
print(f"P(Exactly one T)  = {prob_C:.4f} (Theory: {2*0.5*0.5:.4f})")
print("\n‚ú® Notice how simulation matches the theory!")

## üéØ Exercise 1: Your Turn!

Calculate the probability that **lunch and dinner are the same** (either both H or both T).

**Hint**: Use the simulation results above. The event is: (lunch==0 AND dinner==0) OR (lunch==1 AND dinner==1)

In [None]:
# Your code here!
# event_same_meals = ...
# prob_same = ...

# Uncomment to test:
# print(f"P(Same meals) = {prob_same:.4f}")
# print(f"Theory: {0.5*0.5 + 0.5*0.5:.4f}  (HH or TT)")

<details>
<summary><b>üí° Click to see solution</b></summary>

```python
event_same_meals = jnp.logical_or(
    jnp.logical_and(lunch_results == 0, dinner_results == 0),  # Both H
    jnp.logical_and(lunch_results == 1, dinner_results == 1)   # Both T
)
prob_same = jnp.mean(event_same_meals)

print(f"P(Same meals) = {prob_same:.4f}")
print(f"Theory: {0.5*0.5 + 0.5*0.5:.4f}  (HH or TT)")
```

The answer should be around 0.50 (50%)!
</details>

## üöÄ Exercise 2: Different Probabilities

**Scenario**: Chibany gets tonkatsu 70% of the time for lunch but only 30% for dinner.

Use the code below to answer:
1. What's the most likely outcome?
2. What's P(exactly one tonkatsu)?
3. Does the simulation match theory?

In [None]:
# Generate simulations with different probabilities
lunch_p = 0.7
dinner_p = 0.3
n_sims = 10000

key = jax.random.key(123)
keys = jax.random.split(key, n_sims)

days = jax.vmap(lambda k: chibany_day.simulate(k, (lunch_p, dinner_p)).get_retval())(keys)
lunch_results, dinner_results = days

# Count outcomes
counts = {
    'HH': int(jnp.sum((lunch_results == 0) & (dinner_results == 0))),
    'HT': int(jnp.sum((lunch_results == 0) & (dinner_results == 1))),
    'TH': int(jnp.sum((lunch_results == 1) & (dinner_results == 0))),
    'TT': int(jnp.sum((lunch_results == 1) & (dinner_results == 1)))
}

print("üìä Outcome Counts:")
for outcome, count in counts.items():
    prob_observed = count / n_sims
    print(f"{outcome}: {count:>5} ({prob_observed:.3f})")

print("\n‚ùì Questions:")
print("1. Which outcome is most common?")
print("2. Calculate P(exactly one tonkatsu) from the counts")
print("3. What should the theoretical probability be? Does it match?")

<details>
<summary><b>üí° Click to see answers</b></summary>

**Answers:**

1. **Most likely outcome**: TH (Tonkatsu for lunch, Hamburger for dinner)
   - Probability = 0.7 √ó 0.7 = 0.49 (49%)

2. **P(exactly one tonkatsu)**:
   - Count(HT) + Count(TH) = observed probability
   - Should be around (0.3 √ó 0.3) + (0.7 √ó 0.7) = 0.09 + 0.49 = 0.58

3. **Theory vs Simulation**:
   - HH: 0.3 √ó 0.7 = 0.21
   - HT: 0.3 √ó 0.3 = 0.09
   - TH: 0.7 √ó 0.7 = 0.49
   - TT: 0.7 √ó 0.3 = 0.21
   - With 10,000 simulations, observed should match theory closely!
</details>

## üéì What You've Learned

Congratulations! In this notebook, you:

‚úÖ **Simulated probability**: Generated thousands of outcomes computationally  
‚úÖ **Verified theory**: Saw how simulation matches mathematical predictions  
‚úÖ **Explored interactively**: Used sliders to build intuition  
‚úÖ **Calculated events**: Filtered simulations to find probabilities  
‚úÖ **Connected concepts**: Linked set-based thinking to code  

**The key insight**: 
> *Probability as counting becomes probability as simulation. The math is the same, but now we can handle complex problems that are impossible to count by hand!*

---

## üöÄ Next Steps

Ready for more? Try these notebooks:
- **`conditioning.ipynb`**: What if you observe something?
- **`bayesian_learning.ipynb`**: Update beliefs with evidence
- **GenJAX Tutorial Chapter 3**: Understanding traces

---

**Got questions?** Re-read the tutorial chapters and experiment with the sliders. The best way to learn probability is to *play* with it! üéÆ