In [1]:
import jax
import jax.numpy as jnp
from chewc.population import quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.cross import make_cross
from chewc.population import combine_populations, Population
import matplotlib.pyplot as plt
from chewc.population import quick_haplo, combine_populations, Population, subset_population # Add subset_population
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross


In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross
from functools import partial

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Largely Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

# Create SimParam objects for scenarios
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

def run_simulation(key, founder_pop, sp, heritability, params):
    """Runs a breeding simulation for a given number of generations."""
    genetic_means = []
    genetic_variances = []
    current_pop = founder_pop
    next_id_start = current_pop.nInd
    n_generations = params["n_generations"]
    population_size = params["population_size"]
    n_select = params["n_select"]
    pheno_key, select_key, mate_key = jax.random.split(key, 3)

    for gen in range(n_generations + 1):
        p_subkey, pheno_key = jax.random.split(pheno_key)
        s_subkey, select_key = jax.random.split(select_key)
        m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)
        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
        genetic_means.append(jnp.mean(current_pop.bv))
        genetic_variances.append(jnp.var(current_pop.bv))
        current_pop = set_pheno(
            key=p_subkey, pop=current_pop, traits=sp.traits,
            ploidy=sp.ploidy, h2=jnp.array([heritability])
        )
        selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
        mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
        fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
        cross_plan = jnp.stack([mothers, fathers], axis=1)
        current_pop = make_cross(
            key=s_subkey, pop=current_pop, cross_plan=cross_plan,
            sp=sp, next_id_start=next_id_start
        )
        next_id_start += population_size
    return genetic_means, genetic_variances




In [3]:
# # -------------------------
# # 1) EAGER inspector helper
# # -------------------------
# def inspect_obj(name, obj):
#     print("=== INSPECT:", name, "===")
#     print(" type:", type(obj))
#     # common attributes for Population / SimParam we care about
#     for attr in ("geno", "ibd", "id", "iid", "mother", "father", "pheno", "bv"):
#         if hasattr(obj, attr):
#             val = getattr(obj, attr)
#             print(f"  {attr}: type={type(val)}, shape={getattr(val, 'shape', None)}, dtype={getattr(val, 'dtype', None)}")
#     # SimParam-specific
#     if hasattr(obj, "gen_map"):
#         gm = getattr(obj, "gen_map")
#         print("  gen_map:", type(gm), getattr(gm, "shape", None), getattr(gm, "dtype", None))
#     if hasattr(obj, "traits"):
#         tr = getattr(obj, "traits")
#         print("  traits type:", type(tr))
#         if tr is not None and hasattr(tr, "add_eff"):
#             print("    add_eff.shape:", getattr(tr.add_eff, "shape", None))
#             print("    loci_loc.shape:", getattr(tr.loci_loc, "shape", None))
#     print("  repr (short):", repr(obj)[:200])
#     print("========================\n")


# # -------------------------------------------------------
# # 2) EAGER debug run (no jit/vmap) - use this first.
# # -------------------------------------------------------
# def debug_run_simulation_nojit(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
#     print(">>> Running debug_run_simulation_nojit (eager)")
#     # Eager inspections
#     inspect_obj("founder_pop", founder_pop)
#     inspect_obj("sp", sp)
#     print("heritiability (type):", type(heritability), heritability)

#     genetic_means = []
#     genetic_variances = []
#     current_pop = founder_pop
#     next_id_start = int(current_pop.nInd)
#     pheno_key, select_key, mate_key = jax.random.split(key, 3)

#     for gen in range(n_generations + 1):
#         print(f"\n[EAGER] generation {gen}/{n_generations}")
#         p_subkey, pheno_key = jax.random.split(pheno_key)
#         s_subkey, select_key = jax.random.split(select_key)
#         m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)

#         # print current_pop summary before BV/PHENO
#         print(" current_pop before set_bv: type", type(current_pop))
#         print("  geno shape:", getattr(current_pop.geno, "shape", None), "dtype:", getattr(current_pop.geno, "dtype", None))

