In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

"""
RESTAURANT-SPECIFIC OUTBREAK MODEL

Key differences from generic SEIRX:
1. Staff work in SHIFTS (not continuous exposure)
2. Patrons have SHORT, ONE-TIME exposures (not repeated)
3. Food handlers have HIGHER transmission than other staff
4. Contamination events are EXPLICIT (food/surface contamination)

This is mechanistic but restaurant-appropriate.
"""

def simulate_restaurant_outbreak(
    # Staff parameters
    n_food_handlers=5,
    n_other_staff=5,
    init_infected=1,
    shift_hours=8,
    shifts_per_day=2,  # Lunch and dinner

    # Disease parameters
    latent_period=1.0,  # Days until infectious
    infectious_period=3.0,  # Days infectious
    prob_symptomatic=0.7,

    # Transmission rates
    beta_staff_staff=0.1,  # Per contact per day among staff
    beta_handler_patron=0.02,  # Per patron served by infectious food handler
    beta_other_patron=0.001,  # Per patron contact with other staff (much lower)

    # Restaurant operations
    patrons_per_shift=150,
    patrons_per_handler=30,  # Each handler serves ~30 patrons per shift

    # Contamination
    prob_food_contamination=0.10,  # If handler is infectious
    contamination_attack_rate=0.30,  # Fraction of exposed patrons who get sick

    # Simulation
    max_days=5
):
    """
    Mechanistic restaurant outbreak model.

    Staff dynamics:
    - Food handlers work shifts, can infect patrons directly
    - Latent period before becoming infectious
    - Some symptomatic (may stay home), some asymptomatic (keep working)

    Patron dynamics:
    - Single visit exposure
    - Can be infected by: (1) infectious staff contact, (2) contaminated food

    Returns outbreak size.
    """

    # Initialize staff
    total_staff = n_food_handlers + n_other_staff

    # Staff states: 'S', 'E', 'Ia' (asymptomatic infectious), 'Is' (symptomatic), 'R'
    staff_states = ['S'] * total_staff
    staff_infection_day = [None] * total_staff  # Day they became infected
    staff_is_handler = [True] * n_food_handlers + [False] * n_other_staff

    # Randomly infect initial staff
    initial_infected_idx = np.random.choice(total_staff, size=init_infected, replace=False)
    for idx in initial_infected_idx:
        staff_states[idx] = 'E'
        staff_infection_day[idx] = 0.0

    # Track infections
    total_staff_infected = init_infected
    total_patron_infected = 0

    # Simulate day by day
    for day in range(max_days):

        # Update staff disease progression
        for i in range(total_staff):
            if staff_states[i] == 'E':
                # Check if latent period over
                if day - staff_infection_day[i] >= latent_period:
                    # Become infectious
                    if np.random.rand() < prob_symptomatic:
                        staff_states[i] = 'Is'  # Symptomatic - might stay home
                    else:
                        staff_states[i] = 'Ia'  # Asymptomatic - definitely works

            elif staff_states[i] in ['Ia', 'Is']:
                # Check if infectious period over
                if day - staff_infection_day[i] >= latent_period + infectious_period:
                    staff_states[i] = 'R'

        # Staff-to-staff transmission (happens during shift)
        infectious_staff = [i for i in range(total_staff)
                           if staff_states[i] in ['Ia', 'Is']]
        susceptible_staff = [i for i in range(total_staff)
                            if staff_states[i] == 'S']

        if infectious_staff and susceptible_staff:
            # Each susceptible staff has contact with infectious staff during shift
            for s_idx in susceptible_staff:
                for i_idx in infectious_staff:
                    if staff_states[i_idx] == 'Is':
                        # Symptomatic might not come to work (50% chance stay home)
                        if np.random.rand() < 0.5:
                            continue

                    # Transmission probability
                    if np.random.rand() < beta_staff_staff:
                        staff_states[s_idx] = 'E'
                        staff_infection_day[s_idx] = day + np.random.uniform(0, 1)
                        total_staff_infected += 1
                        break

        # Patron infections - happens during each shift
        for shift in range(shifts_per_day):

            # Determine which staff are working and infectious
            infectious_handlers_working = []
            infectious_other_working = []

            for i in range(total_staff):
                if staff_states[i] in ['Ia', 'Is']:
                    # Symptomatic have 50% chance of staying home
                    if staff_states[i] == 'Is' and np.random.rand() < 0.5:
                        continue

                    if staff_is_handler[i]:
                        infectious_handlers_working.append(i)
                    else:
                        infectious_other_working.append(i)

            # Patron infections from direct transmission
            # Each infectious food handler serves patrons_per_handler patrons
            for handler_idx in infectious_handlers_working:
                n_patrons_served = patrons_per_handler
                infections = np.random.binomial(n_patrons_served, beta_handler_patron)
                total_patron_infected += infections

            # Patrons exposed to other infectious staff (casual contact)
            if infectious_other_working:
                # All patrons have some contact with non-handler staff
                infections = np.random.binomial(patrons_per_shift,
                                               beta_other_patron * len(infectious_other_working))
                total_patron_infected += infections

            # Food contamination event
            if infectious_handlers_working:
                # Check if contamination occurs
                if np.random.rand() < prob_food_contamination:
                    # A batch of food is contaminated
                    # Affects a fraction of patrons this shift
                    n_exposed = int(patrons_per_shift * 0.5)  # ~half eat the contaminated item
                    n_infected = np.random.binomial(n_exposed, contamination_attack_rate)
                    total_patron_infected += n_infected

    total_outbreak = total_staff_infected + total_patron_infected

    return total_outbreak


