# predict

> Common operations around the core datastructures for running a sim

In [None]:
#| default_exp predict

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from flax.struct import dataclass as flax_dataclass
import jax.numpy as jnp
from typing import Optional
from chewc.population import Population

In [None]:
#| export
"""
GBLUP (Genomic Best Linear Unbiased Prediction) implementation for chewc library.

This module provides functions for genomic prediction using the GBLUP methodology,
which is a standard approach in genomic selection and animal breeding.
"""

from typing import Optional, Dict
import jax
import jax.numpy as jnp
from jax.numpy.linalg import solve, inv, pinv
from flax.struct import dataclass as flax_dataclass

from chewc.population import Population

@flax_dataclass(frozen=True)
class PredictionResults:
    """
    A container for the results of a genomic prediction.
    """
    ids: jnp.ndarray
    ebv: jnp.ndarray
    pev: Optional[jnp.ndarray] = None
    reliability: Optional[jnp.ndarray] = None
    fixed_effects: Optional[jnp.ndarray] = None
    h2_used: Optional[float] = None
    var_components: Optional[Dict] = None


def _gblup_core(
    phenotypes: jnp.ndarray,
    dosages: jnp.ndarray,
    h2: float,
    trait_idx: int = 0,
    regularization: float = 1e-6
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, dict]:
    """
    Core GBLUP calculation. Handles both full and missing data cases.
    """
    y = phenotypes[:, trait_idx]
    n_total = dosages.shape[0]

    valid_mask = ~jnp.isnan(y)
    n_valid = jnp.sum(valid_mask).item()

    if n_valid == 0:
        raise ValueError("No valid (non-NaN) phenotypes found for the selected trait.")

    # --- Shared calculations ---
    p = jnp.mean(dosages[valid_mask], axis=0) / 2 # Allele freqs from phenotyped pop
    W_all = dosages - 2 * p
    denom = 2 * jnp.sum(p * (1 - p))
    G_all = (W_all @ W_all.T) / (denom + regularization)

    y_valid = y[valid_mask]
    y_mean = jnp.mean(y_valid)
    y_centered = y_valid - y_mean
    
    var_y = jnp.var(y_valid)
    var_g = h2 * var_y
    var_e = (1 - h2) * var_y

    # --- Case 1: All individuals have phenotypes (no partitioning needed) ---
    if n_valid == n_total:
        G_reg = G_all + jnp.eye(n_total) * regularization
        lambda_val = var_e / (var_g + regularization)
        X = jnp.ones((n_total, 1))
        Z = jnp.eye(n_total)
        G_inv = pinv(G_reg)

        C11 = X.T @ X
        C12 = X.T @ Z
        C22 = Z.T @ Z + G_inv * lambda_val
        LHS = jnp.block([[C11, C12], [C12.T, C22]])
        RHS = jnp.concatenate([X.T @ y_centered, Z.T @ y_centered])
        
        try:
            solutions = solve(LHS, RHS)
            C_inv = inv(LHS)
        except jnp.linalg.LinAlgError:
            solutions = pinv(LHS) @ RHS
            C_inv = pinv(LHS)

        b_hat = solutions[:1]
        ebv_full = solutions[1:]
        
        pev_full = jnp.diag(C_inv[1:, 1:]) * var_e
        reliability_full = jnp.clip(1 - (pev_full / (var_g + regularization)), 0.0, 1.0)
        
    # --- Case 2: Some individuals have missing phenotypes (partitioning required) ---
    else:
        G_11 = G_all[valid_mask][:, valid_mask]
        G_21 = G_all[~valid_mask][:, valid_mask]
        G_22 = G_all[~valid_mask][:, ~valid_mask]
    
        G_11_reg = G_11 + jnp.eye(n_valid) * regularization
        lambda_val = var_e / (var_g + regularization)
        X1 = jnp.ones((n_valid, 1))
        Z1 = jnp.eye(n_valid)
        G_11_inv = pinv(G_11_reg)

        C11 = X1.T @ X1
        C12 = X1.T @ Z1
        C22 = Z1.T @ Z1 + G_11_inv * lambda_val
        LHS = jnp.block([[C11, C12], [C12.T, C22]])
        RHS = jnp.concatenate([X1.T @ y_centered, Z1.T @ y_centered])

        try:
            solutions = solve(LHS, RHS)
            C_inv = inv(LHS)
        except jnp.linalg.LinAlgError:
            solutions = pinv(LHS) @ RHS
            C_inv = pinv(LHS)

        b_hat = solutions[:1]
        u_hat_1 = solutions[1:]
        u_hat_2 = G_21 @ G_11_inv @ u_hat_1
        
        ebv_full = jnp.zeros(n_total).at[valid_mask].set(u_hat_1).at[~valid_mask].set(u_hat_2)
        
        pev_1 = jnp.diag(C_inv[1:, 1:]) * var_e
        rel_1 = jnp.clip(1 - (pev_1 / (var_g + regularization)), 0.0, 1.0)
        
        pev_2 = jnp.diag(G_22 - G_21 @ G_11_inv @ G_21.T) * var_g
        rel_2 = jnp.clip(1 - (pev_2 / (var_g + regularization)), 0.0, 1.0)
        
        pev_full = jnp.zeros(n_total).at[valid_mask].set(pev_1).at[~valid_mask].set(pev_2)
        reliability_full = jnp.zeros(n_total).at[valid_mask].set(rel_1).at[~valid_mask].set(rel_2)

    # --- Shared summary statistics ---
    h2_realized = jnp.var(ebv_full[valid_mask]) / (var_y + 1e-8)
    var_components = {
        'var_genetic': var_g, 'var_error': var_e, 'var_phenotypic': var_y,
        'h2_input': h2, 'h2_realized': h2_realized,
        'intercept': y_mean + b_hat[0], 'n_valid': n_valid
    }
    return ebv_full, pev_full, reliability_full, var_components