#         current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
#         print("  after set_bv, bv shape:", getattr(current_pop.bv, "shape", None))
#         genetic_means.append(jnp.mean(current_pop.bv))
#         genetic_variances.append(jnp.var(current_pop.bv))

#         current_pop = set_pheno(key=p_subkey, pop=current_pop, traits=sp.traits, ploidy=sp.ploidy, h2=jnp.array([heritability]))
#         print("  after set_pheno, pheno.shape:", getattr(current_pop.pheno, "shape", None))

#         # selection
#         selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
#         print("  selected_indices type:", type(selected_indices), "shape:", selected_indices.shape)

#         # mothers/fathers -- debug their contents
#         mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
#         fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
#         print("  mothers.shape:", mothers.shape, "dtype:", mothers.dtype)
#         print("  fathers.shape:", fathers.shape, "dtype:", fathers.dtype)

#         cross_plan = jnp.stack([mothers, fathers], axis=1)
#         print("  cross_plan.shape:", cross_plan.shape)

#         # Debug before calling make_cross
#         print("  Calling make_cross(...)")
#         tmp_pop = make_cross(key=s_subkey, pop=current_pop, cross_plan=cross_plan, sp=sp, next_id_start=next_id_start)
#         print("  make_cross returned type:", type(tmp_pop))
#         print("   tmp_pop.geno.shape:", getattr(tmp_pop.geno, "shape", None))
#         # replace current_pop
#         current_pop = tmp_pop
#         next_id_start += population_size

#     return jnp.stack(genetic_means), jnp.stack(genetic_variances)


# # -------------------------------------------------------
# # 3) VMAP-aware function: use jax.debug.print for traced prints
# # -------------------------------------------------------
# # Note: we accept n_generations, population_size, n_select as *python* ints
# # so they are static from JAX's point of view (recommended).
# def run_simulation_vmapped(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
#     # Split RNGs used repeatedly
#     pheno_key, select_key, mate_key = jax.random.split(key, 3)

#     genetic_means = []
#     genetic_variances = []
#     current_pop = founder_pop
#     next_id_start = current_pop.nInd  # should be an int

#     # Debug prints that are safe under JAX tracing:
#     jax.debug.print("TRACE vmapped start: founder_pop.geno.shape={gshape}", gshape=current_pop.geno.shape)
#     jax.debug.print("TRACE vmapped start: sp.gen_map.shape={gm}", gm=sp.gen_map.shape)
#     if sp.traits is not None:
#         jax.debug.print("TRACE traits.add_eff.shape={a}", a=sp.traits.add_eff.shape)

#     for gen in range(n_generations + 1):
#         p_subkey, pheno_key = jax.random.split(pheno_key)
#         s_subkey, select_key = jax.random.split(select_key)
#         m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)

#         current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
#         jax.debug.print("TRACE gen={gen}: bv.shape={bshape}", gen=gen, bshape=current_pop.bv.shape)
#         genetic_means.append(jnp.mean(current_pop.bv))
#         genetic_variances.append(jnp.var(current_pop.bv))

#         current_pop = set_pheno(key=p_subkey, pop=current_pop, traits=sp.traits, ploidy=sp.ploidy, h2=jnp.array([heritability]))
#         jax.debug.print("TRACE gen={gen}: pheno.shape={pshape}", gen=gen, pshape=current_pop.pheno.shape)

#         selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
#         jax.debug.print("TRACE gen={gen}: selected_indices.shape={s}", gen=gen, s=selected_indices.shape)

#         mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
#         fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
#         jax.debug.print("TRACE gen={gen}: mothers.shape={ms} fathers.shape={fs}", gen=gen, ms=mothers.shape, fs=fathers.shape)

#         cross_plan = jnp.stack([mothers, fathers], axis=1)
#         jax.debug.print("TRACE gen={gen}: cross_plan.shape={cp}", gen=gen, cp=cross_plan.shape)

