In [None]:
#| default_exp workflow

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple

# Assume the 'chewc' library is installed
from chewc.structs import (
    Population,
    Trait,
    GeneticMap,
    quick_haplo,
    add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair

# --- Relationship Matrix Core Functions ---

def build_a_matrix(pedigree_meta: jnp.ndarray) -> jnp.ndarray:
    """
    Constructs the Numerator Relationship Matrix (A) from a pedigree.
    """
    n = pedigree_meta.shape[0]
    meta_np = np.array(pedigree_meta, dtype=np.int32)
    mother_ids = meta_np[:, 1]
    father_ids = meta_np[:, 2]

    A = np.zeros((n, n), dtype=np.float32)

    for i in range(n):
        s = mother_ids[i]
        d = father_ids[i]
        # Diagonal element
        A[i, i] = 1.0
        if s != -1 and d != -1:
            A[i, i] += 0.5 * A[s, d]
        
        # Off-diagonal elements (exploiting symmetry)
        for j in range(i + 1, n):
            s_j = mother_ids[j]
            d_j = father_ids[j]
            
            val = 0.0
            # A_ij = 0.5 * (A_{i,s_j} + A_{i,d_j})
            if s_j != -1:
                val += 0.5 * A[i, s_j]
            if d_j != -1:
                val += 0.5 * A[i, d_j]
            A[j, i] = A[i, j] = val
            
    return jnp.array(A)

@jax.jit
def build_g_matrix(geno: jnp.ndarray) -> jnp.ndarray:
    """
    Constructs the Genomic Relationship Matrix (G) using VanRaden (2008) method 1.
    """
    n_ind, n_chr, _, n_loci = geno.shape
    n_markers = n_chr * n_loci
    
    # Calculate dosage (0, 1, or 2) and reshape to (n_ind, n_markers)
    dosages = jnp.sum(geno, axis=2, dtype=jnp.float32).reshape(n_ind, -1)
    
    # Calculate allele frequencies (p) for each marker
    p = jnp.mean(dosages, axis=0) / 2.0
    
    # Handle monomorphic markers to avoid division by zero
    p = jnp.clip(p, 1e-6, 1.0 - 1e-6)
    
    # Center the genotypes: Z = M - P
    # M is the dosage matrix, P is 2 * p for each individual
    P = 2 * p
    Z = dosages - P
    
    # Calculate scaling factor: k = 2 * sum(p * (1 - p))
    k = 2 * jnp.sum(p * (1 - p))
    
    # Calculate G = (Z @ Z.T) / k
    G = (Z @ Z.T) / k
    return G


@jax.jit
def solve_mme(
    y: jnp.ndarray,
    relationship_matrix_inv: jnp.ndarray,
    lambda_: float,
) -> jnp.ndarray:
    """
    Solves the Mixed Model Equations for a single fixed effect (mean).
    This function is generic and works with either A_inv or G_inv.
    """
    n = relationship_matrix_inv.shape[0]
    X = jnp.ones((n, 1))
    
    X_t_X = X.T @ X
    X_t_Z = X.T
    Z_t_X = X
    Z_t_Z_reg = jnp.identity(n) + relationship_matrix_inv * lambda_
    
    LHS = jnp.block([[X_t_X, X_t_Z], [Z_t_X, Z_t_Z_reg]])
    
    X_t_y = X.T @ y
    Z_t_y = y
    RHS = jnp.concatenate([X_t_y, Z_t_y])
    
    solutions = jnp.linalg.solve(LHS, RHS)
    return solutions[1:] # Return only the estimated breeding values


def remap_pedigree_for_ablup(full_pedigree: jnp.ndarray) -> jnp.ndarray:
    """Remaps original IDs to contiguous integers for matrix construction."""
    original_ids = full_pedigree[:, 0]
    id_map = {int(og_id): i for i, og_id in enumerate(original_ids)}
    id_map[-1] = -1

    remapped_ids = np.array([id_map[int(i)] for i in full_pedigree[:, 0]])
    remapped_mothers = np.array([id_map[int(i)] for i in full_pedigree[:, 1]])
    remapped_fathers = np.array([id_map[int(i)] for i in full_pedigree[:, 2]])
    
    return jnp.stack([
        remapped_ids, remapped_mothers, remapped_fathers, full_pedigree[:, 3]
    ], axis=-1).astype(jnp.int32)

# --- Main Execution Script ---

if __name__ == "__main__":
    # --- Simulation Parameters ---
    N_FOUNDERS = 50
    N_SELECT = 10
    N_OFFSPRING = 100
    N_CHR, N_LOCI = 5, 500  # Increased loci for better G matrix estimation
    MAX_CROSSOVERS = 10
    HERITABILITY = 0.5
    SEED = 42

    key = jax.random.PRNGKey(SEED)
    key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
    
    # 1. & 2. --- Create Founder Population, Trait, and Phenotypes ---
    print("--- Step 1 & 2: Initializing Founders and Trait ---")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI
    )
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=20,
        mean=jnp.array([100.0]), var=jnp.array([10.0]), sigma=jnp.array([[1.0]])
    )
    founder_phenotypes, founder_tbvs = calculate_phenotypes(
        key=pheno_key, population=founder_pop,
        trait=trait_architecture, heritability=jnp.array([HERITABILITY])
    )

    # 3. & 4. --- Create and Evaluate Offspring Population ---
    print("--- Step 3 & 4: Creating and Evaluating Offspring ---")
    selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
    pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
    
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys,
        selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
        selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
        genetic_map, MAX_CROSSOVERS
    )
    
    new_meta = jnp.stack([
        jnp.arange(N_OFFSPRING) + N_FOUNDERS,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((N_OFFSPRING,), 1, dtype=jnp.int32),
    ], axis=-1)
    offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
    
    key, offspring_pheno_key = jax.random.split(key)
    offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
        key=offspring_pheno_key, population=offspring_pop,
        trait=trait_architecture, heritability=jnp.array([HERITABILITY])
    )

    # Combine data for all individuals
    all_phenotypes = jnp.concatenate([founder_phenotypes[:, 0], offspring_phenotypes[:, 0]])
    lambda_ = (1.0 - HERITABILITY) / HERITABILITY

    # --- 5. ABLUP Calculation ---
    print("\n--- Step 5: Performing ABLUP (Pedigree-based) ---")
    full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
    remapped_pedigree = remap_pedigree_for_ablup(full_pedigree)
    
    A_matrix = build_a_matrix(remapped_pedigree)
    A_inv = jnp.linalg.inv(A_matrix)
    
    ablup_ebvs = solve_mme(all_phenotypes, A_inv, lambda_)
    print("[DEBUG] ABLUP EBVs calculated.")

    # --- 6. GBLUP Calculation ---
    print("\n--- Step 6: Performing GBLUP (Genomic-based) ---")
    # Combine genotypes from founders and offspring
    all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
    
    G_matrix = build_g_matrix(all_geno)
    # Add a small value to the diagonal for invertibility (regularization)
    G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)

    gblup_gebvs = solve_mme(all_phenotypes, G_inv, lambda_)
    print("[DEBUG] GBLUP GEBVs calculated.")

    # --- 7. Verify and Compare Results ---
    print("\n--- Step 7: Comparison of Results for Offspring ---")
    offspring_ablup_ebvs = ablup_ebvs[N_FOUNDERS:]
    offspring_gblup_gebvs = gblup_gebvs[N_FOUNDERS:]

    # Calculate accuracies
    accuracy_ablup = jnp.corrcoef(offspring_tbvs.flatten(), offspring_ablup_ebvs.flatten())[0, 1]
    accuracy_gblup = jnp.corrcoef(offspring_tbvs.flatten(), offspring_gblup_gebvs.flatten())[0, 1]

    print(f"\nABLUP Accuracy (Correlation): {accuracy_ablup:.4f}")
    print(f"GBLUP Accuracy (Correlation): {accuracy_gblup:.4f}")
    
    print("\n{:<12} | {:>18} | {:>20} | {:>20}".format("Offspring ID", "True Breeding Value", "ABLUP Est. BV", "GBLUP Est. BV"))
    print("-" * 78)
    for i in range(15): # Print first 15 for brevity
        original_id = offspring_pop.meta[i, 0]
        tbv = offspring_tbvs[i, 0]
        ablup_ebv = offspring_ablup_ebvs[i]
        gblup_gebv = offspring_gblup_gebvs[i]
        print("{:<12} | {:>18.3f} | {:>20.3f} | {:>20.3f}".format(
            int(original_id), tbv, ablup_ebv, gblup_gebv
        ))

