In [None]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

## How to use

In [None]:
import jax
import jax.numpy as jnp
import time

# Import the necessary classes and the NEW JIT-compiled engine
from chewc.sp import SimParam
from chewc.population import Population, quick_haplo
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.pipe import run_generation # The new, all-in-one JIT engine
from functools import partial

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2. Define Genome Blueprint ---
n_chr, n_loci_per_chr, ploidy = 3, 100, 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# --- 3. Instantiate SimParam ---
# We no longer need to update SimParam in the loop, so we can call it `sp`
sp = SimParam(gen_map=gen_map, centromere=centromeres, ploidy=ploidy)

# --- 4. Create Founder Population ---
key, pop_key = jax.random.split(key)
founder_pop = quick_haplo(key=pop_key, sim_param=sp, n_ind=100, inbred=False)
sp = sp.replace(founderPop=founder_pop)

# --- 5. Add Single Additive Trait ---
trait_mean = 0.0
trait_var = 1.0
trait_h2 = 0.1

key, trait_key = jax.random.split(key)
# All subsequent operations will use this version of SimParam
sp = add_trait_a(
    key=trait_key,
    sim_param=sp,
    n_qtl_per_chr=100,
    mean=jnp.array([trait_mean]),
    var=jnp.array([trait_var])
)

# --- 6. Set Initial Phenotypes ---
key, pheno_key = jax.random.split(key)
h2 = jnp.array([trait_h2])
current_pop = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=sp.traits,
    ploidy=sp.ploidy,
    h2=h2
)

# --- 7. Simulation Parameters ---
n_parents_select = 100   # Total number of parents to select
n_progeny = 1000
burn_in_generations = 20

print(f"--- Starting Accelerated Burn-in ({burn_in_generations} Generations) ---")
print("Compiling the JIT function for the first generation (this may take a moment)...")

# --- 8. Accelerated Burn-in Loop ---
# This loop now calls the single, JIT-compiled `run_generation` function.
# The first call will trigger compilation, and subsequent calls will be extremely fast.

start_time = time.time()
for gen in range(burn_in_generations):
    # Split the key for the next generation. All sub-keys are handled inside run_generation.
    key, generation_key = jax.random.split(key)

    # **SINGLE, HIGHLY-ACCELERATED CALL**
    current_pop = run_generation(
        key=generation_key,
        pop=current_pop,
        h2=h2,
        n_parents=n_parents_select,
        n_crosses=n_progeny,
        # Pass static arguments
        use_pheno_selection=True,
        select_top_parents=True,
        ploidy=sp.ploidy,
        # Pass static SimParam components
        gen_map=sp.gen_map,
        recomb_param_v=sp.recomb_params[0],
        traits=sp.traits
    )
    
    # Block until the computation is actually finished to get accurate timing and stats.
    # This is crucial for benchmarking JAX code.
    mean_pheno = jnp.mean(current_pop.pheno).block_until_ready()
    
    if gen == 0:
        compilation_time = time.time() - start_time
        print(f"JIT compilation finished in {compilation_time:.2f} seconds.")
        print("-" * 50)
        
    print(f"Generation {gen + 1:2d}/{burn_in_generations} | Mean Phenotype: {mean_pheno:.4f}")

end_time = time.time()
total_time = end_time - start_time
avg_time_per_gen = (total_time - compilation_time) / (burn_in_generations - 1) if burn_in_generations > 1 else total_time

print("-" * 50)
print("\n--- Burn-in Complete ---")
print(f"Total simulation time: {total_time:.4f} seconds.")
print(f"Average time per generation (after compilation): {avg_time_per_gen * 1000:.4f} ms")
print(f"\nFinal population state after {burn_in_generations} generations:")
print(current_pop)



--- Starting Accelerated Burn-in (20 Generations) ---
Compiling the JIT function for the first generation (this may take a moment)...
JIT compilation finished in 4.26 seconds.
--------------------------------------------------
Generation  1/20 | Mean Phenotype: -0.0178
Generation  2/20 | Mean Phenotype: 0.5732
Generation  3/20 | Mean Phenotype: 1.1044
Generation  4/20 | Mean Phenotype: 1.5740
Generation  5/20 | Mean Phenotype: 2.2945
Generation  6/20 | Mean Phenotype: 3.1033
Generation  7/20 | Mean Phenotype: 3.5431
Generation  8/20 | Mean Phenotype: 4.1015
Generation  9/20 | Mean Phenotype: 4.5527
Generation 10/20 | Mean Phenotype: 5.1959
Generation 11/20 | Mean Phenotype: 5.5863
Generation 12/20 | Mean Phenotype: 5.9911
Generation 13/20 | Mean Phenotype: 6.3606
Generation 14/20 | Mean Phenotype: 6.7998
Generation 15/20 | Mean Phenotype: 7.2815
Generation 16/20 | Mean Phenotype: 7.6483
Generation 17/20 | Mean Phenotype: 8.0602
Generation 18/20 | Mean Phenotype: 8.2747
Generation 19/20