In [None]:
import jax
from jax import debug
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
# from chewc.cross import make_cross

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 5,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 5,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# ==================================================
# --- Patched `make_cross` with Debugging ---
# ==================================================
# This is a patched version of your library's function.
def make_cross(key, pop, cross_plan, sp, next_id_start):
    """
    Patched version of make_cross with the arange fix and debug prints.
    """
    n_crosses = cross_plan.shape[0]
    key_geno, key_sex = jax.random.split(key)

    mother_iids = cross_plan[:, 0]
    father_iids = cross_plan[:, 1]
    mothers_geno = pop.geno[mother_iids]
    fathers_geno = pop.geno[father_iids]
    mothers_ibd = pop.ibd[mother_iids]
    fathers_ibd = pop.ibd[father_iids]

    # This call is already JIT-friendly
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_geno, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd,
        sp.n_chr, sp.gen_map, sp.recomb_params[0]
    )

    # --- THE FIX ---
    # Create an array of shape (n_crosses,) starting from 0. The length is static.
    # Then, add the dynamic `next_id_start` value. JAX can trace this.
    new_public_ids = next_id_start + jnp.arange(n_crosses)
    # --- END OF FIX ---
    
    new_iids = jnp.arange(n_crosses)
    mother_public_ids = pop.id[mother_iids]
    father_public_ids = pop.id[father_iids]
    parent_gen = pop.gen[mother_iids[0]]
    progeny_gen = parent_gen + 1

    # --- LIBERAL DEBUGGING STATEMENTS ---
    # debug.print("--- Inside make_cross_debug ---")
    # debug.print("next_id_start: {}", next_id_start)
    # debug.print("n_crosses: {}", n_crosses)
    # debug.print("Shape of progeny_geno: {}", progeny_geno.shape)
    # debug.print("Shape of new_public_ids: {}", new_public_ids.shape)
    # debug.print("First 5 new_public_ids: {}", new_public_ids[:5])
    
    progeny_pop = Population(
        geno=progeny_geno, ibd=progeny_ibd, id=new_public_ids, iid=new_iids,
        mother=mother_public_ids, father=father_public_ids,
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        gen=jnp.full((n_crosses,), progeny_gen, dtype=jnp.int32),
        pheno=jnp.zeros((n_crosses, sp.n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, sp.n_traits))
    )
    return progeny_pop

# You need to also import the internal `_make_cross_geno` for the patch to work
from chewc.cross import _make_cross_geno

# ==================================================
# ---  Corrected `lax.scan` with Patched Function ---
# ==================================================

def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    current_pop, next_id_start, key = carry
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)
    
    debug.print("\n--- Running Generation Step ---")
    debug.print("Carry-in next_id_start: {}", next_id_start)

    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # Use the patched and debug-ready version of the function
    next_pop = make_cross(
        key=s_key, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )
    
    next_id_start_out = next_id_start + population_size
    next_carry = (next_pop, next_id_start_out, key)
    
    debug.print("Carry-out next_id_start: {}", next_id_start_out)
    return next_carry, metrics

# The main runner function remains the same, it will just use our patched step function
def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    key, g0_key, g1_key, scan_key = jax.random.split(key, 4)
    
    pop_g0 = set_bv(founder_pop, sp.traits, sp.ploidy)
    g0_metrics = (jnp.mean(pop_g0.bv), jnp.var(pop_g0.bv))
    
    pop_g0_pheno = set_pheno(g0_key, pop_g0, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices_g0 = jnp.argsort(pop_g0_pheno.pheno[:, 0])[-n_select:]
    m1_key_g0, m2_key_g0 = jax.random.split(g1_key)
    mothers_g0 = jax.random.choice(m1_key_g0, selected_indices_g0, shape=(population_size,))
    fathers_g0 = jax.random.choice(m2_key_g0, selected_indices_g0, shape=(population_size,))
    cross_plan_g0 = jnp.stack([mothers_g0, fathers_g0], axis=1)

    pop_g1 = make_cross(
        key=g1_key, pop=pop_g0_pheno, cross_plan=cross_plan_g0, sp=sp, 
        next_id_start=int(founder_pop.nInd)
    )

    initial_carry = (pop_g1, int(founder_pop.nInd) + population_size, scan_key)
    
    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )
    
    final_carry, scan_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations
    )

    all_means = jnp.hstack([g0_metrics[0], scan_metrics[0]])
    all_variances = jnp.hstack([g0_metrics[1], scan_metrics[1]])
    
    return all_means, all_variances

%%time

# --- Vectorization and Execution (Unchanged) ---
print(">>> Compiling and running vmapped simulation with lax.scan...")
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)

batched_scan_simulation = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

# To see the debug output, run JAX in a non-JIT context first, or use a
# configurable print function that works inside JIT.
# For simplicity, let's just run it. The fix should prevent the crash.
means, variances = batched_scan_simulation(
    keys,
    founder_pop,
    sp,
    simulation_parameters['h2'][0],
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)