def gblup_predict(
    pop: Population,
    h2: float,
    trait_idx: int = 0,
    regularization: float = 1e-6
) -> PredictionResults:
    """
    Predicts breeding values using GBLUP (Genomic Best Linear Unbiased Prediction).
    """
    if not (0 <= h2 <= 1):
        raise ValueError(f"Heritability must be between 0 and 1, got {h2}")
    if pop.pheno.shape[1] <= trait_idx:
        raise IndexError(f"trait_idx {trait_idx} is out of bounds for {pop.pheno.shape[1]} traits")
    if pop.nInd == 0:
        raise ValueError("Population is empty")

    ebv, pev, reliability, var_components = _gblup_core(
        pop.pheno, pop.dosage, h2, trait_idx=trait_idx, regularization=regularization
    )

    return PredictionResults(
        ids=pop.id,
        ebv=ebv.reshape(-1, 1),
        pev=pev,
        reliability=reliability,
        fixed_effects=jnp.array([var_components['intercept']]),
        h2_used=h2,
        var_components=var_components
    )


def gblup_multi_trait(
    pop: Population,
    h2: jnp.ndarray,
    regularization: float = 1e-6
) -> list[PredictionResults]:
    """
    Performs GBLUP prediction for multiple traits independently.
    """
    if h2.ndim == 0:
        h2 = jnp.array([h2])
    if len(h2) != pop.pheno.shape[1]:
        raise ValueError(f"Number of h2 values ({len(h2)}) must match number of traits ({pop.pheno.shape[1]})")

    results = []
    for trait_idx, h2_val in enumerate(h2):
        result = gblup_predict(pop, float(h2_val), trait_idx, regularization)
        results.append(result)

    return results

In [None]:
#| test

import jax
import jax.numpy as jnp

# --- Import your revised GBLUP code ---
# To make this runnable, you would save the code above as, for example, `gblup.py`
# and then import it like this:
# from gblup import gblup_predict
from __main__ import gblup_predict # Or use this line if running in the same file/notebook

# --- Import chewc library components ---
# This assumes 'chewc' is installed and accessible in your environment
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno


def test_gblup_accuracy():
    """
    Tests the GBLUP implementation by comparing its predictions
    to known, simulated breeding values.
    """
    print("🚀 Starting GBLUP accuracy test...")

    # 1. Simulation parameters
    n_ind = 250
    n_loci_per_chr = 1000
    n_chr = 1
    h2_simulated = 0.5  # The true heritability used to generate data
    h2_for_gblup = 0.5  # The heritability we provide to the GBLUP model

    # 2. Setup JAX random keys
    key = jax.random.PRNGKey(42)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)

    # 3. Create a founder population using chewc
    print(f"🧬 Simulating a population with {n_ind} individuals and {n_loci_per_chr} loci...")
    founder_pop, gen_map = quick_haplo(
        key=founder_key,
        n_ind=n_ind,
        n_chr=n_chr,
        n_loci_per_chr=n_loci_per_chr
    )

    # 4. Define simulation parameters and add a quantitative trait
    sp = SimParam.from_founder_pop(founder_pop, gen_map)
    sp = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=150,      # A reasonably polygenic trait
        mean=jnp.array([10.0]), # With a non-zero mean
        var=jnp.array([2.0])    # And some genetic variance
    )

    # 5. Generate phenotypes based on the true genetics and simulated heritability
    print(f"🌱 Generating phenotypes with a true h² of {h2_simulated}...")
    pop_with_pheno = set_pheno(
        key=pheno_key,
        pop=founder_pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2_simulated])
    )

    # 6. Run the GBLUP prediction using the specified h2
    print(f"📈 Running GBLUP prediction with an input h² of {h2_for_gblup}...")
    results = gblup_predict(pop=pop_with_pheno, h2=h2_for_gblup)

    # 7. Validate the results
    true_bvs = pop_with_pheno.bv.flatten()
    estimated_bvs = results.ebv.flatten()

    # Calculate the correlation between true and estimated BVs.
    # This is a key metric for "prediction accuracy".
    prediction_accuracy = jnp.corrcoef(true_bvs, estimated_bvs)[0, 1]
    
    # Calculate the regression of true BVs on estimated BVs.
    # A value close to 1 indicates that the EBVs are not biased (scaled correctly).
    regression_b_on_a = jnp.cov(true_bvs, estimated_bvs)[0, 1] / jnp.var(estimated_bvs)

    print("\n--- ✅ Test Validation ---")
    print(f"Prediction Accuracy (Correlation): {prediction_accuracy:.4f}")
    print(f"Regression of True BV on EBV:     {regression_b_on_a:.4f}")

    # --- Assertions for automated testing ---
    # The expected accuracy is roughly sqrt(h2). We'll test for a reasonable value.
    assert prediction_accuracy > 0.6, f"Prediction accuracy ({prediction_accuracy:.2f}) is lower than expected."
    
    # The regression coefficient should be close to 1, indicating unbiased predictions.
    assert 0.9 < regression_b_on_a < 1.1, f"EBV estimates appear biased (regression={regression_b_on_a:.2f})."
    
    # Check that the output shapes are correct
    assert results.ebv.shape == (n_ind, 1)
    assert results.reliability.shape == (n_ind,)
    
    # Check that reliability is within the valid range [0, 1]
    assert jnp.all(results.reliability >= 0) and jnp.all(results.reliability <= 1)
    print("\n🎉 All assertions passed!")