#         # Call make_cross (this is probably where the error appears)
#         new_pop = make_cross(key=s_subkey, pop=current_pop, cross_plan=cross_plan, sp=sp, next_id_start=next_id_start)
#         # Immediately print shapes of the new population's arrays
#         jax.debug.print("TRACE gen={gen}: new_pop.geno.shape={g}", gen=gen, g=new_pop.geno.shape)
#         current_pop = new_pop
#         next_id_start += population_size

#     # stack results so batched output has shape (n_reps, n_generations+1)
#     return jnp.stack(genetic_means), jnp.stack(genetic_variances)


# # -------------------------------------------------------
# # 4) How to call (suggested debugging order)
# # -------------------------------------------------------
# # Example (do these steps interactively):
# # 1) EAGER inspect before using vmap:
# inspect_obj("founder_pop", founder_pop)
# inspect_obj("sp", sp)
# #
# # 2) Run one replicate eagerly to see python-level prints:
# k = jax.random.PRNGKey(123)
# means, vars_ = debug_run_simulation_nojit(k, founder_pop, sp, 0.3, 
#                                         n_generations=5, population_size=50, n_select=5)
# #
# # 3) If that works, try the vmapped version (this will produce traced prints):
# n_reps = 4
# keys = jax.random.split(jax.random.PRNGKey(123), n_reps)
# herits = jnp.array([0.2, 0.3, 0.5, 0.7])
# batched = jax.vmap(run_simulation_vmapped, in_axes=(0, None, None, 0, None, None, None), out_axes=(0,0))
# means_b, vars_b = batched(keys, founder_pop, sp, herits, 5, 50, 5)
# #
# # Notes:
# # - If you see `type=object` for any `.geno`/`.ibd` etc in the eager inspector,
# #   that is the immediate cause of `AttributeError: 'object' object has no attribute 'shape'`.
# # - The `make_cross(...)` call is the most likely place where non-array inputs slip in;
# #   pay attention to the debug print immediately after the call in the eager run and
# #   the jax.debug.print in the vmapped run.


In [4]:
from functools import partial # Make sure this is imported

def _simulation_generation_step(carry, _):
    """Logic for a single generation, designed for lax.scan."""
    current_pop, next_id_start, key, sp, heritability, population_size, n_select = carry

    # Split keys
    p_key, s_key, m_key, next_loop_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)

    # Core logic
    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))

    # Selection & Mating
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # Create next generation
    new_pop = make_cross(s_key, current_pop, cross_plan, sp, next_id_start)

    # Collect metrics & prepare state for next iteration
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    next_carry = (new_pop, next_id_start + population_size, next_loop_key, sp, heritability, population_size, n_select)
    
    return next_carry, metrics

@partial(jax.jit, static_argnames=("n_generations", "population_size", "n_select"))
def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    """JAX-native simulation runner using lax.scan."""
    initial_carry = (founder_pop, int(founder_pop.nInd), key, sp, heritability, population_size, n_select)
    
    final_carry, all_metrics = jax.lax.scan(
        _simulation_generation_step, initial_carry, None, length=n_generations + 1
    )
    
    genetic_means, genetic_variances = all_metrics
    return genetic_means, genetic_variances

# --- Your final calling code ---
n_reps = 4
keys = jax.random.split(jax.random.PRNGKey(123), n_reps)
herits = jnp.array([0.2, 0.3, 0.5, 0.7])

# Vmap over the NEW scan-based function
batched = jax.vmap(
    run_simulation_scan,  # <-- This is the one to use
    in_axes=(0, None, None, 0, None, None, None), 
    out_axes=(0, 0)
)

print(">>> Running vmapped simulation with lax.scan...")
means_b, vars_b = batched(keys, founder_pop, sp, herits, 5, 50, 5)
print("... vmapped simulation successful!")

>>> Running vmapped simulation with lax.scan...


ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<~int32[]>with<DynamicJaxprTrace>. The error was:
TypeError: unhashable type: 'DynamicJaxprTracer'


