In [3]:
#| default_exp k_cross

In [4]:
#| export
# chewc/_src/genetics/cross_kernels.py
from __future__ import annotations
import jax, jax.numpy as jnp

def make_cross_plan(
    key: jax.Array,
    selected_idx: jnp.ndarray,    # (K,)
    population_size: int,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Return (mothers, fathers) each (population_size,). For now, clone mating:
    mothers sampled from selected with replacement; fathers = mothers.
    """
    Nsel = selected_idx.shape[0]
    k = key
    # sample mothers with replacement
    # Using randint to avoid XLA scatter-permutation complexities
    choice = jax.random.randint(k, (population_size,), 0, Nsel)
    mothers = selected_idx[choice]
    fathers = mothers
    return mothers, fathers

def clonal_progeny_from_parents(
    parent_geno: jnp.ndarray,   # (N, n_chr, ploidy, n_loci)
    parent_ibd: jnp.ndarray,    # same shape (unused here, but kept for API consistency)
    mothers: jnp.ndarray,       # (M,)
    fathers: jnp.ndarray,       # (M,) (unused for clonal)
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Return progeny that copy the mother's genotype and ibd.
    Shape: (M, n_chr, ploidy, n_loci).
    """
    geno_child = parent_geno[mothers]
    ibd_child  = parent_ibd[mothers]
    return geno_child, ibd_child


In [5]:
#| hide
import nbdev; nbdev.nbdev_export()