In [3]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple
from scipy.sparse import coo_matrix

# JAX's iterative solver and sparse matrix format
from jax.scipy.sparse.linalg import cg
import jax.experimental.sparse as jsparse

# Assume the 'chewc' library is installed
from chewc.structs import (
    Population,
    Trait,
    GeneticMap,
    quick_haplo,
    add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair

# --- Scalable Core Functions ---

def build_a_inverse_sparse(pedigree_meta: np.ndarray) -> jsparse.BCOO:
    """
    Constructs the inverse of the Numerator Relationship Matrix (A_inv)
    directly from a pedigree using Henderson's rules. This implementation is
    corrected for numerical stability.
    """
    n = pedigree_meta.shape[0]
    # More generous pre-allocation
    max_elements = n * 7
    data, rows, cols = (
        np.zeros(max_elements, dtype=np.float32),
        np.zeros(max_elements, dtype=np.int32),
        np.zeros(max_elements, dtype=np.int32),
    )
    ptr = 0

    def add_element(r, c, val):
        nonlocal ptr
        rows[ptr], cols[ptr], data[ptr] = r, c, val
        ptr += 1

    # Loop through individuals and add their contributions to A_inv
    for i in range(n):
        sire = int(pedigree_meta[i, 1])
        dam = int(pedigree_meta[i, 2])

        if sire == -1 and dam == -1: # Founder
            add_element(i, i, 1.0)
        elif sire == -1 or dam == -1: # One parent known
            parent = sire if sire != -1 else dam
            add_element(i, i, 4/3)
            add_element(i, parent, -2/3)
            add_element(parent, i, -2/3)
            add_element(parent, parent, 1/3)
        else: # Both parents known
            add_element(i, i, 2.0)
            add_element(i, sire, -1.0)
            add_element(sire, i, -1.0)
            add_element(i, dam, -1.0)
            add_element(dam, i, -1.0)
            add_element(sire, sire, 0.5)
            add_element(dam, dam, 0.5)
            add_element(sire, dam, 0.5)
            add_element(dam, sire, 0.5)

    # Use scipy to efficiently sum duplicate entries from contributions
    scipy_coo = coo_matrix((data[:ptr], (rows[:ptr], cols[:ptr])), shape=(n, n))
    scipy_coo.sum_duplicates()

    # **DEBUG**: Check for invalid values before returning
    if not np.all(np.isfinite(scipy_coo.data)):
        print("[DEBUG] WARNING: Non-finite values (NaN or inf) found in sparse A_inv data!")
    else:
        print(f"[DEBUG] Sparse A_inv successfully created with {len(scipy_coo.data)} non-zero elements.")
        print(f"[DEBUG] First 5 data points of A_inv: {scipy_coo.data[:5]}")

    # The format requires indices of shape (nse, 2)
    indices = np.stack([scipy_coo.row, scipy_coo.col], axis=1)

    return jsparse.BCOO((jnp.array(scipy_coo.data), jnp.array(indices)), shape=(n, n))


@jax.jit
def build_g_matrix(geno: jnp.ndarray) -> jnp.ndarray:
    """Constructs the Genomic Relationship Matrix (G)."""
    n_ind, _, _, _ = geno.shape
    dosages = jnp.sum(geno, axis=2, dtype=jnp.float32).reshape(n_ind, -1)
    p = jnp.mean(dosages, axis=0) / 2.0
    p = jnp.clip(p, 1e-6, 1.0 - 1e-6)
    Z = dosages - 2 * p
    k = 2 * jnp.sum(p * (1 - p))
    return (Z @ Z.T) / k

@partial(jax.jit, static_argnames=('n_traits',))
def solve_multi_trait_mme_iterative(
    y: jnp.ndarray,
    relationship_matrix_inv: jsparse.BCOO,
    G0_inv: jnp.ndarray,
    R0_inv: jnp.ndarray,
    n_traits: int
) -> jnp.ndarray:
    """Solves a multi-trait MME using an iterative Conjugate Gradient solver."""
    n_ind = relationship_matrix_inv.shape[0]

    def lhs_matvec(v_stacked):
        b = v_stacked[:n_traits]
        u = v_stacked[n_traits:].reshape((n_ind, n_traits), order='F')
        
        # Top block
        top = (R0_inv @ b) * n_ind + R0_inv @ u.sum(axis=0)
        
        # Bottom block
        Zt_R_inv_X_b = (jnp.ones((n_ind, 1)) @ b.reshape(1, -1)) @ R0_inv.T
        Zt_R_inv_Z_u = u @ R0_inv.T
        G_inv_u = (relationship_matrix_inv @ u) @ G0_inv.T
        bottom = (Zt_R_inv_X_b + Zt_R_inv_Z_u + G_inv_u).flatten('F')
        
        return jnp.concatenate([top, bottom])

    RHS = jnp.concatenate([R0_inv @ y.sum(axis=0), (y @ R0_inv.T).flatten('F')])
    
    solutions, _ = cg(lhs_matvec, RHS, maxiter=1000)
    
    return solutions[n_traits:].reshape((n_ind, n_traits), order='F')


def remap_pedigree(full_pedigree: jnp.ndarray) -> np.ndarray:
    """Remaps IDs and returns a NumPy array for the sparse builder."""
    original_ids = full_pedigree[:, 0]
    id_map = {int(og_id): i for i, og_id in enumerate(original_ids)}
    id_map[-1] = -1
    
    remapped_pedigree = np.zeros(full_pedigree.shape, dtype=np.int32)
    remapped_pedigree[:, 0] = np.arange(len(original_ids))
    remapped_pedigree[:, 1] = [id_map[int(i)] for i in full_pedigree[:, 1]]
    remapped_pedigree[:, 2] = [id_map[int(i)] for i in full_pedigree[:, 2]]
    return remapped_pedigree

# --- Main Execution Script ---

if __name__ == "__main__":
    # --- Parameters ---
    N_FOUNDERS, N_SELECT, N_OFFSPRING = 100, 20, 200
    N_CHR, N_LOCI = 5, 1000
    SEED = 42
    N_TRAITS = 2
    h2_trait1, h2_trait2 = 0.6, 0.4
    genetic_corr = 0.5
    
    # --- Setup (Co)variance matrices ---
    var_g1, var_g2 = h2_trait1, h2_trait2
    cov_g12 = genetic_corr * jnp.sqrt(var_g1 * var_g2)
    G0 = jnp.array([[var_g1, cov_g12], [cov_g12, var_g2]])
    
    var_e1, var_e2 = 1 - h2_trait1, 1 - h2_trait2
    R0 = jnp.diag(jnp.array([var_e1, var_e2]))
    
    G0_inv, R0_inv = jnp.linalg.inv(G0), jnp.linalg.inv(R0)

    # --- Population Simulation ---
    print("--- Step 1-4: Simulating population and multi-trait phenotypes ---")
    key = jax.random.PRNGKey(SEED)
    key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
    
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI)
    
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([var_g1, var_g2]), sigma=G0
    )
    
    founder_phenotypes, founder_tbvs = calculate_phenotypes(
        key=pheno_key, population=founder_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
    pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
    
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
        selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
        genetic_map, 10
    )
    
    new_meta = jnp.stack([
        jnp.arange(N_OFFSPRING) + N_FOUNDERS,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((N_OFFSPRING,), 1),
    ], axis=-1)
    offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
    
    key, offspring_pheno_key = jax.random.split(key)
    offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
        key=offspring_pheno_key, population=offspring_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    all_phenotypes = jnp.concatenate([founder_phenotypes, offspring_phenotypes], axis=0)
    print("--- Population simulation complete ---")

    # --- ABLUP (Sparse, Iterative) ---
    print("\n--- Performing Multi-Trait ABLUP (Sparse Iterative) ---")
    full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
    remapped_ped_np = remap_pedigree(full_pedigree)
    
    A_inv_sparse = build_a_inverse_sparse(remapped_ped_np)
    ablup_ebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, A_inv_sparse, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("ABLUP calculation complete.")
    
    # --- GBLUP (Iterative) ---
    print("\n--- Performing Multi-Trait GBLUP (Iterative) ---")
    all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
    G_matrix = build_g_matrix(all_geno)
    G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)
    
    gblup_gebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, G_inv, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("GBLUP calculation complete.")

    # --- Compare Results ---
    print("\n--- Comparison of Results for Offspring ---")
    offspring_ablup = ablup_ebvs[N_FOUNDERS:]
    offspring_gblup = gblup_gebvs[N_FOUNDERS:]

    acc_ablup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_ablup[:, 0])[0, 1]
    acc_ablup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_ablup[:, 1])[0, 1]
    acc_gblup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_gblup[:, 0])[0, 1]
    acc_gblup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_gblup[:, 1])[0, 1]

    print(f"\nABLUP Accuracy -> Trait 1: {acc_ablup_t1:.4f}, Trait 2: {acc_ablup_t2:.4f}")
    print(f"GBLUP Accuracy -> Trait 1: {acc_gblup_t1:.4f}, Trait 2: {acc_gblup_t2:.4f}")

    print("\n{:<6} | {:>12} {:>12} | {:>12} {:>12} | {:>12} {:>12}".format(
        "ID", "TBV T1", "TBV T2", "ABLUP T1", "ABLUP T2", "GBLUP T1", "GBLUP T2"))
    print("-" * 88)
    for i in range(10):
        print("{:<6} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f}".format(
            int(offspring_pop.meta[i, 0]),
            offspring_tbvs[i, 0], offspring_tbvs[i, 1],
            offspring_ablup[i, 0], offspring_ablup[i, 1],
            offspring_gblup[i, 0], offspring_gblup[i, 1]
        ))

