In [5]:
#| default_exp rollout

In [8]:
#| export
# chewc/_src/sim/rollout.py
from __future__ import annotations
from functools import partial
import jax, jax.numpy as jnp
from jax import lax

@partial(jax.jit, static_argnames=("config",))
def run_one(
    state0,
    sp,
    config,
    h2,                # (T,)
    trait_effects,     # (n_chr, n_loci, T)
):
    from chewc.k_fused import generation_step
    def body(st, _):
        return generation_step(st, sp, config, h2, trait_effects)
    stateT, metrics = lax.scan(body, state0, None, length=config.n_generations)
    means = jnp.stack([m for m, _ in metrics]) if isinstance(metrics[0], tuple) else metrics
    # Our metrics are (mean, var); unpack cleanly:
    means = jnp.stack([m[0] for m in metrics])   # (G, T)
    vars_ = jnp.stack([m[1] for m in metrics])   # (G, T)
    return stateT, means, vars_

@partial(jax.jit, static_argnames=("config", "n_reps"))
def run_replicates(
    master_key,
    founder_pop,
    sp,
    config,
    h2,
    trait_effects,
    n_reps: int,
):
    from chewc.state import init_state_from_founders
    def init_rep(r):
        key_r = jax.random.fold_in(master_key, r)
        return init_state_from_founders(key_r, founder_pop, sp, config)
    def run_r(r):
        st0 = init_rep(r)
        return run_one(st0, sp, config, h2, trait_effects)
    reps = jnp.arange(n_reps)
    stateTs, all_means, all_vars = jax.vmap(run_r, in_axes=(0,))(reps)
    return stateTs, all_means, all_vars


In [9]:
#| hide
import nbdev; nbdev.nbdev_export()