In [6]:
import jax
from jax import debug
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross
# from chewc.cross import make_cross

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 50,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 5,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# 

# ==================================================
# ---  Corrected `lax.scan` with Patched Function ---
# ==================================================

def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    current_pop, next_id_start, key = carry
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)
    
    debug.print("\n--- Running Generation Step ---")
    debug.print("Carry-in next_id_start: {}", next_id_start)

    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # Use the patched and debug-ready version of the function
    next_pop = make_cross(
        key=s_key, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )
    
    next_id_start_out = next_id_start + population_size
    next_carry = (next_pop, next_id_start_out, key)
    
    debug.print("Carry-out next_id_start: {}", next_id_start_out)
    return next_carry, metrics

# The main runner function remains the same, it will just use our patched step function
def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    key, g0_key, g1_key, scan_key = jax.random.split(key, 4)
    
    pop_g0 = set_bv(founder_pop, sp.traits, sp.ploidy)
    g0_metrics = (jnp.mean(pop_g0.bv), jnp.var(pop_g0.bv))
    
    pop_g0_pheno = set_pheno(g0_key, pop_g0, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices_g0 = jnp.argsort(pop_g0_pheno.pheno[:, 0])[-n_select:]
    m1_key_g0, m2_key_g0 = jax.random.split(g1_key)
    mothers_g0 = jax.random.choice(m1_key_g0, selected_indices_g0, shape=(population_size,))
    fathers_g0 = jax.random.choice(m2_key_g0, selected_indices_g0, shape=(population_size,))
    cross_plan_g0 = jnp.stack([mothers_g0, fathers_g0], axis=1)

    pop_g1 = make_cross(
        key=g1_key, pop=pop_g0_pheno, cross_plan=cross_plan_g0, sp=sp, 
        next_id_start=int(founder_pop.nInd)
    )

    initial_carry = (pop_g1, int(founder_pop.nInd) + population_size, scan_key)
    
    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )
    
    final_carry, scan_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations
    )

    all_means = jnp.hstack([g0_metrics[0], scan_metrics[0]])
    all_variances = jnp.hstack([g0_metrics[1], scan_metrics[1]])
    
    return all_means, all_variances


# --- Vectorization and Execution (Unchanged) ---
print(">>> Compiling and running vmapped simulation with lax.scan...")
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)

batched_scan_simulation = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

# To see the debug output, run JAX in a non-JIT context first, or use a
# configurable print function that works inside JIT.
# For simplicity, let's just run it. The fix should prevent the crash.
means, variances = batched_scan_simulation(
    keys,
    founder_pop,
    sp,
    simulation_parameters['h2'][0],
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)

>>> Compiling and running vmapped simulation with lax.scan...

--- Running Generation Step ---
Carry-in next_id_start: 210
Carry-out next_id_start: 410

--- Running Generation Step ---
Carry-in next_id_start: 410
Carry-out next_id_start: 610

--- Running Generation Step ---
Carry-in next_id_start: 610
Carry-out next_id_start: 810

--- Running Generation Step ---
Carry-in next_id_start: 810
Carry-out next_id_start: 1010

--- Running Generation Step ---
Carry-in next_id_start: 1010
Carry-out next_id_start: 1210
... Vmapped simulation successful!
Batched means shape: (50, 6)
Batched variances shape: (50, 6)