In [12]:
import jax
from jax import debug
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross, _make_cross_geno

# ==================================================
# --- Simulation Parameters ---
# ==================================================
simulation_parameters = {
    "n_replicates": 5,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    # Define the array of heritabilities for the experiments
    "h2_values": jnp.array([0.3, 0.5, 0.7, 0.9]),
}

# ==================================================
# --- Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# ==================================================
# --- Patched `make_cross` with Debugging (Unchanged) ---
# ==================================================
def make_cross_debug(key, pop, cross_plan, sp, next_id_start):
    """
    Patched version of make_cross with the arange fix and debug prints.
    """
    n_crosses = cross_plan.shape[0]
    key_geno, key_sex = jax.random.split(key)

    mother_iids = cross_plan[:, 0]
    father_iids = cross_plan[:, 1]
    mothers_geno = pop.geno[mother_iids]
    fathers_geno = pop.geno[father_iids]
    mothers_ibd = pop.ibd[mother_iids]
    fathers_ibd = pop.ibd[father_iids]

    progeny_geno, progeny_ibd = _make_cross_geno(
        key_geno, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd,
        sp.n_chr, sp.gen_map, sp.recomb_params[0]
    )

    new_public_ids = next_id_start + jnp.arange(n_crosses)
    new_iids = jnp.arange(n_crosses)
    mother_public_ids = pop.id[mother_iids]
    father_public_ids = pop.id[father_iids]
    parent_gen = pop.gen[mother_iids[0]]
    progeny_gen = parent_gen + 1

    progeny_pop = Population(
        geno=progeny_geno, ibd=progeny_ibd, id=new_public_ids, iid=new_iids,
        mother=mother_public_ids, father=father_public_ids,
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        gen=jnp.full((n_crosses,), progeny_gen, dtype=jnp.int32),
        pheno=jnp.zeros((n_crosses, sp.n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, sp.n_traits))
    )
    return progeny_pop

# ==================================================
# --- Corrected `lax.scan` with Patched Function (Unchanged) ---
# ==================================================
def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    current_pop, next_id_start, key = carry
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)

    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))

    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)

    next_pop = make_cross_debug(
        key=s_key, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )

    next_id_start_out = next_id_start + population_size
    next_carry = (next_pop, next_id_start_out, key)

    return next_carry, metrics

def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    key, g0_key, g1_key, scan_key = jax.random.split(key, 4)

    pop_g0 = set_bv(founder_pop, sp.traits, sp.ploidy)
    g0_metrics = (jnp.mean(pop_g0.bv), jnp.var(pop_g0.bv))

    pop_g0_pheno = set_pheno(g0_key, pop_g0, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices_g0 = jnp.argsort(pop_g0_pheno.pheno[:, 0])[-n_select:]
    m1_key_g0, m2_key_g0 = jax.random.split(g1_key)
    mothers_g0 = jax.random.choice(m1_key_g0, selected_indices_g0, shape=(population_size,))
    fathers_g0 = jax.random.choice(m2_key_g0, selected_indices_g0, shape=(population_size,))
    cross_plan_g0 = jnp.stack([mothers_g0, fathers_g0], axis=1)

    pop_g1 = make_cross_debug(
        key=g1_key, pop=pop_g0_pheno, cross_plan=cross_plan_g0, sp=sp,
        next_id_start=int(founder_pop.nInd)
    )

    initial_carry = (pop_g1, int(founder_pop.nInd) + population_size, scan_key)

    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )

    final_carry, scan_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations
    )

    all_means = jnp.hstack([g0_metrics[0], scan_metrics[0]])
    all_variances = jnp.hstack([g0_metrics[1], scan_metrics[1]])

    return all_means, all_variances




In [13]:
%%time


# ==================================================
# --- Vectorization and Execution (Updated) ---
# ==================================================
print(">>> Compiling and running nested vmapped simulation with lax.scan...")

# 1. Inner vmap for replicates (over keys)
# This maps the simulation over the different random keys for replicates.
batched_over_replicates = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

# 2. Outer vmap for heritabilities
# This takes the already-batched-over-replicates function and maps it
# over the different heritability values.
run_all_experiments = jax.vmap(
    batched_over_replicates,
    in_axes=(None, None, None, 0, None, None, None)
)

# --- Execute the full experiment ---
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)
h2_values = simulation_parameters['h2_values']

means, variances = run_all_experiments(
    keys,
    founder_pop,
    sp,
    h2_values,
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Nested vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)

# The output shapes will be (number_of_heritabilities, n_replicates, n_generations + 1)
# For this script, it will be (4, 5, 6)

>>> Compiling and running nested vmapped simulation with lax.scan...
... Nested vmapped simulation successful!
Batched means shape: (4, 5, 51)
Batched variances shape: (4, 5, 51)
CPU times: user 39.2 s, sys: 2.12 s, total: 41.3 s
Wall time: 8.1 s