def test_gblup_with_missing_phenotypes():
    """
    Tests that GBLUP can handle missing phenotypes (NaNs) and still provide
    predictions for all individuals.
    """
    print("\n🚀 Starting GBLUP test with missing phenotypes...")

    # 1. Simulation parameters (same as before)
    n_ind = 300
    n_loci_per_chr = 1000
    n_chr = 1
    h2 = 0.6  # Using a slightly higher h2 to make the effect clearer
    missing_fraction = 0.4 # 40% of individuals will have no phenotype

    # 2. Setup JAX random keys
    key = jax.random.PRNGKey(101)
    founder_key, trait_key, pheno_key, missing_key = jax.random.split(key, 4)

    # 3. Create a base population
    print(f"🧬 Simulating a population with {n_ind} individuals...")
    founder_pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=n_chr, n_loci_per_chr=n_loci_per_chr
    )
    sp = SimParam.from_founder_pop(founder_pop, gen_map)
    sp = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=150,
        mean=jnp.array([10.0]),
        var=jnp.array([2.0]),
    )
    pop_with_pheno = set_pheno(
        key=pheno_key,
        pop=founder_pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2]),
    )

    # 4. Introduce missing phenotypes
    n_missing = int(n_ind * missing_fraction)
    print(f"🔪 Introducing {n_missing} missing phenotypes (NaNs)...")
    
    # Create a boolean mask for which individuals will have phenotypes
    phenotyped_indices = jax.random.choice(missing_key, n_ind, shape=(n_ind - n_missing,), replace=False)
    phenotyped_mask = jnp.zeros(n_ind, dtype=bool).at[phenotyped_indices].set(True)
    
    # Create the new phenotype array with NaNs
    pheno_with_missing = jnp.where(
        phenotyped_mask[:, None],  # Ensure mask is broadcastable to pheno shape
        pop_with_pheno.pheno,
        jnp.nan
    )
    pop_missing = pop_with_pheno.replace(pheno=pheno_with_missing)

    # 5. Run GBLUP prediction
    print("📈 Running GBLUP prediction on incomplete data...")
    results = gblup_predict(pop=pop_missing, h2=h2)

    # 6. Validate the results
    true_bvs = pop_missing.bv.flatten()
    estimated_bvs = results.ebv.flatten()

    # --- Separate individuals into phenotyped and non-phenotyped groups ---
    bvs_phenotyped = true_bvs[phenotyped_mask]
    ebvs_phenotyped = estimated_bvs[phenotyped_mask]
    
    bvs_non_phenotyped = true_bvs[~phenotyped_mask]
    ebvs_non_phenotyped = estimated_bvs[~phenotyped_mask]

    # --- Calculate accuracies for each group ---
    accuracy_phenotyped = jnp.corrcoef(bvs_phenotyped, ebvs_phenotyped)[0, 1]
    accuracy_non_phenotyped = jnp.corrcoef(bvs_non_phenotyped, ebvs_non_phenotyped)[0, 1]
    
    print("\n--- ✅ Test Validation ---")
    print(f"Accuracy for individuals WITH phenotypes:   {accuracy_phenotyped:.4f}")
    print(f"Accuracy for individuals WITHOUT phenotypes: {accuracy_non_phenotyped:.4f}")

    # --- Assertions for automated testing ---
    # The function should not produce any NaNs in the output EBVs
    assert not jnp.any(jnp.isnan(results.ebv)), "Output EBVs should not contain NaNs."
    
    # Accuracy for phenotyped group should be high
    assert accuracy_phenotyped > jnp.sqrt(h2) * 0.8, "Accuracy for phenotyped individuals is too low."

    # Accuracy for non-phenotyped group should be lower, but still positive
    assert accuracy_non_phenotyped > 0.1, "Accuracy for non-phenotyped individuals should be positive."
    assert accuracy_non_phenotyped < accuracy_phenotyped, "Non-phenotyped accuracy should not be higher than phenotyped."

    # Average reliability should be lower for the non-phenotyped group
    reliability_phenotyped = jnp.mean(results.reliability[phenotyped_mask])
    reliability_non_phenotyped = jnp.mean(results.reliability[~phenotyped_mask])
    print(f"Average reliability (phenotyped):           {reliability_phenotyped:.4f}")
    print(f"Average reliability (non-phenotyped):       {reliability_non_phenotyped:.4f}")
    assert reliability_non_phenotyped < reliability_phenotyped, "Reliability should be lower for non-phenotyped individuals."
    
    print("\n🎉 All assertions passed!")



