In [1]:
#| default_exp pipe

In [2]:
#|export
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_pipe.ipynb.

# %% auto 0
__all__ = ['phenotypic_selection_step', 'run_simulation_cycles']

# %% ../nbs/00_pipe.ipynb 1
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from chewc.structs import BreedingState, Population, Trait, GeneticMap
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair

# %% ../nbs/00_pipe.ipynb 2
@partial(jax.jit, static_argnames=("n_select", "n_offspring", "max_crossovers"))
def phenotypic_selection_step(
    carry: BreedingState,
    _, # Placeholder for lax.scan's iteration number
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
) -> Tuple[BreedingState, jnp.ndarray]:
    """
    Executes one full cycle of phenotypic selection and breeding.
    
    This function performs the following steps:
    1. Phenotyping: Calculates phenotypes and True Breeding Values (TBVs).
    2. Selection: Selects top `n_select` parents based on the first trait (index 0).
    3. Mating: Generates a random mating plan.
    4. Crossing: Simulates meiosis and crossover to create the next generation.
    5. Update: Updates the simulation state (generation, IDs, pedigree).

    Args:
        carry (BreedingState): Current simulation state.
        _ (Any): Ignored loop iterator from lax.scan.
        trait (Trait): Trait architecture.
        genetic_map (GeneticMap): Genetic map for recombination.
        heritabilities (jnp.ndarray): Heritability values for phenotypes.
        n_select (int): Number of parents to select.
        n_offspring (int): Number of offspring to generate.
        max_crossovers (int): Maximum crossovers per chromosome.

    Returns:
        next_state (BreedingState): The updated simulation state.
        metrics (jnp.ndarray): Array [mean_tbv, mean_phenotype] for the first trait.
    """
    key, pheno_key, mate_key, cross_key = jax.random.split(carry.key, 4)
    current_pop = carry.population

    # 1. Evaluate the population (Phenotype)
    phenotypes, tbvs = calculate_phenotypes(
        key=pheno_key, population=current_pop, trait=trait, heritability=heritabilities
    )

    # 2. Select top parents (Selection)
    # Note: Currently hardcoded to select based on the first phenotype column (index 0)
    selected_parents = select_top_k(current_pop, phenotypes[:, 0], k=n_select)

    # 3. Generate a random mating plan (Mating)
    pairings = random_mating(mate_key, n_parents=n_select, n_crosses=n_offspring)
    
    # Gather parent genotypes/IBD
    mothers_geno = selected_parents.geno[pairings[:, 0]]
    fathers_geno = selected_parents.geno[pairings[:, 1]]
    mothers_ibd = selected_parents.ibd[pairings[:, 0]]
    fathers_ibd = selected_parents.ibd[pairings[:, 1]]

    # 4. Create the next generation (Crossing)
    vmapped_cross = jax.vmap(
        partial(cross_pair, max_crossovers=max_crossovers), 
        in_axes=(0, 0, 0, 0, 0, None)
    )
    offspring_keys = jax.random.split(cross_key, n_offspring)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map
    )

    # 5. Form the new population and update the state
    new_generation = carry.generation + 1
    new_ids = jnp.arange(n_offspring, dtype=jnp.int32) + carry.next_id
    new_meta = jnp.stack(
        [
            new_ids,
            selected_parents.meta[pairings[:, 0], 0], # Mother IDs
            selected_parents.meta[pairings[:, 1], 0], # Father IDs
            jnp.full((n_offspring,), new_generation, dtype=jnp.int32),
        ],
        axis=-1,
    )
    new_population = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)

    next_state = BreedingState(
        population=new_population,
        key=key,
        generation=new_generation,
        next_id=carry.next_id + n_offspring
    )

    # Track metrics: Mean True Breeding Value, Mean Phenotype (Trait 0)
    metrics = jnp.array([jnp.mean(tbvs[:, 0]), jnp.mean(phenotypes[:, 0])])

    return next_state, metrics

# %% ../nbs/00_pipe.ipynb 3
def run_simulation_cycles(
    initial_state: BreedingState,
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_cycles: int,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
) -> Tuple[BreedingState, jnp.ndarray]:
    """
    Runs the entire multi-cycle simulation using lax.scan for high performance.
    
    Args:
        initial_state (BreedingState): Starting state of the simulation.
        trait (Trait): Trait architecture.
        genetic_map (GeneticMap): Genetic map.
        heritabilities (jnp.ndarray): Heritability values.
        n_cycles (int): Number of selection cycles to run.
        n_select (int): Number of parents to select per cycle.
        n_offspring (int): Number of offspring to produce per cycle.
        max_crossovers (int): Maximum crossovers per chromosome.

    Returns:
        final_state (BreedingState): The state after the last cycle.
        history (jnp.ndarray): Array of shape (n_cycles, 2) containing 
                               [mean_tbv, mean_phenotype] for each generation.
    """
    # Create a partial function to fix static arguments for the scan loop
    scan_fn = partial(
        phenotypic_selection_step,
        trait=trait,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_select=n_select,
        n_offspring=n_offspring,
        max_crossovers=max_crossovers
    )

    # Execute the scan loop
    # scan returns: (final_carry, stacked_outputs)
    final_state, history = lax.scan(
        scan_fn, initial_state, None, length=n_cycles
    )

    return final_state, history



In [3]:
#| hide
import nbdev; nbdev.nbdev_export()