In [7]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross
from functools import partial

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 50,  # Increased replicates for a smoother plot
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
)

# ==================================================
# --- Simulation Function (Unchanged) ---
# ==================================================
def run_simulation(key, founder_pop, sp, heritability, params):
    """Runs a breeding simulation for a given number of generations."""
    genetic_means = []
    genetic_variances = []
    current_pop = founder_pop
    next_id_start = current_pop.nInd
    n_generations = params["n_generations"]
    population_size = params["population_size"]
    n_select = params["n_select"]
    pheno_key, select_key, mate_key = jax.random.split(key, 3)

    for gen in range(n_generations + 1):
        p_subkey, pheno_key = jax.random.split(pheno_key)
        s_subkey, select_key = jax.random.split(select_key)
        m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)
        
        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
        genetic_means.append(jnp.mean(current_pop.bv))
        genetic_variances.append(jnp.var(current_pop.bv))
        
        current_pop = set_pheno(
            key=p_subkey, pop=current_pop, traits=sp.traits,
            ploidy=sp.ploidy, h2=jnp.array([heritability])
        )
        
        selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
        mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
        fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
        
        cross_plan = jnp.stack([mothers, fathers], axis=1)
        
        current_pop = make_cross(
            key=s_subkey, pop=current_pop, cross_plan=cross_plan,
            sp=sp, next_id_start=next_id_start
        )
        next_id_start += population_size
        
    return jnp.array(genetic_means), jnp.array(genetic_variances)

# ==================================================
# ---  Vectorization and Execution ---
# ==================================================
# 1. Create a batch of random keys, one for each replicate
n_replicates = simulation_parameters["n_replicates"]
replicate_keys = jax.random.split(key, n_replicates)

# 2. Use vmap to create a version of the function that maps over the keys
#    in_axes=(0, None, None, None, None) means:
#    - Map over the first (and only) axis of `replicate_keys`.
#    - Broadcast the other arguments (founder_pop, sp, etc.) so they are
#      the same for every replicate.
vmapped_simulation = jax.vmap(
    run_simulation, in_axes=(0, None, None, None, None)
)

# 3. Run all replicates in parallel
print(f"Running {n_replicates} replicates...")
all_means, all_variances = vmapped_simulation(
    replicate_keys,
    founder_pop,
    sp,
    simulation_parameters["h2"][0],
    simulation_parameters
)
print("Simulation complete.")

# 4. Calculate statistics across replicates
mean_of_means = jnp.mean(all_means, axis=0)
std_err_of_means = jnp.std(all_means, axis=0) / jnp.sqrt(n_replicates)

mean_of_variances = jnp.mean(all_variances, axis=0)
std_err_of_variances = jnp.std(all_variances, axis=0) / jnp.sqrt(n_replicates)

# ==================================================
# ---  Plotting Results ---
# ==================================================
generations = jnp.arange(simulation_parameters["n_generations"] + 1)
plt.style.use('seaborn-v0_8-whitegrid')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Plot Genetic Mean
ax1.plot(generations, mean_of_means, label='Mean Genetic Value', color='royalblue')
ax1.fill_between(
    generations,
    mean_of_means - 1.96 * std_err_of_means,
    mean_of_means + 1.96 * std_err_of_means,
    alpha=0.2, color='royalblue', label='95% CI'
)
ax1.set_xlabel('Generation')
ax1.set_ylabel('Genetic Mean')
ax1.set_title('Genetic Gain Over Time')
ax1.legend()
ax1.grid(True)

# Plot Genetic Variance
ax2.plot(generations, mean_of_variances, label='Mean Genetic Variance', color='firebrick')
ax2.fill_between(
    generations,
    mean_of_variances - 1.96 * std_err_of_variances,
    mean_of_variances + 1.96 * std_err_of_variances,
    alpha=0.2, color='firebrick', label='95% CI'
)
ax2.set_xlabel('Generation')
ax2.set_ylabel('Genetic Variance')
ax2.set_title('Change in Genetic Variance')
ax2.legend()
ax2.grid(True)