def calibrate_restaurant_model(real_sizes):
    """
    Calibrate the restaurant-specific model.
    """
    print("="*70)
    print("RESTAURANT OUTBREAK MODEL CALIBRATION")
    print("="*70)
    print("\nModel structure:")
    print("  - Food handlers work shifts, serve patrons")
    print("  - Staff-to-staff transmission during shifts")
    print("  - Direct transmission: handler → patron")
    print("  - Contamination events: contaminated food → many patrons")
    print()

    # Calibration grid - focus on key parameters
    beta_handler_patron_range = np.linspace(0.005, 0.03, 6)
    prob_contamination_range = np.linspace(0.05, 0.25, 5)
    contamination_attack_range = np.linspace(0.2, 0.5, 4)

    best_score = np.inf
    best_params = None
    best_sim = None

    total_iters = (len(beta_handler_patron_range) *
                   len(prob_contamination_range) *
                   len(contamination_attack_range))

    pbar = tqdm(total=total_iters, desc="Calibrating")

    for beta_hp in beta_handler_patron_range:
        for p_contam in prob_contamination_range:
            for contam_ar in contamination_attack_range:

                # Run simulations
                sim_sizes = np.array([
                    simulate_restaurant_outbreak(
                        beta_handler_patron=beta_hp,
                        prob_food_contamination=p_contam,
                        contamination_attack_rate=contam_ar
                    )
                    for _ in range(200)
                ])

                # Score with weighted percentiles
                real_pct = np.percentile(real_sizes, [25, 50, 75, 90, 95, 99])
                sim_pct = np.percentile(sim_sizes, [25, 50, 75, 90, 95, 99])
                weights = np.array([1, 2, 1, 2, 3, 5])
                score = np.average(np.abs(real_pct - sim_pct), weights=weights)

                if score < best_score:
                    best_score = score
                    best_params = {
                        'beta_handler_patron': beta_hp,
                        'prob_food_contamination': p_contam,
                        'contamination_attack_rate': contam_ar,
                        'beta_staff_staff': 0.1,
                        'n_food_handlers': 5,
                        'n_other_staff': 5,
                        'patrons_per_shift': 150,
                        'patrons_per_handler': 30
                    }
                    best_sim = sim_sizes

                pbar.update(1)

    pbar.close()

    print("\n" + "="*70)
    print("CALIBRATION RESULTS")
    print("="*70)
    print(f"\nBest parameters:")
    print(f"  β (handler→patron) = {best_params['beta_handler_patron']:.4f}")
    print(f"  P(food contamination) = {best_params['prob_food_contamination']:.3f}")
    print(f"  Contamination attack rate = {best_params['contamination_attack_rate']:.3f}")
    print(f"  Score = {best_score:.3f}")

    # Compare statistics
    print(f"\n{'Metric':<20} {'Real':<12} {'Simulated':<12} {'Difference'}")
    print("-"*60)

    metrics = [
        ('Mean', np.mean),
        ('Median', np.median),
        ('Std', np.std),
        ('75th pct', lambda x: np.percentile(x, 75)),
        ('90th pct', lambda x: np.percentile(x, 90)),
        ('95th pct', lambda x: np.percentile(x, 95)),
        ('99th pct', lambda x: np.percentile(x, 99))
    ]

    for name, func in metrics:
        real_val = func(real_sizes)
        sim_val = func(best_sim)
        diff = abs(real_val - sim_val)
        print(f"{name:<20} {real_val:<12.1f} {sim_val:<12.1f} {diff:.1f}")

    # Plot results
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # Histogram
    ax = axes[0, 0]
    ax.hist(real_sizes, bins=40, alpha=0.6, label='Real NORS',
            density=True, edgecolor='black', color='blue')
    ax.hist(best_sim, bins=40, alpha=0.6, label='Simulated',
            density=True, edgecolor='black', color='green')
    ax.set_xlabel('Outbreak size', fontsize=11)
    ax.set_ylabel('Density', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_title('Distribution Comparison', fontsize=12, fontweight='bold')
    ax.set_xlim(0, 300)
    ax.grid(True, alpha=0.3)

    # CDF
    ax = axes[0, 1]
    ax.plot(np.sort(real_sizes), np.linspace(0, 1, len(real_sizes)),
            label='Real', lw=2, color='blue')
    ax.plot(np.sort(best_sim), np.linspace(0, 1, len(best_sim)),
            label='Simulated', lw=2, ls='--', color='green')
    ax.set_xlabel('Outbreak size', fontsize=11)
    ax.set_ylabel('Cumulative probability', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_title('Cumulative Distribution', fontsize=12, fontweight='bold')
    ax.set_xlim(0, 300)
    ax.grid(True, alpha=0.3)

    # Q-Q Plot
    ax = axes[1, 0]
    n = min(len(real_sizes), len(best_sim))
    real_sorted = np.sort(real_sizes[:n])
    sim_sorted = np.sort(best_sim[:n])
    ax.scatter(real_sorted, sim_sorted, alpha=0.5, s=20, color='green')
    max_val = max(real_sorted.max(), sim_sorted.max())
    ax.plot([0, max_val], [0, max_val], 'r--', lw=2, label='Perfect fit')
    ax.set_xlabel('Real outbreak size', fontsize=11)
    ax.set_ylabel('Simulated outbreak size', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_title('Q-Q Plot', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)

    # Percentile comparison
    ax = axes[1, 1]
    percentiles = [25, 50, 75, 90, 95, 99]
    real_pct = np.percentile(real_sizes, percentiles)
    sim_pct = np.percentile(best_sim, percentiles)
    x = np.arange(len(percentiles))
    width = 0.35
    ax.bar(x - width/2, real_pct, width, label='Real',
           alpha=0.7, color='blue', edgecolor='black')
    ax.bar(x + width/2, sim_pct, width, label='Simulated',
           alpha=0.7, color='green', edgecolor='black')
    ax.set_xticks(x)
    ax.set_xticklabels([f'{p}th' for p in percentiles], rotation=45)
    ax.set_ylabel('Outbreak size', fontsize=11)
    ax.legend(fontsize=10)
    ax.set_title('Percentile Comparison', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('restaurant_model_calibration.png', dpi=300, bbox_inches='tight')
    plt.show()

    return best_params, best_sim


def main():
    # Load data
    try:
        nors = pd.read_csv("NORS.csv", header=None)
        real_sizes = nors[0].dropna().astype(int).values
    except FileNotFoundError:
        print("ERROR: NORS.csv not found!")
        return

    print(f"Loaded {len(real_sizes)} outbreaks")
    print(f"Range: [{real_sizes.min()}, {real_sizes.max()}]")
    print(f"Mean: {np.mean(real_sizes):.1f}")
    print(f"Median: {np.median(real_sizes):.1f}")
    print()

    # Calibrate
    best_params, best_sim = calibrate_restaurant_model(real_sizes)

    # Save parameters
    params_df = pd.DataFrame([best_params])
    params_df.to_csv('calibrated_restaurant_params.csv', index=False)

    print("\n" + "="*70)
    print("MODEL INTERPRETATION")
    print("="*70)
    print(f"\nThis model represents restaurant outbreaks as:")
    print(f"  1. Food handlers working shifts")
    print(f"  2. Direct transmission: infected handler → patrons they serve")
    print(f"  3. Food contamination: infected handler contaminates food → many patrons")
    print(f"  4. Staff-to-staff: transmission among workers during shifts")
    print()
    print(f"Calibrated parameters suggest:")
    print(f"  - {best_params['prob_food_contamination']*100:.0f}% chance infected handler contaminates food")
    print(f"  - When contamination occurs, {best_params['contamination_attack_rate']*100:.0f}% of exposed patrons get sick")
    print(f"  - Direct transmission rate: {best_params['beta_handler_patron']:.1%} per patron served")
    print()
    print("Parameters saved to: calibrated_restaurant_params.csv")


if __name__ == "__main__":
    main()