In [None]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

## How to use

In [None]:
import jax
import jax.numpy as jnp

# Import the necessary classes and functions from your library
from chewc.sp import SimParam
from chewc.population import Population, quick_haplo
from chewc.trait import TraitCollection, add_trait_a
from chewc.pheno import set_pheno
from chewc.cross import make_cross
from chewc.pipe import update_pop_values # <-- Import from the new module

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2-6. (Setup code remains the same as before) ---
# Define Genome Blueprint
n_chr, n_loci_per_chr, ploidy = 3, 100, 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# Instantiate SimParam
SP = SimParam(gen_map=gen_map, centromere=centromeres, ploidy=ploidy)

# Create Founder Population
key, pop_key = jax.random.split(key)
founder_pop = quick_haplo(key=pop_key, sim_param=SP, n_ind=100, inbred=False)
SP = SP.replace(founderPop=founder_pop)

# Add Single Additive Trait
key, trait_key = jax.random.split(key)
SP_with_trait = add_trait_a(
    key=trait_key,
    sim_param=SP,
    n_qtl_per_chr=100,
    mean=jnp.array([10.0]),
    var=jnp.array([1.5])
)

# Set Initial Phenotypes
key, pheno_key = jax.random.split(key)
h2 = jnp.array([0.9])
founder_pop_with_pheno = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=SP_with_trait.traits,
    ploidy=SP_with_trait.ploidy,
    h2=h2
)

# --- 8. Burn-in Selection for 20 Generations ---
print("\n--- Starting Burn-in Phenotypic Selection (20 Generations) ---")

pop_burn_in = founder_pop_with_pheno
sp_burn_in = SP_with_trait

# Selection parameters
n_males_select = 10
n_females_select = 25
n_progeny = 100

for gen in range(20):
    key, sel_key, cross_key, update_key = jax.random.split(key, 4)

    # Selection
    male_indices = jnp.where(pop_burn_in.sex == 0, jnp.arange(pop_burn_in.nInd), pop_burn_in.nInd)
    female_indices = jnp.where(pop_burn_in.sex == 1, jnp.arange(pop_burn_in.nInd), pop_burn_in.nInd)
    phenotypes = pop_burn_in.pheno[:, 0]
    sorted_indices = jnp.argsort(phenotypes)[::-1]
    top_males = jnp.intersect1d(sorted_indices, male_indices)[:n_males_select]
    top_females = jnp.intersect1d(sorted_indices, female_indices)[:n_females_select]

    # Mating
    mother_iids = jax.random.choice(sel_key, top_females, shape=(n_progeny,))
    key, sel_key = jax.random.split(sel_key)
    father_iids = jax.random.choice(sel_key, top_males, shape=(n_progeny,))
    cross_plan = jnp.stack([mother_iids, father_iids], axis=1)

    # Create Next Generation
    progeny_pop = make_cross(cross_key, pop_burn_in, cross_plan, sp_burn_in)
    
    # Update Values for the New Generation using the pipeline function
    pop_burn_in = update_pop_values(update_key, progeny_pop, sp_burn_in, h2=h2)

    # Track Progress
    mean_pheno = jnp.mean(pop_burn_in.pheno)
    print(f"Generation {gen + 1:2d}/{20} | Mean Phenotype: {mean_pheno:.4f}")

print("\n--- Burn-in Complete ---")
print("Final population state after 20 generations of selection:")
print(pop_burn_in)



--- Starting Burn-in Phenotypic Selection (20 Generations) ---
Generation  1/20 | Mean Phenotype: 9.8385
Generation  2/20 | Mean Phenotype: 9.9082
Generation  3/20 | Mean Phenotype: 9.7989
Generation  4/20 | Mean Phenotype: 9.9779
Generation  5/20 | Mean Phenotype: 9.6955
Generation  6/20 | Mean Phenotype: 9.5632
Generation  7/20 | Mean Phenotype: 9.6790
Generation  8/20 | Mean Phenotype: 9.7858
Generation  9/20 | Mean Phenotype: 9.9831
Generation 10/20 | Mean Phenotype: 10.0481
Generation 11/20 | Mean Phenotype: 9.9873
Generation 12/20 | Mean Phenotype: 9.9383
Generation 13/20 | Mean Phenotype: 10.2945
Generation 14/20 | Mean Phenotype: 10.4920
Generation 15/20 | Mean Phenotype: 10.5587
Generation 16/20 | Mean Phenotype: 10.7865
Generation 17/20 | Mean Phenotype: 10.6711
Generation 18/20 | Mean Phenotype: 10.5779
Generation 19/20 | Mean Phenotype: 10.9127
Generation 20/20 | Mean Phenotype: 11.0136

--- Burn-in Complete ---
Final population state after 20 generations of selection:
Pop