In [None]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

## How to use

In [None]:
import jax
import jax.numpy as jnp
import time

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2. Imports from the 'chewc' library ---
from chewc.sp import SimParam
from chewc.population import Population, msprime_pop # Use msprime_pop
from chewc.trait import add_trait_a
from chewc.burnin import run_burnin
# Import the generation runner
from chewc.pipe import run_generation

# --- 3. Define Genome and Founder Population Blueprint ---
# Parameters for the msprime simulation
n_chr, n_loci_per_chr, ploidy = 3, 100, 2
n_founder_ind = 100

# --- 4. Define Trait Architecture ---
trait_mean = 0.0
trait_var = 1.0
trait_h2 = 0.05

# --- 5. Define Burn-in Parameters ---
n_parents_select = 10
n_progeny = 100
burn_in_generations = 20

# --- 6. Create Founder Population using msprime_pop ---
# We need a temporary SimParam object just to pass structural info to msprime_pop
temp_sp = SimParam(
    gen_map=jnp.empty((n_chr, n_loci_per_chr)), # Shape placeholder
    centromere=jnp.full(n_chr, 0.5),          # Dummy centromeres
    ploidy=ploidy
)

key, pop_key = jax.random.split(key)
# msprime_pop creates the founder population and the real genetic map
founder_pop = msprime_pop(
    key=pop_key,
    sim_param=temp_sp,
    n_ind=n_founder_ind,
    n_loci_per_chr=n_loci_per_chr
)

# --- 7. Instantiate the Definitive SimParam ---
# Now, create the real SimParam using the genetic map generated by msprime
sp = SimParam(
    gen_map=founder_pop.miscPop['genetic_map_cm'],
    centromere=jnp.full(founder_pop.geno.shape[1], 0.5), # Use real shape
    ploidy=founder_pop.geno.shape[2],
    founderPop=founder_pop
)


# --- 8. Add Trait to SimParam ---
key, trait_key = jax.random.split(key)
sp = add_trait_a(
    key=trait_key,
    sim_param=sp,
    n_qtl_per_chr=100, # Use all loci as QTLs for this example
    mean=jnp.array([trait_mean]),
    var=jnp.array([trait_var])
)

# --- 9. Run the entire Burn-in Phase with a single function call ---
key, burnin_key = jax.random.split(key)
h2 = jnp.array([trait_h2])

start_time = time.time()
final_pop = run_burnin(
    key=burnin_key,
    sp=sp,
    n_generations=burn_in_generations,
    n_parents=n_parents_select,
    n_progeny=n_progeny,
    h2=h2,
    verbose=True
)
final_pop.geno.block_until_ready()
end_time = time.time()

# --- 10. Report Results ---
total_time = end_time - start_time
avg_time_per_gen = total_time / burn_in_generations

print("-" * 50)
print(f"Total simulation time for burn-in: {total_time:.4f} seconds.")
print(f"Average time per generation (including compilation): {avg_time_per_gen * 1000:.4f} ms")
print(f"\nFinal population state after {burn_in_generations} generations:")
print(final_pop)

# --- 11. Create a new generation from the final population ---
print("\n" + "-" * 50)
print("Running one more generation with all parents...")

# Create a new key for the next generation
key, next_gen_key = jax.random.split(key)

# All individuals in `final_pop` will be parents
n_all_parents = final_pop.nInd
n_new_progeny = 2000

# Run a single generation
expanded_pop = run_generation(
    key=next_gen_key,
    pop=final_pop,
    h2=h2,
    n_parents=n_all_parents,
    n_crosses=n_new_progeny,
    use_pheno_selection=True, # Selection strategy doesn't matter here
    select_top_parents=True,  # since we use all parents.
    ploidy=sp.ploidy,
    gen_map=sp.gen_map,
    recomb_param_v=sp.recomb_params[0],
    traits=sp.traits
)
expanded_pop.geno.block_until_ready()

print(f"\nCreated new expanded population of {n_new_progeny} individuals:")
print(expanded_pop)

--- Starting Accelerated Burn-in (20 Generations) ---
Generation  1/20 | Mean Phenotype: 0.1321