--- Step 1-4: Simulating population and multi-trait phenotypes ---
--- Population simulation complete ---

--- Performing Multi-Trait ABLUP (Sparse Iterative) ---
[DEBUG] Sparse A_inv successfully created with 1306 non-zero elements.
[DEBUG] First 5 data points of A_inv: [1. 1. 1. 1. 1.]
ABLUP calculation complete.

--- Performing Multi-Trait GBLUP (Iterative) ---
GBLUP calculation complete.

--- Comparison of Results for Offspring ---

ABLUP Accuracy -> Trait 1: 0.8360, Trait 2: 0.7404
GBLUP Accuracy -> Trait 1: 0.8594, Trait 2: 0.7734

ID     |       TBV T1       TBV T2 |     ABLUP T1     ABLUP T2 |     GBLUP T1     GBLUP T2
----------------------------------------------------------------------------------------
100    |      101.551       51.007 |        1.006        0.817 |        0.380        0.571
101    |      100.531       50.623 |        0.697        0.056 |        0.387       -0.103
102    |      100.679       49.953 |        1.434        0.276 |        0.575       -0.341
103

In [5]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Tuple
from scipy.sparse import coo_matrix

# JAX's iterative solver and sparse matrix format
from jax.scipy.sparse.linalg import cg
import jax.experimental.sparse as jsparse

