In [None]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

## How to use

In [None]:
import jax
import jax.numpy as jnp
from typing import Callable, Union

# Import the necessary classes and functions from your library
from chewc.sp import SimParam
from chewc.population import Population, quick_haplo
from chewc.trait import TraitCollection, add_trait_a
from chewc.pheno import set_pheno
from chewc.cross import make_cross
from chewc.pipe import update_pop_values

# --- 🧬 New High-Level Pipeline Functions ---

def select_ind(
    pop: Population,
    n_ind: int,
    use: Union[str, Callable[[Population], jnp.ndarray]] = "pheno",
    select_top: bool = True
) -> Population:
    """Selects the top or bottom individuals from a population."""
    if isinstance(use, str):
        selection_values = getattr(pop, use)
        if selection_values.ndim > 1:
            selection_values = selection_values[:, 0]  # Default to first trait
    else:
        selection_values = use(pop)
    
    # Use JAX's efficient top-k selection for performance
    if not select_top:
        selection_values = -selection_values
        
    _, indices = jax.lax.top_k(selection_values, k=n_ind)
    
    # Apply slicing to all array attributes of the Population object
    # NOTE: This requires that non-array fields (like misc dictionaries) are
    # marked as static in the flax dataclass for tree_map to work correctly.
    # For this example, we assume all fields are JAX arrays.
    return jax.tree_util.tree_map(lambda x: x[indices] if isinstance(x, jnp.ndarray) else x, pop)

def select_and_cross(
    key: jax.random.PRNGKey,
    pop: Population,
    sp: SimParam,
    n_parents: int,
    n_crosses: int,
    use: str = "pheno"
) -> Population:
    """Selects parents and performs random crosses to create a new generation."""
    key_select, key_cross = jax.random.split(key)
    
    # 1. Select the best individuals to form the parent pool (sexes ignored)
    parent_pool = select_ind(pop, n_parents, use=use)
    
    # 2. Generate a random cross plan from the selected parent pool
    # Any parent can be a mother or a father
    mother_iids = jax.random.choice(key_cross, parent_pool.iid, shape=(n_crosses,))
    key_cross, _ = jax.random.split(key_cross) # Split key for next choice
    father_iids = jax.random.choice(key_cross, parent_pool.iid, shape=(n_crosses,))
    
    cross_plan = jnp.stack([mother_iids, father_iids], axis=1)

    # 3. Create progeny using the existing low-level function
    progeny = make_cross(key_cross, pop, cross_plan, sp)
    
    return progeny

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2-6. (Setup code remains the same as before) ---
# Define Genome Blueprint
n_chr, n_loci_per_chr, ploidy = 3, 100, 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# Instantiate SimParam
SP = SimParam(gen_map=gen_map, centromere=centromeres, ploidy=ploidy)

# Create Founder Population
key, pop_key = jax.random.split(key)
founder_pop = quick_haplo(key=pop_key, sim_param=SP, n_ind=100, inbred=False)
SP = SP.replace(founderPop=founder_pop)

# Add Single Additive Trait
trait_mean = 0
trait_var = 1
trait_h2 = .2

key, trait_key = jax.random.split(key)
SP_with_trait = add_trait_a(
    key=trait_key,
    sim_param=SP,
    n_qtl_per_chr=100,
    mean=jnp.array([trait_mean]),
    var=jnp.array([trait_var])
)

# Set Initial Phenotypes
key, pheno_key = jax.random.split(key)
h2 = jnp.array([trait_h2])
founder_pop_with_pheno = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=SP_with_trait.traits,
    ploidy=SP_with_trait.ploidy,
    h2=h2
)

# --- 8. Burn-in Selection for 20 Generations (Simplified Loop) ---
print("\n--- Starting Burn-in Phenotypic Selection (20 Generations) ---")

pop_burn_in = founder_pop_with_pheno
sp_burn_in = SP_with_trait

# Selection parameters
n_parents_select = 5  # Total number of parents to select
n_progeny = 100

for gen in range(20):
    key, cross_key, update_key = jax.random.split(key, 3)

    # **SINGLE, HIGH-LEVEL CALL** to handle a full generation
    progeny_pop = select_and_cross(
        key=cross_key,
        pop=pop_burn_in,
        sp=sp_burn_in,
        n_parents=n_parents_select,
        n_crosses=n_progeny,
        use="pheno" # Select based on phenotype
    )
    
    # Update genetic and phenotypic values for the new generation
    pop_burn_in = update_pop_values(update_key, progeny_pop, sp_burn_in, h2=h2)

    # Track Progress
    mean_pheno = jnp.mean(pop_burn_in.pheno)
    print(f"Generation {gen + 1:2d}/{20} | Mean Phenotype: {mean_pheno:.4f}")

print("\n--- Burn-in Complete ---")
print("Final population state after 20 generations of selection:")
print(pop_burn_in)


--- Starting Burn-in Phenotypic Selection (20 Generations) ---
Generation  1/20 | Mean Phenotype: 1.1126
Generation  2/20 | Mean Phenotype: 2.0168
Generation  3/20 | Mean Phenotype: 2.6536
Generation  4/20 | Mean Phenotype: 3.5688
Generation  5/20 | Mean Phenotype: 4.5619
Generation  6/20 | Mean Phenotype: 5.0159
Generation  7/20 | Mean Phenotype: 5.5932
Generation  8/20 | Mean Phenotype: 6.0638
Generation  9/20 | Mean Phenotype: 6.1357
Generation 10/20 | Mean Phenotype: 6.2606
Generation 11/20 | Mean Phenotype: 6.3443
Generation 12/20 | Mean Phenotype: 6.4400
Generation 13/20 | Mean Phenotype: 6.5349
Generation 14/20 | Mean Phenotype: 6.6008
Generation 15/20 | Mean Phenotype: 6.6359
Generation 16/20 | Mean Phenotype: 6.6355
Generation 17/20 | Mean Phenotype: 6.6355
Generation 18/20 | Mean Phenotype: 6.6355
Generation 19/20 | Mean Phenotype: 6.6355
Generation 20/20 | Mean Phenotype: 6.6355

--- Burn-in Complete ---
Final population state after 20 generations of selection:
Population(n