In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross
from functools import partial
from flax.struct import dataclass

# ==================================================
# ---      Simulation Parameters & State         ---
# ==================================================

@dataclass
class SimState:
    """A dataclass to hold the state of the simulation for lax.scan."""
    pop: any
    next_id_start: int
    key: jax.random.PRNGKey

@dataclass
class SimParams:
    """A dataclass for simulation parameters that don't change over time."""
    sp: SimParam
    heritability: float
    n_generations: int
    population_size: int
    n_select: int

# It's a good practice to define your simulation parameters in one place.
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42)
}

# ==================================================
# ---     Core Simulation Logic (for lax.scan)   ---
# ==================================================

def run_generation_step(state: SimState, _, params: SimParams):
    """
    Runs a single generation step. This function will be used with lax.scan.
    The underscore `_` is a placeholder for the looped-over variable, which we don't use here.
    """
    p_subkey, m_subkey1, m_subkey2, s_subkey, new_key = jax.random.split(state.key, 5)
    
    # --- Genetic and Phenotypic Calculations ---
    pop = set_bv(state.pop, params.sp.traits, params.sp.ploidy)
    pop = set_pheno(
        key=p_subkey, pop=pop, traits=params.sp.traits,
        ploidy=params.sp.ploidy, h2=jnp.array([params.heritability])
    )
    
    # --- Selection ---
    selected_indices = jnp.argsort(pop.pheno[:, 0])[-params.n_select:]
    
    # --- Mating ---
    mothers = jax.random.choice(m_subkey1, selected_indices, shape=(params.population_size,))
    fathers = jax.random.choice(m_subkey2, selected_indices, shape=(params.population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # --- Create new generation ---
    new_pop = make_cross(
        key=s_subkey, pop=pop, cross_plan=cross_plan,
        sp=params.sp, next_id_start=state.next_id_start
    )
    
    # --- Update state for the next generation ---
    new_state = SimState(
        pop=new_pop,
        next_id_start=state.next_id_start + params.population_size,
        key=new_key
    )
    
    # --- Collect results for this generation ---
    genetic_mean = jnp.mean(pop.bv)
    genetic_variance = jnp.var(pop.bv)
    
    return new_state, (genetic_mean, genetic_variance)

# ==================================================
# ---     JIT-compiled Simulation Function       ---
# ==================================================
@partial(jax.jit, static_argnames=("params",))
def run_simulation_scan(key, founder_pop, params: SimParams):
    """
    Runs the entire simulation using lax.scan for efficiency.
    We JIT-compile this function for maximum performance.
    """
    # Initialize the simulation state
    initial_state = SimState(
        pop=founder_pop,
        next_id_start=founder_pop.nInd,
        key=key
    )
    
    # Run the simulation over all generations
    # `lax.scan` is a JAX primitive for efficient loops.
    final_state, (genetic_means, genetic_variances) = jax.lax.scan(
        f=partial(run_generation_step, params=params),
        init=initial_state,
        xs=None,
        length=params.n_generations + 1
    )
    
    return genetic_means, genetic_variances

# ==================================================
# ---      Setup and Execution of Scenarios      ---
# ==================================================

# --- Setup Founder Population and Traits ---
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)
founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)
sp1 = SimParam.from_founder_pop(founder_pop, genetic_map)
sp1 = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp1,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"], var=simulation_parameters["trait_var"],
    gamma=False
)
sp2 = SimParam.from_founder_pop(founder_pop, genetic_map)
sp2 = add_trait_a(
    key=sp2_key, founder_pop=founder_pop, sim_param=sp2,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"], var=simulation_parameters["trait_var"],
    gamma=True, shape=0.4
)

# --- Define Scenarios ---
scenario_templates = [
    {"name_base": "Normal", "sp": sp1, "color": "#0077BB"},
    {"name_base": "Gamma", "sp": sp2, "color": "#009988"}
]
h2_values_to_test = [0.1, 0.6]
h2_linestyles = {0.1: ':', 0.6: '-'}

# Create a list of SimParams objects for each scenario
scenarios = []
for template in scenario_templates:
    for h2 in h2_values_to_test:
        scenarios.append({
            "name": f"{template['name_base']}, h2={h2}",
            "params": SimParams(
                sp=template['sp'],
                heritability=h2,
                n_generations=simulation_parameters["n_generations"],
                population_size=simulation_parameters["population_size"],
                n_select=simulation_parameters["n_select"]
            ),
            "color": template['color'],
            "linestyle": h2_linestyles[h2]
        })

# --- Run Replicated Simulations in Parallel ---
print("--- Running Replicated Simulations in Parallel ---")

# We will vmap over the simulation function for each scenario
# This is where the parallelization happens.
key, *rep_keys = jax.random.split(key, simulation_parameters["n_replicates"] + 1)
replicate_keys = jnp.stack(rep_keys)

results = {}
for scenario in scenarios:
    print(f"  Running scenario: {scenario['name']}")
    
    # `jax.vmap` vectorizes the function, running it once for each key.
    # The `in_axes` argument specifies how to map the inputs.
    # `in_axes=(0, None, None)` means:
    #   - Map over the first axis of `replicate_keys`.
    #   - Broadcast `founder_pop` and `scenario['params']`.
    vmapped_simulation = jax.vmap(run_simulation_scan, in_axes=(0, None, None))
    
    # Run all replicates in a single, parallelized call
    all_means, all_variances = vmapped_simulation(
        replicate_keys, founder_pop, scenario["params"]
    )
    
    results[scenario['name']] = {
        'means': all_means,
        'variances': all_variances
    }

print("--- All simulations complete ---")