In [None]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

## How to use

In [None]:
import jax
import jax.numpy as jnp

# Import the necessary classes and functions from your library
from chewc.sp import SimParam
from chewc.population import quick_haplo, Population
# Import the refactored trait-related classes and functions
from chewc.trait import add_trait_a, TraitCollection
# Import the JIT-compatible phenotype function
from chewc.pheno import set_pheno

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2. Define the Genome's "Blueprint" ---
n_chr = 3
n_loci_per_chr = 100
ploidy = 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# --- 3. Instantiate Initial Simulation Parameters ---
# This remains the same.
SP = SimParam(
    gen_map=gen_map,
    centromere=centromeres,
    ploidy=ploidy
)

print("--- Initial Simulation Parameters Created ---")
print(SP)
print("-" * 35)

# --- 4. Create the Founder Population ---
# This remains the same.
key, pop_key = jax.random.split(key)
n_founders = 50

founder_pop = quick_haplo(
    key=pop_key,
    sim_param=SP,
    n_ind=n_founders,
    inbred=False
)

print("\n--- Founder Population Created ---")
print(founder_pop)
print(f"Genotype array shape: {founder_pop.geno.shape}")
print("-" * 35)


# --- 5. Finalize Simulation Parameters ---
# This remains the same.
SP = SP.replace(founderPop=founder_pop)

print("\n--- Simulation Parameters Finalized with Founder Pop ---")
print(SP)
print(f"Number of traits before: {SP.n_traits}")
print("-" * 35)


# --- 6. Add Two Correlated Additive Traits ---
# This section is MODIFIED to correctly inspect the new TraitCollection
key, trait_key = jax.random.split(key)

n_qtl_per_chr = 100
trait_means = jnp.array([10.0, 20.0])
trait_vars = jnp.array([1.5, 2.5])
trait_cor = jnp.array([[1.0, 0.8],
                       [0.8, 1.0]])

# This function call is the same, but what it returns has changed.
SP_with_traits = add_trait_a(
    key=trait_key,
    sim_param=SP,
    n_qtl_per_chr=n_qtl_per_chr,
    mean=trait_means,
    var=trait_vars,
    cor_a=trait_cor
)

print("\n--- Correlated Additive Traits Added ---")
# The __repr__ for SimParam will now work correctly
print(f"SimParam object updated: {SP_with_traits}")

# **FIXED**: We now access the single TraitCollection object
trait_collection = SP_with_traits.traits

# **FIXED**: Inspect the properties of the vectorized TraitCollection
print(f"Number of traits after: {trait_collection.n_traits}")
print("\nDetails of the new TraitCollection:")
print(f"  - Number of shared QTL: {trait_collection.n_loci}")
print(f"  - Shape of additive effects array: {trait_collection.add_eff.shape}")
print(f"  - Intercepts for all traits: {trait_collection.intercept}")
print("-" * 35)

# --- 7. Set Phenotypes for the Founder Population ---
# This section is MODIFIED to use the new JIT-compatible function signature
key, pheno_key = jax.random.split(key)

h2 = jnp.array([0.5, 0.7])
cor_e = jnp.array([[1.0, 0.3],
                   [0.3, 1.0]])

# **FIXED**: Calling the refactored, JIT-compatible set_pheno function.
# We now pass the TraitCollection object directly.
founder_pop_with_pheno = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=trait_collection,
    ploidy=SP_with_traits.ploidy, # Pass ploidy from the SimParam object
    h2=h2,
    cor_e=cor_e
)
# The first time this runs, JAX will JIT-compile the function.
# Subsequent calls will be lightning fast! ⚡️
print("\n--- Phenotypes Calculated for Founder Population (JIT-compiled) ---")
print(founder_pop_with_pheno)
print(f"\nPhenotype array shape: {founder_pop_with_pheno.pheno.shape}")
print("\nExample phenotypes (first 5 individuals):")
print(founder_pop_with_pheno.pheno[:5, :])
print("-" * 35)



--- Initial Simulation Parameters Created ---
SimParam(nChr=3, nTraits=0, ploidy=2, sexes='no')
-----------------------------------

--- Founder Population Created ---
Population(nInd=50, nTraits=0, has_ebv=No)
Genotype array shape: (50, 3, 2, 100)
-----------------------------------

--- Simulation Parameters Finalized with Founder Pop ---
SimParam(nChr=3, nTraits=0, ploidy=2, sexes='no')
Number of traits before: 0
-----------------------------------

--- Correlated Additive Traits Added ---
SimParam object updated: SimParam(nChr=3, nTraits=2, ploidy=2, sexes='no')
Number of traits after: 2

Details of the new TraitCollection:
  - Number of shared QTL: 300
  - Shape of additive effects array: (2, 300)
  - Intercepts for all traits: [ 9.365508 17.48284 ]
-----------------------------------

--- Phenotypes Calculated for Founder Population (JIT-compiled) ---
Population(nInd=50, nTraits=0, has_ebv=No)

Phenotype array shape: (50, 2)

Example phenotypes (first 5 individuals):
[[ 9.620646 