In [1]:
import jax
import jax.numpy as jnp
from chewc.population import quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.cross import make_cross
from chewc.population import combine_populations, Population
import matplotlib.pyplot as plt
from chewc.population import quick_haplo, combine_populations, Population, subset_population # Add subset_population
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross


In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross
from functools import partial

# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Largely Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

# Create SimParam objects for scenarios
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

def run_simulation(key, founder_pop, sp, heritability, params):
    """Runs a breeding simulation for a given number of generations."""
    genetic_means = []
    genetic_variances = []
    current_pop = founder_pop
    next_id_start = current_pop.nInd
    n_generations = params["n_generations"]
    population_size = params["population_size"]
    n_select = params["n_select"]
    pheno_key, select_key, mate_key = jax.random.split(key, 3)

    for gen in range(n_generations + 1):
        p_subkey, pheno_key = jax.random.split(pheno_key)
        s_subkey, select_key = jax.random.split(select_key)
        m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)
        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
        genetic_means.append(jnp.mean(current_pop.bv))
        genetic_variances.append(jnp.var(current_pop.bv))
        current_pop = set_pheno(
            key=p_subkey, pop=current_pop, traits=sp.traits,
            ploidy=sp.ploidy, h2=jnp.array([heritability])
        )
        selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
        mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
        fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
        cross_plan = jnp.stack([mothers, fathers], axis=1)
        current_pop = make_cross(
            key=s_subkey, pop=current_pop, cross_plan=cross_plan,
            sp=sp, next_id_start=next_id_start
        )
        next_id_start += population_size
    return genetic_means, genetic_variances




In [3]:
# -------------------------
# 1) EAGER inspector helper
# -------------------------
def inspect_obj(name, obj):
    print("=== INSPECT:", name, "===")
    print(" type:", type(obj))
    # common attributes for Population / SimParam we care about
    for attr in ("geno", "ibd", "id", "iid", "mother", "father", "pheno", "bv"):
        if hasattr(obj, attr):
            val = getattr(obj, attr)
            print(f"  {attr}: type={type(val)}, shape={getattr(val, 'shape', None)}, dtype={getattr(val, 'dtype', None)}")
    # SimParam-specific
    if hasattr(obj, "gen_map"):
        gm = getattr(obj, "gen_map")
        print("  gen_map:", type(gm), getattr(gm, "shape", None), getattr(gm, "dtype", None))
    if hasattr(obj, "traits"):
        tr = getattr(obj, "traits")
        print("  traits type:", type(tr))
        if tr is not None and hasattr(tr, "add_eff"):
            print("    add_eff.shape:", getattr(tr.add_eff, "shape", None))
            print("    loci_loc.shape:", getattr(tr.loci_loc, "shape", None))
    print("  repr (short):", repr(obj)[:200])
    print("========================\n")


# -------------------------------------------------------
# 2) EAGER debug run (no jit/vmap) - use this first.
# -------------------------------------------------------
def debug_run_simulation_nojit(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    print(">>> Running debug_run_simulation_nojit (eager)")
    # Eager inspections
    inspect_obj("founder_pop", founder_pop)
    inspect_obj("sp", sp)
    print("heritiability (type):", type(heritability), heritability)

    genetic_means = []
    genetic_variances = []
    current_pop = founder_pop
    next_id_start = int(current_pop.nInd)
    pheno_key, select_key, mate_key = jax.random.split(key, 3)

    for gen in range(n_generations + 1):
        print(f"\n[EAGER] generation {gen}/{n_generations}")
        p_subkey, pheno_key = jax.random.split(pheno_key)
        s_subkey, select_key = jax.random.split(select_key)
        m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)

        # print current_pop summary before BV/PHENO
        print(" current_pop before set_bv: type", type(current_pop))
        print("  geno shape:", getattr(current_pop.geno, "shape", None), "dtype:", getattr(current_pop.geno, "dtype", None))

        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
        print("  after set_bv, bv shape:", getattr(current_pop.bv, "shape", None))
        genetic_means.append(jnp.mean(current_pop.bv))
        genetic_variances.append(jnp.var(current_pop.bv))

        current_pop = set_pheno(key=p_subkey, pop=current_pop, traits=sp.traits, ploidy=sp.ploidy, h2=jnp.array([heritability]))
        print("  after set_pheno, pheno.shape:", getattr(current_pop.pheno, "shape", None))

        # selection
        selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
        print("  selected_indices type:", type(selected_indices), "shape:", selected_indices.shape)

        # mothers/fathers -- debug their contents
        mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
        fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
        print("  mothers.shape:", mothers.shape, "dtype:", mothers.dtype)
        print("  fathers.shape:", fathers.shape, "dtype:", fathers.dtype)

        cross_plan = jnp.stack([mothers, fathers], axis=1)
        print("  cross_plan.shape:", cross_plan.shape)

        # Debug before calling make_cross
        print("  Calling make_cross(...)")
        tmp_pop = make_cross(key=s_subkey, pop=current_pop, cross_plan=cross_plan, sp=sp, next_id_start=next_id_start)
        print("  make_cross returned type:", type(tmp_pop))
        print("   tmp_pop.geno.shape:", getattr(tmp_pop.geno, "shape", None))
        # replace current_pop
        current_pop = tmp_pop
        next_id_start += population_size

    return jnp.stack(genetic_means), jnp.stack(genetic_variances)