plt.tight_layout()
plt.show()



Running 50 replicates...
TRACE make_cross geno type=<class 'jax._src.interpreters.batching.BatchTracer'> shape=(Array(200, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(1000, dtype=int32, weak_type=True))
TRACE make_cross new_public_ids type=<class 'jaxlib._jax.ArrayImpl'> shape=(Array(200, dtype=int32, weak_type=True),)
TRACE make_cross geno type=<class 'jax._src.interpreters.batching.BatchTracer'> shape=(Array(200, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(1000, dtype=int32, weak_type=True))
TRACE make_cross new_public_ids type=<class 'jaxlib._jax.ArrayImpl'> shape=(Array(200, dtype=int32, weak_type=True),)
TRACE make_cross geno type=<class 'jax._src.interpreters.batching.BatchTracer'> shape=(Array(200, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(1000, dtype=int32, weak_type=Tr

KeyboardInterrupt: 

In [5]:
# In your main simulation notebook/script

from functools import partial # Add this import

def _simulation_generation_step(carry, _):
    """Logic for a single generation, designed for lax.scan."""
    current_pop, next_id_start, key, sp, heritability, population_size, n_select = carry

    # Split keys
    p_key, s_key, m_key, next_loop_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)

    # Core logic
    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))

    # Selection & Mating
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # Create next generation
    new_pop = make_cross(s_key, current_pop, cross_plan, sp, next_id_start)

    # Collect metrics & prepare state for next iteration
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    next_carry = (new_pop, next_id_start + population_size, next_loop_key, sp, heritability, population_size, n_select)
    
    return next_carry, metrics

@partial(jax.jit, static_argnames=("n_generations", "population_size", "n_select"))
def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    """JAX-native simulation runner using lax.scan."""
    initial_carry = (founder_pop, int(founder_pop.nInd), key, sp, heritability, population_size, n_select)
    
    final_carry, all_metrics = jax.lax.scan(
        _simulation_generation_step, initial_carry, None, length=n_generations + 1
    )
    
    genetic_means, genetic_variances = all_metrics
    return genetic_means, genetic_variances

# --- Your final calling code ---
n_reps = 4
keys = jax.random.split(jax.random.PRNGKey(123), n_reps)
herits = jnp.array([0.2, 0.3, 0.5, 0.7])

# Vmap over the NEW scan-based function
batched = jax.vmap(
    run_simulation_scan,  # <-- This is the one to use
    in_axes=(0, None, None, 0, None, None, None), 
    out_axes=(0, 0)
)

print(">>> Running vmapped simulation with lax.scan...")
means_b, vars_b = batched(keys, founder_pop, sp, herits, 5, 50, 5)
print("... vmapped simulation successful!")
print("Batched means shape:", means_b.shape)

>>> Running vmapped simulation with lax.scan...


ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<~int32[]>with<DynamicJaxprTrace>. The error was:
TypeError: unhashable type: 'DynamicJaxprTracer'


In [9]:
import jax
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross

# (Your simulation_parameters and initial setup are correct and remain unchanged)
# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Largely Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

# Create SimParam objects for scenarios
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# ==================================================
# ---  Corrected lax.scan Implementation ---
# ==================================================

def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    """
    This function now takes static parameters (sp, heritability, etc.)
    as regular arguments, not as part of the carry.
    The second argument `_` is a placeholder for the `xs` in lax.scan, which we are not using.
    """
    current_pop, next_id_start, key = carry
    
    # Split key for this generation's operations
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)

    # --- Core Logic (Unchanged) ---
    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    
    # Capture metrics
    genetic_mean = jnp.mean(current_pop.bv)
    genetic_variance = jnp.var(current_pop.bv)
    metrics = (genetic_mean, genetic_variance)

    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    
    # Selection & Mating
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)

    # Create next generation
    next_pop = make_cross(
        key=s_key, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )
    
    # Prepare the carry for the *next* iteration
    next_carry = (next_pop, next_id_start + population_size, key)
    
    return next_carry, metrics