# Assume the 'chewc' library is installed
from chewc.structs import (
    Population,
    Trait,
    GeneticMap,
    quick_haplo,
    add_trait
)
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair

# --- Scalable Core Functions ---

def build_a_inverse_sparse(pedigree_meta: np.ndarray) -> jsparse.BCOO:
    """
    Constructs the inverse of the Numerator Relationship Matrix (A_inv)
    directly from a pedigree using the correct, numerically stable version of
    Henderson's rules. Returns the matrix in JAX's BCOO sparse format.
    """
    n = pedigree_meta.shape[0]
    # Pre-allocate with a safe upper bound on non-zero elements
    max_elements = n * 9  # Max of 9 contributions per animal with two parents
    data, rows, cols = (
        np.zeros(max_elements, dtype=np.float32),
        np.zeros(max_elements, dtype=np.int32),
        np.zeros(max_elements, dtype=np.int32),
    )
    ptr = 0

    def add_element(r, c, val):
        nonlocal ptr
        if ptr >= max_elements:
            raise ValueError("Exceeded pre-allocated memory for sparse matrix.")
        rows[ptr], cols[ptr], data[ptr] = r, c, val
        ptr += 1

    # Loop through individuals and add their contributions to A_inv
    for i in range(n):
        sire = int(pedigree_meta[i, 1])
        dam = int(pedigree_meta[i, 2])

        if sire == -1 and dam == -1:  # Founder
            add_element(i, i, 1.0)
        elif sire == -1 or dam == -1:  # One parent known
            parent = sire if sire != -1 else dam
            add_element(i, i, 4/3)
            add_element(i, parent, -2/3)
            add_element(parent, i, -2/3)
            add_element(parent, parent, 1/3)
        else:  # Both parents known
            add_element(i, i, 2.0)
            add_element(i, sire, -1.0)
            add_element(sire, i, -1.0)
            add_element(i, dam, -1.0)
            add_element(dam, i, -1.0)
            add_element(sire, sire, 0.5)
            add_element(dam, dam, 0.5)
            add_element(sire, dam, 0.5)
            add_element(dam, sire, 0.5)

    # Use scipy to efficiently sum duplicate entries
    scipy_coo = coo_matrix((data[:ptr], (rows[:ptr], cols[:ptr])), shape=(n, n))
    scipy_coo.sum_duplicates()

    # **DEBUG**: Check for invalid values before returning
    if not np.all(np.isfinite(scipy_coo.data)):
        print("[DEBUG] WARNING: Non-finite values (NaN or inf) found in sparse A_inv data!")
    else:
        print(f"[DEBUG] Sparse A_inv successfully created with {len(scipy_coo.data)} non-zero elements.")
        print(f"[DEBUG] First 5 data points of A_inv: {scipy_coo.data[:5]}")

    # The format requires indices of shape (nse, 2)
    indices = np.stack([scipy_coo.row, scipy_coo.col], axis=1)

    return jsparse.BCOO((jnp.array(scipy_coo.data), jnp.array(indices)), shape=(n, n))