# -------------------------------------------------------
# 3) VMAP-aware function: use jax.debug.print for traced prints
# -------------------------------------------------------
# Note: we accept n_generations, population_size, n_select as *python* ints
# so they are static from JAX's point of view (recommended).
def run_simulation_vmapped(key, founder_pop, sp, heritability, n_generations, population_size, n_select):
    # Split RNGs used repeatedly
    pheno_key, select_key, mate_key = jax.random.split(key, 3)

    genetic_means = []
    genetic_variances = []
    current_pop = founder_pop
    next_id_start = current_pop.nInd  # should be an int

    # Debug prints that are safe under JAX tracing:
    jax.debug.print("TRACE vmapped start: founder_pop.geno.shape={gshape}", gshape=current_pop.geno.shape)
    jax.debug.print("TRACE vmapped start: sp.gen_map.shape={gm}", gm=sp.gen_map.shape)
    if sp.traits is not None:
        jax.debug.print("TRACE traits.add_eff.shape={a}", a=sp.traits.add_eff.shape)

    for gen in range(n_generations + 1):
        p_subkey, pheno_key = jax.random.split(pheno_key)
        s_subkey, select_key = jax.random.split(select_key)
        m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)

        current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
        jax.debug.print("TRACE gen={gen}: bv.shape={bshape}", gen=gen, bshape=current_pop.bv.shape)
        genetic_means.append(jnp.mean(current_pop.bv))
        genetic_variances.append(jnp.var(current_pop.bv))

        current_pop = set_pheno(key=p_subkey, pop=current_pop, traits=sp.traits, ploidy=sp.ploidy, h2=jnp.array([heritability]))
        jax.debug.print("TRACE gen={gen}: pheno.shape={pshape}", gen=gen, pshape=current_pop.pheno.shape)

        selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
        jax.debug.print("TRACE gen={gen}: selected_indices.shape={s}", gen=gen, s=selected_indices.shape)

        mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
        fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
        jax.debug.print("TRACE gen={gen}: mothers.shape={ms} fathers.shape={fs}", gen=gen, ms=mothers.shape, fs=fathers.shape)

        cross_plan = jnp.stack([mothers, fathers], axis=1)
        jax.debug.print("TRACE gen={gen}: cross_plan.shape={cp}", gen=gen, cp=cross_plan.shape)

        # Call make_cross (this is probably where the error appears)
        new_pop = make_cross(key=s_subkey, pop=current_pop, cross_plan=cross_plan, sp=sp, next_id_start=next_id_start)
        # Immediately print shapes of the new population's arrays
        jax.debug.print("TRACE gen={gen}: new_pop.geno.shape={g}", gen=gen, g=new_pop.geno.shape)
        current_pop = new_pop
        next_id_start += population_size

    # stack results so batched output has shape (n_reps, n_generations+1)
    return jnp.stack(genetic_means), jnp.stack(genetic_variances)


# -------------------------------------------------------
# 4) How to call (suggested debugging order)
# -------------------------------------------------------
# Example (do these steps interactively):
# 1) EAGER inspect before using vmap:
inspect_obj("founder_pop", founder_pop)
inspect_obj("sp", sp)
#
# 2) Run one replicate eagerly to see python-level prints:
k = jax.random.PRNGKey(123)
means, vars_ = debug_run_simulation_nojit(k, founder_pop, sp, 0.3, 
                                        n_generations=5, population_size=50, n_select=5)
#
# 3) If that works, try the vmapped version (this will produce traced prints):
n_reps = 4
keys = jax.random.split(jax.random.PRNGKey(123), n_reps)
herits = jnp.array([0.2, 0.3, 0.5, 0.7])
batched = jax.vmap(run_simulation_vmapped, in_axes=(0, None, None, 0, None, None, None), out_axes=(0,0))
means_b, vars_b = batched(keys, founder_pop, sp, herits, 5, 50, 5)
#
# Notes:
# - If you see `type=object` for any `.geno`/`.ibd` etc in the eager inspector,
#   that is the immediate cause of `AttributeError: 'object' object has no attribute 'shape'`.
# - The `make_cross(...)` call is the most likely place where non-array inputs slip in;
#   pay attention to the debug print immediately after the call in the eager run and
#   the jax.debug.print in the vmapped run.


=== INSPECT: founder_pop ===
 type: <class 'chewc.population.Population'>
  geno: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10, 5, 2, 1000), dtype=uint8
  ibd: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10, 5, 2, 1000), dtype=uint32
  id: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10,), dtype=int32
  iid: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10,), dtype=int32
  mother: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10,), dtype=int32
  father: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10,), dtype=int32
  pheno: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10, 0), dtype=float32
  bv: type=<class 'jaxlib._jax.ArrayImpl'>, shape=(10, 0), dtype=float32
  repr (short): Population(nInd=10, nTraits=0, has_ebv=No)

=== INSPECT: sp ===
 type: <class 'chewc.sp.SimParam'>
  gen_map: <class 'jaxlib._jax.ArrayImpl'> (5, 1000) float32
  traits type: <class 'chewc.trait.TraitCollection'>
    add_eff.shape: (1, 500)
    loci_loc.shape: (500,)
  repr (short): SimParam(nChr=5, nTraits=1, 