def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    """
    JAX-native simulation runner using lax.scan with the corrected pattern.
    """
    # The initial state to be "carried" over iterations.
    # Notice `sp` and other params are NOT here.
    initial_carry = (founder_pop, int(founder_pop.nInd), key)

    # Use a lambda to "close over" the static parameters.
    # This creates a function with the signature `(carry, _)` that lax.scan expects.
    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )

    # Run the scan
    final_carry, all_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations + 1
    )

    genetic_means, genetic_variances = all_metrics
    return genetic_means, genetic_variances

# --- Vectorization and Execution ---
print(">>> Compiling and running vmapped simulation with lax.scan...")
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)

# The vmapped function now correctly handles static vs. dynamic arguments.
# We map over the `keys` array, while all other arguments are broadcast (`None`).
batched_scan_simulation = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

means, variances = batched_scan_simulation(
    keys,
    founder_pop,
    sp,
    simulation_parameters['h2'][0],
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)

>>> Compiling and running vmapped simulation with lax.scan...


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
It arose in the jnp.arange argument 'start'
The error occurred while tracing the function <lambda> at /tmp/ipykernel_166737/179897431.py:105 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[1].

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [11]:
import jax
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 5,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# ==================================================
# ---  Corrected lax.scan with Stable Shapes ---
# ==================================================

# The generation step function remains the same as our last fix.
def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    current_pop, next_id_start, key = carry
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)
    
    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    next_pop = make_cross(key=s_key, pop=current_pop, cross_plan=cross_plan, sp=sp, next_id_start=next_id_start)
    next_carry = (next_pop, next_id_start + population_size, key)
    
    return next_carry, metrics

def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    """
    JAX-native runner that stabilizes population size *before* the main scan loop.
    """
    # --- STEP 1: Handle Generation 0 (Founders) ---
    key, g0_key, g1_key, scan_key = jax.random.split(key, 4)
    
    # Calculate metrics for the initial founder population
    pop_g0 = set_bv(founder_pop, sp.traits, sp.ploidy)
    g0_metrics = (jnp.mean(pop_g0.bv), jnp.var(pop_g0.bv))
    
    # --- STEP 2: Run the FIRST cross to establish the stable population size ---
    pop_g0_pheno = set_pheno(g0_key, pop_g0, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices_g0 = jnp.argsort(pop_g0_pheno.pheno[:, 0])[-n_select:]
    m1_key_g0, m2_key_g0 = jax.random.split(g1_key)
    mothers_g0 = jax.random.choice(m1_key_g0, selected_indices_g0, shape=(population_size,))
    fathers_g0 = jax.random.choice(m2_key_g0, selected_indices_g0, shape=(population_size,))
    cross_plan_g0 = jnp.stack([mothers_g0, fathers_g0], axis=1)

    pop_g1 = make_cross(
        key=g1_key, pop=pop_g0_pheno, cross_plan=cross_plan_g0, sp=sp, 
        next_id_start=int(founder_pop.nInd)
    )

    # --- STEP 3: Run the rest of the generations with lax.scan ---
    # The initial carry for the scan now uses pop_g1, which has the correct, stable shape.
    initial_carry = (pop_g1, int(founder_pop.nInd) + population_size, scan_key)
    
    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )
    
    # We scan for `n_generations` because we already completed one.
    final_carry, scan_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations
    )

    # --- STEP 4: Combine metrics from all generations ---
    # `jnp.stack` the metrics from G0 with the metrics from the scanned generations
    all_means = jnp.hstack([g0_metrics[0], scan_metrics[0]])
    all_variances = jnp.hstack([g0_metrics[1], scan_metrics[1]])
    
    return all_means, all_variances

# --- Vectorization and Execution (Unchanged) ---
print(">>> Compiling and running vmapped simulation with lax.scan...")
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)

batched_scan_simulation = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