@jax.jit
def build_g_matrix(geno: jnp.ndarray) -> jnp.ndarray:
    """Constructs the Genomic Relationship Matrix (G)."""
    n_ind, _, _, _ = geno.shape
    dosages = jnp.sum(geno, axis=2, dtype=jnp.float32).reshape(n_ind, -1)
    p = jnp.mean(dosages, axis=0) / 2.0
    p = jnp.clip(p, 1e-6, 1.0 - 1e-6)
    Z = dosages - 2 * p
    k = 2 * jnp.sum(p * (1 - p))
    return (Z @ Z.T) / k

@partial(jax.jit, static_argnames=('n_traits',))
def solve_multi_trait_mme_iterative(
    y: jnp.ndarray,
    relationship_matrix_inv: jsparse.BCOO,
    G0_inv: jnp.ndarray,
    R0_inv: jnp.ndarray,
    n_traits: int
) -> jnp.ndarray:
    """Solves a multi-trait MME using an iterative Conjugate Gradient solver."""
    n_ind = relationship_matrix_inv.shape[0]

    def lhs_matvec(v_stacked):
        b = v_stacked[:n_traits]
        u = v_stacked[n_traits:].reshape((n_ind, n_traits), order='F')
        
        # Top block
        top = (R0_inv @ b) * n_ind + R0_inv @ u.sum(axis=0)
        
        # Bottom block
        Zt_R_inv_X_b = (jnp.ones((n_ind, 1)) @ b.reshape(1, -1)) @ R0_inv.T
        Zt_R_inv_Z_u = u @ R0_inv.T
        G_inv_u = (relationship_matrix_inv @ u) @ G0_inv.T
        bottom = (Zt_R_inv_X_b + Zt_R_inv_Z_u + G_inv_u).flatten('F')
        
        return jnp.concatenate([top, bottom])

    RHS = jnp.concatenate([R0_inv @ y.sum(axis=0), (y @ R0_inv.T).flatten('F')])
    
    solutions, _ = cg(lhs_matvec, RHS, maxiter=1000)
    
    return solutions[n_traits:].reshape((n_ind, n_traits), order='F')


