In [1]:
#| default_exp k_fused

In [2]:
#| export
# chewc/_src/sim/generation_step.py
from __future__ import annotations
from functools import partial
import jax, jax.numpy as jnp
from jax import random
from chewc.k_pheno import set_pheno_kernel
from chewc.k_select import select_topk, score_from_bv_pheno
from chewc.k_cross import make_cross_plan, clonal_progeny_from_parents

@partial(jax.jit, static_argnames=("config",))
def generation_step(state, sp, config, h2, trait_effects):
    """
    One fused generation step:
      1) phenotype current actives
      2) metrics on current actives
      3) select parents (top-k)
      4) make progeny (clonal now)
      5) write slice
      6) advance pointers
    Returns: (new_state, metrics_tuple)
    """
    key = state.key
    key, k_pheno, k_sel = random.split(key, 3)

    # 1) phenotype current actives
    bv, pheno = set_pheno_kernel(
        k_pheno, state.geno, state.is_active, trait_effects, sp.ploidy, h2
    )

    # 2) metrics on current actives (before we overwrite actives)
    active = state.is_active
    ph_act = jnp.where(active[:, None], pheno, jnp.nan)
    mean_ph = jnp.nanmean(ph_act, axis=0)   # (T,)
    var_ph  = jnp.nanvar(ph_act, axis=0)    # (T,)

    # 3) select parents via top-k on chosen score
    scores = score_from_bv_pheno(bv=bv, pheno=pheno, is_active=active, use="pheno", trait_idx=0)
    parent_idx = select_topk(scores, config.n_select)

    # 4) make progeny (clonal for now)
    mothers, fathers = make_cross_plan(k_sel, parent_idx, config.population_size)
    prog_geno, prog_ibd = clonal_progeny_from_parents(state.geno, state.ibd, mothers, fathers)

    # 5) write slice for the new cohort
    start = state.write_pos
    end   = start + config.population_size

    geno2  = state.geno.at[start:end].set(prog_geno)
    ibd2   = state.ibd.at[start:end].set(prog_ibd)
    # reset pheno/bv for fresh cohort (theyâ€™ll be phenotyped next gen)
    bv2    = state.bv.at[start:end].set(jnp.nan)
    ph2    = state.pheno.at[start:end].set(jnp.nan)

    # set pedigree/ids/gen
    new_ids = jnp.arange(state.next_id, state.next_id + config.population_size, dtype=state.id.dtype)
    id2     = state.id.at[start:end].set(new_ids)
    mother2 = state.mother.at[start:end].set(state.id[mothers])
    father2 = state.father.at[start:end].set(state.id[fathers])
    gen2    = state.gen.at[start:end].set(state.gen_idx + 1)  # generation counter

    # activate only this slice
    act2 = jnp.zeros_like(state.is_active, dtype=bool).at[start:end].set(True)

    new_state = state.replace(
        geno=geno2, ibd=ibd2, bv=bv2, pheno=ph2,
        id=id2, mother=mother2, father=father2, gen=gen2,
        is_active=act2,
        key=key,
        write_pos=(end % config.max_pop_size),
        gen_idx=state.gen_idx + 1,
        next_id=state.next_id + config.population_size,
    )
    metrics = (mean_ph, var_ph)
    return new_state, metrics


In [3]:
#| hide
import nbdev; nbdev.nbdev_export()