means, variances = batched_scan_simulation(
    keys,
    founder_pop,
    sp,
    simulation_parameters['h2'][0],
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)

>>> Compiling and running vmapped simulation with lax.scan...
TRACE make_cross geno type=<class 'jax._src.interpreters.batching.BatchTracer'> shape=(Array(200, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(1000, dtype=int32, weak_type=True))
TRACE make_cross new_public_ids type=<class 'jaxlib._jax.ArrayImpl'> shape=(Array(200, dtype=int32, weak_type=True),)


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
It arose in the jnp.arange argument 'start'
The error occurred while tracing the function <lambda> at /tmp/ipykernel_166737/1116422661.py:102 for scan. This concrete value was not available in Python because it depends on the value of the argument carry[1].

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [12]:
import jax
from jax import debug
import jax.numpy as jnp
from functools import partial
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# ==================================================
# --- Patched `make_cross` with Debugging ---
# ==================================================
# This is a patched version of your library's function.
def make_cross_debug(key, pop, cross_plan, sp, next_id_start):
    """
    Patched version of make_cross with the arange fix and debug prints.
    """
    n_crosses = cross_plan.shape[0]
    key_geno, key_sex = jax.random.split(key)

    mother_iids = cross_plan[:, 0]
    father_iids = cross_plan[:, 1]
    mothers_geno = pop.geno[mother_iids]
    fathers_geno = pop.geno[father_iids]
    mothers_ibd = pop.ibd[mother_iids]
    fathers_ibd = pop.ibd[father_iids]

    # This call is already JIT-friendly
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_geno, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd,
        sp.n_chr, sp.gen_map, sp.recomb_params[0]
    )

    # --- THE FIX ---
    # Create an array of shape (n_crosses,) starting from 0. The length is static.
    # Then, add the dynamic `next_id_start` value. JAX can trace this.
    new_public_ids = next_id_start + jnp.arange(n_crosses)
    # --- END OF FIX ---
    
    new_iids = jnp.arange(n_crosses)
    mother_public_ids = pop.id[mother_iids]
    father_public_ids = pop.id[father_iids]
    parent_gen = pop.gen[mother_iids[0]]
    progeny_gen = parent_gen + 1

    # --- LIBERAL DEBUGGING STATEMENTS ---
    debug.print("--- Inside make_cross_debug ---")
    debug.print("next_id_start: {}", next_id_start)
    debug.print("n_crosses: {}", n_crosses)
    debug.print("Shape of progeny_geno: {}", progeny_geno.shape)
    debug.print("Shape of new_public_ids: {}", new_public_ids.shape)
    debug.print("First 5 new_public_ids: {}", new_public_ids[:5])
    
    progeny_pop = Population(
        geno=progeny_geno, ibd=progeny_ibd, id=new_public_ids, iid=new_iids,
        mother=mother_public_ids, father=father_public_ids,
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        gen=jnp.full((n_crosses,), progeny_gen, dtype=jnp.int32),
        pheno=jnp.zeros((n_crosses, sp.n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, sp.n_traits))
    )
    return progeny_pop

# You need to also import the internal `_make_cross_geno` for the patch to work
from chewc.cross import _make_cross_geno

# ==================================================
# ---  Corrected `lax.scan` with Patched Function ---
# ==================================================

def _simulation_generation_step(carry, _, sp, heritability, population_size, n_select):
    current_pop, next_id_start, key = carry
    key, p_key, s_key, m_key = jax.random.split(key, 4)
    m1_key, m2_key = jax.random.split(m_key)
    
    debug.print("\n--- Running Generation Step ---")
    debug.print("Carry-in next_id_start: {}", next_id_start)

    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    metrics = (jnp.mean(current_pop.bv), jnp.var(current_pop.bv))
    
    current_pop = set_pheno(p_key, current_pop, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m1_key, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m2_key, selected_indices, shape=(population_size,))
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    
    # Use the patched and debug-ready version of the function
    next_pop = make_cross_debug(
        key=s_key, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )
    
    next_id_start_out = next_id_start + population_size
    next_carry = (next_pop, next_id_start_out, key)
    
    debug.print("Carry-out next_id_start: {}", next_id_start_out)
    return next_carry, metrics

