In [1]:
# | hide
# from chewc.core import *
# from chewc.pop import *
# from chewc.sp import *
# from chewc.trait import *

# chewc

> JAX breeding sim

## Developer Guide

If you are new to using `nbdev` here are some useful pointers to get you started.

### Install chewc in Development mode

```sh
# make sure chewc package is installed in development mode
$ pip install -e .

# make changes under nbs/ directory
# ...

# compile to have changes apply to chewc
$ nbdev_prepare
```

## Usage

### Installation

Install latest from the GitHub [repository][repo]:

```sh
$ pip install git+https://github.com/cjGO/chewc.git
```


[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/
[pypi]: https://pypi.org/project/chewc/
[conda]: https://anaconda.org/cjGO/chewc

### Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].

[repo]: https://github.com/cjGO/chewc
[docs]: https://cjGO.github.io/chewc/

### Multi trait simulation example

This code block sets up the simulation parameters, creates a founder population in linkage equilbrium across loci, and simulates 2 correlated traits.

Then a single generation is simulated from the founder population and ABLUP and GBLUP models are fitted for the offspring.

Both the breeding simulation and prediction models are running on JAX.

## How to use

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple
from scipy.sparse import coo_matrix

# JAX's iterative solver and sparse matrix format
from jax.scipy.sparse.linalg import cg
import jax.experimental.sparse as jsparse

# Assume the 'chewc' library is installed
from chewc.structs import (
    Population,
    Trait,
    GeneticMap,
    quick_haplo,
    add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair
from chewc.blup import *


if __name__ == "__main__":
    # --- Parameters ---
    N_FOUNDERS, N_SELECT, N_OFFSPRING = 200, 20, 200
    N_CHR, N_LOCI = 5, 1000
    SEED = 42
    N_TRAITS = 2
    h2_trait1, h2_trait2 = 0.6, 0.4
    genetic_corr = 0.5
    
    # --- Setup (Co)variance matrices ---
    var_g1, var_g2 = h2_trait1, h2_trait2
    cov_g12 = genetic_corr * jnp.sqrt(var_g1 * var_g2)
    G0 = jnp.array([[var_g1, cov_g12], [cov_g12, var_g2]])
    
    var_e1, var_e2 = 1 - h2_trait1, 1 - h2_trait2
    R0 = jnp.diag(jnp.array([var_e1, var_e2]))
    
    G0_inv, R0_inv = jnp.linalg.inv(G0), jnp.linalg.inv(R0)

    # --- Population Simulation ---
    print("--- Step 1-4: Simulating population and multi-trait phenotypes ---")
    key = jax.random.PRNGKey(SEED)
    key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
    
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI)
    
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var_a=jnp.array([var_g1, var_g2]), var_d=jnp.array([0.0, 0.9]), sigma=G0
    )
    
    founder_phenotypes, founder_tbvs = calculate_phenotypes(
        key=pheno_key, population=founder_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
    pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
    
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
        selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
        genetic_map, 10
    )
    
    new_meta = jnp.stack([
        jnp.arange(N_OFFSPRING) + N_FOUNDERS,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((N_OFFSPRING,), 1),
    ], axis=-1)
    offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
    
    key, offspring_pheno_key = jax.random.split(key)
    offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
        key=offspring_pheno_key, population=offspring_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    all_phenotypes = jnp.concatenate([founder_phenotypes, offspring_phenotypes], axis=0)
    print("--- Population simulation complete ---")

    # --- ABLUP (Sparse, Iterative) ---
    print("\n--- Performing Multi-Trait ABLUP (Sparse Iterative) ---")
    full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
    remapped_ped_np = remap_pedigree(full_pedigree)
    
    A_inv_sparse = build_a_inverse_sparse(remapped_ped_np)
    ablup_ebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, A_inv_sparse, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("ABLUP calculation complete.")
    
    # --- GBLUP (Iterative) ---
    print("\n--- Performing Multi-Trait GBLUP (Iterative) ---")
    all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
    G_matrix = build_g_matrix(all_geno)
    G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)
    
    gblup_gebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, G_inv, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("GBLUP calculation complete.")

    # --- Compare Results ---
    print("\n--- Comparison of Results for Offspring ---")
    offspring_ablup = ablup_ebvs[N_FOUNDERS:]
    offspring_gblup = gblup_gebvs[N_FOUNDERS:]

    acc_ablup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_ablup[:, 0])[0, 1]
    acc_ablup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_ablup[:, 1])[0, 1]
    acc_gblup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_gblup[:, 0])[0, 1]
    acc_gblup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_gblup[:, 1])[0, 1]

    print(f"\nABLUP Accuracy -> Trait 1: {acc_ablup_t1:.4f}, Trait 2: {acc_ablup_t2:.4f}")
    print(f"GBLUP Accuracy -> Trait 1: {acc_gblup_t1:.4f}, Trait 2: {acc_gblup_t2:.4f}")

    print("\n{:<6} | {:>12} {:>12} | {:>12} {:>12} | {:>12} {:>12}".format(
        "ID", "TBV T1", "TBV T2", "ABLUP T1", "ABLUP T2", "GBLUP T1", "GBLUP T2"))
    print("-" * 88)
    for i in range(10):
        print("{:<6} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f}".format(
            int(offspring_pop.meta[i, 0]),
            offspring_tbvs[i, 0], offspring_tbvs[i, 1],
            offspring_ablup[i, 0], offspring_ablup[i, 1],
            offspring_gblup[i, 0], offspring_gblup[i, 1]
        ))



--- Step 1-4: Simulating population and multi-trait phenotypes ---
--- Population simulation complete ---

--- Performing Multi-Trait ABLUP (Sparse Iterative) ---
[DEBUG] Sparse A_inv successfully created with 1406 non-zero elements.
[DEBUG] First 5 data points of A_inv: [1. 1. 1. 1. 1.]
ABLUP calculation complete.

--- Performing Multi-Trait GBLUP (Iterative) ---
GBLUP calculation complete.

--- Comparison of Results for Offspring ---

ABLUP Accuracy -> Trait 1: 0.8429, Trait 2: 0.7658
GBLUP Accuracy -> Trait 1: 0.8597, Trait 2: 0.7926

ID     |       TBV T1       TBV T2 |     ABLUP T1     ABLUP T2 |     GBLUP T1     GBLUP T2
----------------------------------------------------------------------------------------
200    |        1.776        0.456 |        0.661        0.440 |       -0.111       -0.035
201    |        1.663        1.046 |        0.612        0.519 |        0.061        0.297
202    |        2.522        0.793 |        1.711        0.530 |        1.056        0.423
203