# Your First GenJAX Model: Chibany's Meals

This notebook accompanies Chapter 2 of the GenJAX tutorial. You'll:
- Write your first generative function
- Simulate thousands of random outcomes
- Use interactive sliders to explore probability
- See visualizations update automatically

**No coding experience needed!** Just run each cell and play with the sliders.

---

## Step 1: Install GenJAX

This takes about 1-2 minutes the first time. You'll see lots of output â€” that's normal!

In [None]:
# Install GenJAX and required packages
!pip install genjax ipywidgets matplotlib -q

## Step 2: Import Libraries

Load the tools we'll use. Think of this as getting your ingredients ready before cooking.

In [None]:
import jax
import jax.numpy as jnp
from genjax import gen, bernoulli
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Set random seed for reproducibility
jax.config.update('jax_enable_x64', True)

print("âœ… All libraries loaded successfully!")

## Step 3: Your First Generative Function

Here's Chibany's daily meals in GenJAX code!

**What it does:**
- Flips a coin for lunch (0 = Hamburger, 1 = Tonkatsu)
- Flips a coin for dinner (0 = Hamburger, 1 = Tonkatsu)
- Returns both meals as a pair

In [None]:
@gen
def chibany_day(lunch_prob=0.5, dinner_prob=0.5):
    """Generate one day of Chibany's meals.
    
    Args:
        lunch_prob: Probability that lunch is tonkatsu (default 0.5)
        dinner_prob: Probability that dinner is tonkatsu (default 0.5)
    
    Returns:
        Tuple of (lunch_is_tonkatsu, dinner_is_tonkatsu)
    """
    # Lunch: flip a coin (0=Hamburger, 1=Tonkatsu)
    lunch_is_tonkatsu = bernoulli(lunch_prob) @ "lunch"
    
    # Dinner: flip another coin
    dinner_is_tonkatsu = bernoulli(dinner_prob) @ "dinner"
    
    # Return the pair
    return (lunch_is_tonkatsu, dinner_is_tonkatsu)

print("âœ… Generative function defined!")

## Step 4: Test It Out!

Let's generate one day and see what Chibany gets for his meals.

In [None]:
# Create a random key (JAX requirement for randomness)
key = jax.random.key(42)

# Generate one day
trace = chibany_day.simulate(key, (0.5, 0.5))

# What happened?
meals = trace.get_retval()
choices = trace.get_choices()

print(f"Today's meals: {meals}")
print(f"  Lunch was tonkatsu: {choices['lunch']}")
print(f"  Dinner was tonkatsu: {choices['dinner']}")

# Decode the outcome
outcome_map = {(0, 0): "HH", (0, 1): "HT", (1, 0): "TH", (1, 1): "TT"}
outcome = outcome_map[tuple(meals)]
print(f"\nOutcome: {outcome}")

## Step 5: Interactive Exploration! ðŸŽ®

**This is where it gets fun!** Use the sliders below to:
- Change the probability of tonkatsu for lunch and dinner
- Adjust the number of simulations
- Watch the chart update automatically!

**Try this:**
1. Set lunch probability to 0.8 (80% tonkatsu)
2. Set dinner probability to 0.2 (20% tonkatsu)
3. Run 10,000 simulations
4. What do you notice about the distribution?

In [None]:
# Create sliders
lunch_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Lunch P(T):',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

dinner_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Dinner P(T):',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

n_sims_slider = widgets.SelectionSlider(
    options=[100, 1000, 10000, 100000],
    value=10000,
    description='Simulations:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
)

# Output widget for the plot
output = widgets.Output()

def update_plot(lunch_prob, dinner_prob, n_sims):
    """Run simulations and update the visualization."""
    with output:
        clear_output(wait=True)
        
        # Generate random keys
        key = jax.random.key(42)
        keys = jax.random.split(key, n_sims)
        
        # Run simulations
        def run_one_day(k):
            trace = chibany_day.simulate(k, (lunch_prob, dinner_prob))
            return trace.get_retval()
        
        # Use JAX's vmap for parallel execution
        days = jax.vmap(run_one_day)(keys)
        
        # Count each outcome
        HH = jnp.sum((days[:, 0] == 0) & (days[:, 1] == 0))
        HT = jnp.sum((days[:, 0] == 0) & (days[:, 1] == 1))
        TH = jnp.sum((days[:, 0] == 1) & (days[:, 1] == 0))
        TT = jnp.sum((days[:, 0] == 1) & (days[:, 1] == 1))
        
        # Calculate theoretical probabilities
        theory_HH = (1 - lunch_prob) * (1 - dinner_prob) * n_sims
        theory_HT = (1 - lunch_prob) * dinner_prob * n_sims
        theory_TH = lunch_prob * (1 - dinner_prob) * n_sims
        theory_TT = lunch_prob * dinner_prob * n_sims
        
        # Create bar chart
        outcomes = ['HH', 'HT', 'TH', 'TT']
        observed_counts = [int(HH), int(HT), int(TH), int(TT)]
        theoretical_counts = [theory_HH, theory_HT, theory_TH, theory_TT]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        x = range(len(outcomes))
        width = 0.35
        
        # Plot observed and theoretical
        bars1 = ax.bar([i - width/2 for i in x], observed_counts, width, 
                       label='Observed', color='#4ecdc4', alpha=0.8)
        bars2 = ax.bar([i + width/2 for i in x], theoretical_counts, width,
                       label='Theoretical', color='#ff6b6b', alpha=0.6)
        
        ax.set_xlabel('Outcome', fontsize=12)
        ax.set_ylabel(f'Count (out of {n_sims:,})', fontsize=12)
        ax.set_title(f"Chibany's Meals: {n_sims:,} Simulated Days\n" +
                    f"Lunch P(Tonkatsu)={lunch_prob:.2f}, Dinner P(Tonkatsu)={dinner_prob:.2f}",
                    fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(outcomes)
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{int(height)}',
                       ha='center', va='bottom', fontsize=9)
        
        plt.tight_layout()
        plt.show()
        
        # Calculate and display probability of at least one tonkatsu
        has_tonkatsu = jnp.logical_or(days[:, 0], days[:, 1])
        prob_tonkatsu = jnp.mean(has_tonkatsu)
        theory_prob = 1 - (1 - lunch_prob) * (1 - dinner_prob)
        
        print(f"\nðŸ“Š Probability of at least one tonkatsu:")
        print(f"   Observed:    {prob_tonkatsu:.4f} ({int(jnp.sum(has_tonkatsu))}/{n_sims})")
        print(f"   Theoretical: {theory_prob:.4f}")
        print(f"   Difference:  {abs(prob_tonkatsu - theory_prob):.4f}")