# The main runner function remains the same, it will just use our patched step function
def run_simulation_scan(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    key, g0_key, g1_key, scan_key = jax.random.split(key, 4)
    
    pop_g0 = set_bv(founder_pop, sp.traits, sp.ploidy)
    g0_metrics = (jnp.mean(pop_g0.bv), jnp.var(pop_g0.bv))
    
    pop_g0_pheno = set_pheno(g0_key, pop_g0, sp.traits, sp.ploidy, jnp.array([heritability]))
    selected_indices_g0 = jnp.argsort(pop_g0_pheno.pheno[:, 0])[-n_select:]
    m1_key_g0, m2_key_g0 = jax.random.split(g1_key)
    mothers_g0 = jax.random.choice(m1_key_g0, selected_indices_g0, shape=(population_size,))
    fathers_g0 = jax.random.choice(m2_key_g0, selected_indices_g0, shape=(population_size,))
    cross_plan_g0 = jnp.stack([mothers_g0, fathers_g0], axis=1)

    pop_g1 = make_cross_debug(
        key=g1_key, pop=pop_g0_pheno, cross_plan=cross_plan_g0, sp=sp, 
        next_id_start=int(founder_pop.nInd)
    )

    initial_carry = (pop_g1, int(founder_pop.nInd) + population_size, scan_key)
    
    generation_func = lambda carry, _ : _simulation_generation_step(
        carry, _, sp, heritability, population_size, n_select
    )
    
    final_carry, scan_metrics = jax.lax.scan(
        generation_func, initial_carry, None, length=n_generations
    )

    all_means = jnp.hstack([g0_metrics[0], scan_metrics[0]])
    all_variances = jnp.hstack([g0_metrics[1], scan_metrics[1]])
    
    return all_means, all_variances

# --- Vectorization and Execution (Unchanged) ---
print(">>> Compiling and running vmapped simulation with lax.scan...")
n_replicates = simulation_parameters["n_replicates"]
keys = jax.random.split(key, n_replicates)

batched_scan_simulation = jax.vmap(
    run_simulation_scan,
    in_axes=(0, None, None, None, None, None, None)
)

# To see the debug output, run JAX in a non-JIT context first, or use a
# configurable print function that works inside JIT.
# For simplicity, let's just run it. The fix should prevent the crash.
means, variances = batched_scan_simulation(
    keys,
    founder_pop,
    sp,
    simulation_parameters['h2'][0],
    simulation_parameters['n_generations'],
    simulation_parameters['population_size'],
    simulation_parameters['n_select']
)

print("... Vmapped simulation successful!")
print("Batched means shape:", means.shape)
print("Batched variances shape:", variances.shape)

>>> Compiling and running vmapped simulation with lax.scan...
--- Inside make_cross_debug ---
next_id_start: 10
n_crosses: 200
Shape of progeny_geno: (Array(200, dtype=int32, weak_type=True), Array(5, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array(1000, dtype=int32, weak_type=True))
Shape of new_public_ids: (Array(200, dtype=int32, weak_type=True),)
First 5 new_public_ids: [10 11 12 13 14]

--- Running Generation Step ---
Carry-in next_id_start: 210
next_id_start: 210
Carry-out next_id_start: 410
First 5 new_public_ids: [210 211 212 213 214]
--- Inside make_cross_debug ---
n_crosses: 200
Shape of new_public_ids: (Array(200, dtype=int32),)
Shape of progeny_geno: (Array(200, dtype=int32), Array(5, dtype=int32), Array(2, dtype=int32), Array(1000, dtype=int32))

--- Running Generation Step ---
--- Inside make_cross_debug ---
n_crosses: 200
Shape of new_public_ids: (Array(200, dtype=int32),)
Carry-in next_id_start: 410
next_id_start: 410
Shape of progeny_geno: (A