def remap_pedigree(full_pedigree: jnp.ndarray) -> np.ndarray:
    """Remaps IDs and returns a NumPy array for the sparse builder."""
    original_ids = full_pedigree[:, 0]
    id_map = {int(og_id): i for i, og_id in enumerate(original_ids)}
    id_map[-1] = -1
    
    remapped_pedigree = np.zeros(full_pedigree.shape, dtype=np.int32)
    remapped_pedigree[:, 0] = np.arange(len(original_ids))
    remapped_pedigree[:, 1] = [id_map[int(i)] for i in full_pedigree[:, 1]]
    remapped_pedigree[:, 2] = [id_map[int(i)] for i in full_pedigree[:, 2]]
    return remapped_pedigree

# --- Main Execution Script ---

if __name__ == "__main__":
    # --- Parameters ---
    N_FOUNDERS, N_SELECT, N_OFFSPRING = 200, 20, 200
    N_CHR, N_LOCI = 5, 1000
    SEED = 42
    N_TRAITS = 2
    h2_trait1, h2_trait2 = 0.6, 0.4
    genetic_corr = 0.5
    
    # --- Setup (Co)variance matrices ---
    var_g1, var_g2 = h2_trait1, h2_trait2
    cov_g12 = genetic_corr * jnp.sqrt(var_g1 * var_g2)
    G0 = jnp.array([[var_g1, cov_g12], [cov_g12, var_g2]])
    
    var_e1, var_e2 = 1 - h2_trait1, 1 - h2_trait2
    R0 = jnp.diag(jnp.array([var_e1, var_e2]))
    
    G0_inv, R0_inv = jnp.linalg.inv(G0), jnp.linalg.inv(R0)

    # --- Population Simulation ---
    print("--- Step 1-4: Simulating population and multi-trait phenotypes ---")
    key = jax.random.PRNGKey(SEED)
    key, pop_key, trait_key, pheno_key, mating_key, cross_key = jax.random.split(key, 6)
    
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI)
    
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([var_g1, var_g2]), sigma=G0
    )
    
    founder_phenotypes, founder_tbvs = calculate_phenotypes(
        key=pheno_key, population=founder_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    selected_parents = select_top_k(founder_pop, founder_phenotypes[:, 0], k=N_SELECT)
    pairings = random_mating(mating_key, n_parents=N_SELECT, n_crosses=N_OFFSPRING)
    
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, N_OFFSPRING)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, selected_parents.geno[pairings[:, 0]], selected_parents.geno[pairings[:, 1]],
        selected_parents.ibd[pairings[:, 0]], selected_parents.ibd[pairings[:, 1]],
        genetic_map, 10
    )
    
    new_meta = jnp.stack([
        jnp.arange(N_OFFSPRING) + N_FOUNDERS,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((N_OFFSPRING,), 1),
    ], axis=-1)
    offspring_pop = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)
    
    key, offspring_pheno_key = jax.random.split(key)
    offspring_phenotypes, offspring_tbvs = calculate_phenotypes(
        key=offspring_pheno_key, population=offspring_pop, trait=trait_architecture,
        heritability=jnp.array([h2_trait1, h2_trait2])
    )
    
    all_phenotypes = jnp.concatenate([founder_phenotypes, offspring_phenotypes], axis=0)
    print("--- Population simulation complete ---")

    # --- ABLUP (Sparse, Iterative) ---
    print("\n--- Performing Multi-Trait ABLUP (Sparse Iterative) ---")
    full_pedigree = jnp.concatenate([founder_pop.meta, offspring_pop.meta], axis=0)
    remapped_ped_np = remap_pedigree(full_pedigree)
    
    A_inv_sparse = build_a_inverse_sparse(remapped_ped_np)
    ablup_ebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, A_inv_sparse, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("ABLUP calculation complete.")
    
    # --- GBLUP (Iterative) ---
    print("\n--- Performing Multi-Trait GBLUP (Iterative) ---")
    all_geno = jnp.concatenate([founder_pop.geno, offspring_pop.geno], axis=0)
    G_matrix = build_g_matrix(all_geno)
    G_inv = jnp.linalg.inv(G_matrix + jnp.identity(G_matrix.shape[0]) * 1e-4)
    
    gblup_gebvs = solve_multi_trait_mme_iterative(
        all_phenotypes, G_inv, G0_inv, R0_inv, n_traits=N_TRAITS
    )
    print("GBLUP calculation complete.")

    # --- Compare Results ---
    print("\n--- Comparison of Results for Offspring ---")
    offspring_ablup = ablup_ebvs[N_FOUNDERS:]
    offspring_gblup = gblup_gebvs[N_FOUNDERS:]

    acc_ablup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_ablup[:, 0])[0, 1]
    acc_ablup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_ablup[:, 1])[0, 1]
    acc_gblup_t1 = jnp.corrcoef(offspring_tbvs[:, 0], offspring_gblup[:, 0])[0, 1]
    acc_gblup_t2 = jnp.corrcoef(offspring_tbvs[:, 1], offspring_gblup[:, 1])[0, 1]

    print(f"\nABLUP Accuracy -> Trait 1: {acc_ablup_t1:.4f}, Trait 2: {acc_ablup_t2:.4f}")
    print(f"GBLUP Accuracy -> Trait 1: {acc_gblup_t1:.4f}, Trait 2: {acc_gblup_t2:.4f}")

    print("\n{:<6} | {:>12} {:>12} | {:>12} {:>12} | {:>12} {:>12}".format(
        "ID", "TBV T1", "TBV T2", "ABLUP T1", "ABLUP T2", "GBLUP T1", "GBLUP T2"))
    print("-" * 88)
    for i in range(10):
        print("{:<6} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f} | {:>12.3f} {:>12.3f}".format(
            int(offspring_pop.meta[i, 0]),
            offspring_tbvs[i, 0], offspring_tbvs[i, 1],
            offspring_ablup[i, 0], offspring_ablup[i, 1],
            offspring_gblup[i, 0], offspring_gblup[i, 1]
        ))

--- Step 1-4: Simulating population and multi-trait phenotypes ---
--- Population simulation complete ---

--- Performing Multi-Trait ABLUP (Sparse Iterative) ---
[DEBUG] Sparse A_inv successfully created with 1406 non-zero elements.
[DEBUG] First 5 data points of A_inv: [1. 1. 1. 1. 1.]
ABLUP calculation complete.

--- Performing Multi-Trait GBLUP (Iterative) ---
GBLUP calculation complete.

--- Comparison of Results for Offspring ---

ABLUP Accuracy -> Trait 1: 0.8543, Trait 2: 0.7700
GBLUP Accuracy -> Trait 1: 0.8698, Trait 2: 0.7921

ID     |       TBV T1       TBV T2 |     ABLUP T1     ABLUP T2 |     GBLUP T1     GBLUP T2
----------------------------------------------------------------------------------------
200    |      100.404       49.503 |        0.553        0.084 |       -0.161       -0.159
201    |      100.290       50.093 |        0.555       -0.027 |       -0.007       -0.264
202    |      101.150       49.840 |        1.583       -0.072 |        0.880       -0.544
203

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()