In [4]:
from chewc.config import SimConfig
from chewc.structs import SimParam, Population
from chewc.state import SimState, init_state_from_founders

In [8]:
# examples/min_smoke.py
import jax, jax.numpy as jnp
from chewc.popgen import msprime_pop   # your existing founder builder
from chewc.structs import SimParam
from chewc.config import SimConfig
from chewc.state import init_state_from_founders
from chewc.rollout import run_one

key = jax.random.PRNGKey(0)

# 1) Founders
founders, gen_map = msprime_pop(
    key, n_ind=20, n_chr=3, n_loci_per_chr=50, max_pop_size=80 + 5*100
)

# 2) SimParam (simple)
sp = SimParam(gen_map=gen_map, ploidy=founders.geno.shape[2])

# 3) Trait effects for T=1 trait (demo): (n_chr, n_loci, 1)
# Use small random effects; in practice derive this from your trait module.
trait_effects = jax.random.normal(key, shape=(gen_map.shape[0], gen_map.shape[1], 1)) * 0.1
h2 = jnp.array([0.5])

# 4) Static config
config = SimConfig(
    n_chr=gen_map.shape[0],
    ploidy=sp.ploidy,
    max_pop_size=founders.geno.shape[0]+20,
    n_loci_per_chr=gen_map.shape[1],
    n_generations=5,
    n_select=20,         # select all (since pop=100 below)
    population_size=100, # cohort size written each step
)

# 5) State
state0 = init_state_from_founders(key, founders, sp, config)

# 6) Roll one replicate
stateT, means, vars_ = run_one(state0, sp, config, h2, trait_effects)

print("means shape:", means.shape)  # (G, T)
print("vars  shape:", vars_.shape)  # (G, T)
print("final gen_idx:", int(stateT.gen_idx))
print("final write_pos:", int(stateT.write_pos))


IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<~int32[]>with<DynamicJaxprTrace>, Traced<~int32[]>with<DynamicJaxprTrace>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).