def test_gblup_heritability_effect():
    """
    Tests that prediction accuracy increases with higher heritability.
    """
    print("\n🚀 Starting GBLUP heritability effect test...")

    # --- Helper function to run a single simulation and prediction ---
    def run_sim_and_predict(key, h2):
        founder_key, trait_key, pheno_key = jax.random.split(key, 3)
        pop, gen_map = quick_haplo(key=founder_key, n_ind=250, n_chr=1, n_loci_per_chr=1000)
        sp = SimParam.from_founder_pop(pop, gen_map)
        sp = add_trait_a(
            key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=150,
            mean=jnp.array([0.]), var=jnp.array([1.])
        )
        pop_with_pheno = set_pheno(
            key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
            h2=jnp.array([h2])
        )
        results = gblup_predict(pop=pop_with_pheno, h2=h2)
        accuracy = jnp.corrcoef(pop_with_pheno.bv.flatten(), results.ebv.flatten())[0, 1]
        return accuracy

    # 1. Setup keys
    key = jax.random.PRNGKey(202)
    key_low_h2, key_high_h2 = jax.random.split(key)

    # 2. Run for low and high heritability
    h2_low = 0.2
    h2_high = 0.8
    print(f"📈 Running prediction for low h² ({h2_low})...")
    accuracy_low = run_sim_and_predict(key_low_h2, h2_low)
    
    print(f"📈 Running prediction for high h² ({h2_high})...")
    accuracy_high = run_sim_and_predict(key_high_h2, h2_high)

    print("\n--- ✅ Test Validation ---")
    print(f"Prediction Accuracy (h²={h2_low}): {accuracy_low:.4f}")
    print(f"Prediction Accuracy (h²={h2_high}): {accuracy_high:.4f}")

    # 3. Assert that higher heritability leads to higher accuracy
    assert accuracy_high > accuracy_low, "Accuracy should be higher for a more heritable trait."
    # Also check that both are reasonably positive
    assert accuracy_low > 0.2, "Low h2 accuracy is lower than expected."
    assert accuracy_high > 0.7, "High h2 accuracy is lower than expected."

    print("\n🎉 All assertions passed!")

def test_gblup_pop_size_effect_on_reliability():
    """
    Tests that the average reliability of predictions increases with a larger
    reference population size.
    """
    print("\n🚀 Starting GBLUP population size effect test...")

    # --- Helper function to run a single simulation and get avg reliability ---
    def get_avg_reliability(key, n_ind, h2):
        founder_key, trait_key, pheno_key = jax.random.split(key, 3)
        pop, gen_map = quick_haplo(
            key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=1000
        )
        sp = SimParam.from_founder_pop(pop, gen_map)
        sp = add_trait_a(
            key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=150,
            mean=jnp.array([0.]), var=jnp.array([1.])
        )
        pop_with_pheno = set_pheno(
            key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
            h2=jnp.array([h2])
        )
        results = gblup_predict(pop=pop_with_pheno, h2=h2)
        return jnp.mean(results.reliability)

    # 1. Setup keys and parameters
    key = jax.random.PRNGKey(303)
    key_small, key_large = jax.random.split(key)
    h2 = 0.5
    n_small = 100
    n_large = 500

    # 2. Run for small and large population sizes
    print(f"📈 Running prediction for small population (n={n_small})...")
    reliability_small = get_avg_reliability(key_small, n_ind=n_small, h2=h2)

    print(f"📈 Running prediction for large population (n={n_large})...")
    reliability_large = get_avg_reliability(key_large, n_ind=n_large, h2=h2)

    print("\n--- ✅ Test Validation ---")
    print(f"Average Reliability (n={n_small}): {reliability_small:.4f}")
    print(f"Average Reliability (n={n_large}): {reliability_large:.4f}")

    # 3. Assert that reliability is higher for the larger population
    assert reliability_large > reliability_small, "Reliability should increase with population size."
    
    print("\n🎉 All assertions passed!")


# --- To run all tests ---
# if __name__ == "__main__":
#     test_gblup_accuracy()
#     test_gblup_with_missing_phenotypes()
#     test_gblup_heritability_effect()
#     test_gblup_pop_size_effect_on_reliability()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()