# Create interactive widget
interactive = widgets.interactive(
    update_plot,
    lunch_prob=lunch_slider,
    dinner_prob=dinner_slider,
    n_sims=n_sims_slider
)

# Display
display(interactive)
display(output)

# Run initial plot
update_plot(0.5, 0.5, 10000)

## Connection to Set-Based Probability

Remember from the probability tutorial:

| Set-Based Concept | GenJAX Equivalent |
|-------------------|-------------------|
| Outcome space $\Omega$ | Running `simulate()` many times |
| One outcome $\omega$ | One call to `simulate()` |
| Event $A \subseteq \Omega$ | Filtering simulations |
| $\|A\|$ (count elements) | `jnp.sum(condition)` |
| $P(A) = \|A\|/\|\Omega\|$ | `jnp.mean(condition)` |

**It's the same concept!** Just computed instead of counted by hand.

## Exercises

Try these to deepen your understanding!

### Exercise 1: Different Probabilities

Using the sliders above:
- Set lunch probability to 0.7 (70% tonkatsu)
- Set dinner probability to 0.3 (30% tonkatsu)
- Run 10,000 simulations

**Questions:**
1. Which outcome is most common?
2. Why does this make sense?
3. How close are the observed counts to the theoretical counts?

### Exercise 2: Counting Total Tonkatsu

Write code to count **how many tonkatsu** Chibany gets across all days (not just which days have tonkatsu, but the total count).

In [None]:
# Your code here!
# Hint: Generate days using the code from above, then sum days[:, 0] and days[:, 1]

# Generate 10,000 days with equal probabilities
key = jax.random.key(42)
keys = jax.random.split(key, 10000)

def run_one_day(k):
    trace = chibany_day.simulate(k, (0.5, 0.5))
    return trace.get_retval()

days = jax.vmap(run_one_day)(keys)

# TODO: Calculate total tonkatsu and average per day
# total_tonkatsu = ...
# avg_per_day = ...

<details>
<summary><b>Click to see solution</b></summary>

```python
total_tonkatsu = jnp.sum(days[:, 0]) + jnp.sum(days[:, 1])
avg_per_day = total_tonkatsu / len(days)

print(f"Total tonkatsu: {total_tonkatsu}")
print(f"Average per day: {avg_per_day:.2f}")
```

With equal probabilities (0.5 each), you should get close to 1.0 tonkatsu per day on average!
</details>

### Exercise 3: Three Meals?

Extend the model to include breakfast! Modify the `chibany_day` function below.

In [None]:
@gen
def chibany_three_meals(breakfast_prob=0.5, lunch_prob=0.5, dinner_prob=0.5):
    """Generate one day of Chibany's meals including breakfast."""
    
    # TODO: Add breakfast!
    # breakfast_is_tonkatsu = ...
    
    lunch_is_tonkatsu = bernoulli(lunch_prob) @ "lunch"
    dinner_is_tonkatsu = bernoulli(dinner_prob) @ "dinner"
    
    # TODO: Return all three meals
    # return (...)
    
    pass

<details>
<summary><b>Click to see solution</b></summary>

```python
@gen
def chibany_three_meals(breakfast_prob=0.5, lunch_prob=0.5, dinner_prob=0.5):
    """Generate one day of Chibany's meals including breakfast."""
    breakfast_is_tonkatsu = bernoulli(breakfast_prob) @ "breakfast"
    lunch_is_tonkatsu = bernoulli(lunch_prob) @ "lunch"
    dinner_is_tonkatsu = bernoulli(dinner_prob) @ "dinner"
    return (breakfast_is_tonkatsu, lunch_is_tonkatsu, dinner_is_tonkatsu)
```

Now the outcome space has $2^3 = 8$ possible outcomes!
</details>

## What You've Learned

In this notebook, you:

âœ… Wrote your first generative function  
âœ… Simulated thousands of random outcomes  
âœ… Used interactive sliders to explore probability  
âœ… Saw how simulation matches theoretical predictions  
âœ… Understood the connection between sets and simulation  
âœ… Learned about traces and random choices  

**The key insight:** Generative functions let computers do what you did by hand with sets â€” but now you can handle millions of possibilities!

---

## Next Steps

Ready for more? Check out:
- **Chapter 3: Understanding Traces** - How GenJAX records random choices
- **Chapter 4: Conditioning and Observations** - What if I observe something?
- **Chapter 5: Inference in Action** - The taxicab problem with code!

---

**Questions? Stuck?** That's normal! Go back to the tutorial text and re-read the explanations. Try changing the slider values to build intuition.