# predict

> Common operations around the core datastructures for running a sim

In [2]:
#| default_exp pipe

In [3]:
#| hide
from nbdev.showdoc import *

In [6]:
#| export

from flax.struct import dataclass as flax_dataclass
import jax.numpy as jnp
from typing import Optional

In [7]:

@flax_dataclass(frozen=True)
class PredictionResults:
    """
    A container for the results of a genomic prediction.

    This object is a PyTree, making it compatible with JAX transformations.

    Attributes:
        ids (jnp.ndarray): The public IDs of the individuals.
        ebv (jnp.ndarray): The estimated breeding values. Shape: (nInd, nTraits).
        pev (Optional[jnp.ndarray]): The prediction error variance for each EBV.
        reliability (Optional[jnp.ndarray]): The reliability of each EBV.
        fixed_effects (Optional[jnp.ndarray]): The estimated fixed effects.
        h2_used (Optional[float]): The heritability value used for the prediction.
    """
    ids: jnp.ndarray
    ebv: jnp.ndarray
    pev: Optional[jnp.ndarray] = None
    reliability: Optional[jnp.ndarray] = None
    fixed_effects: Optional[jnp.ndarray] = None
    h2_used: Optional[float] = None

In [8]:
#| export
import jax.numpy as jnp
from jax.numpy.linalg import solve, inv
from chewc.population import Population

def gblup_predict(pop: Population, h2: float) -> PredictionResults:
    """
    Predicts breeding values using GBLUP.

    Args:
        pop: A Population object with phenotypes.
        h2: The heritability of the trait.

    Returns:
        A new Population object with ebv and pev attributes.
    """
    # 1. Get phenotypes and genotypes
    y = pop.pheno[:, 0] # Assuming a single trait for now
    M = pop.dosage

    n_ind, n_loci = M.shape

    # 2. Calculate the Genomic Relationship Matrix (G)
    p = jnp.mean(M, axis=0) / 2
    W = M - 2 * p
    g_numerator = W @ W.T
    g_denominator = 2 * jnp.sum(p * (1 - p))
    G = g_numerator / g_denominator

    # 3. Set up the Mixed Model Equations (MME)
    # For a simple model with only an intercept as a fixed effect
    X = jnp.ones((n_ind, 1))

    # Calculate variance components from heritability
    var_y = jnp.var(y)
    var_g = h2 * var_y
    var_e = (1 - h2) * var_y
    lambda_val = var_e / var_g

    # Construct the MME coefficient matrix
    C11 = X.T @ X
    C12 = X.T # Since Z is identity
    C21 = X   # Since Z is identity
    # Add a small value to the diagonal of G for numerical stability
    G_inv = inv(G + jnp.eye(n_ind) * 1e-6)
    C22 = jnp.eye(n_ind) + G_inv * lambda_val

    # MME Left-Hand Side (LHS)
    LHS = jnp.block([[C11, C12], [C21, C22]])

    # MME Right-Hand Side (RHS)
    RHS = jnp.concatenate([X.T @ y, y])

    # 4. Solve the MME
    solutions = solve(LHS, RHS)
    b_hat = solutions[:1]
    u_hat = solutions[1:]

    # 5. Calculate Prediction Error Variance (PEV)
    C_inv = inv(LHS)
    pev = jnp.diag(C_inv[1:, 1:]) * var_e
    
    # Optionally, calculate reliability
    reliability = 1 - (pev / var_g)
    return PredictionResults(
        ids=pop.id,
        ebv=u_hat[:, jnp.newaxis], # Reshape to (nInd, 1)
        pev=pev,
        reliability=reliability,
        fixed_effects=b_hat,
        h2_used=h2
    )

In [None]:
# Assuming gblup_predict is in a module called `prediction`
import jax
from chewc.population import quick_haplo

key = jax.random.PRNGKey(0)
founder_pop, genetic_map = quick_haplo(key, n_ind=100, n_chr=1, n_loci_per_chr=1000)


from chewc.sp import SimParam
from chewc.trait import add_trait_a
import jax.numpy as jnp

sp = SimParam.from_founder_pop(founder_pop, genetic_map)

# Define a single trait with 100 QTLs
sp = add_trait_a(
    key,
    founder_pop,
    sp,
    n_qtl_per_chr=100,
    mean=jnp.array([0]),
    var=jnp.array([1.0])
)

from chewc.phenotype import set_pheno

# Generate phenotypes with a heritability of 0.5
h2 = 0.5
founder_pop_with_pheno = set_pheno(key, founder_pop, sp.traits, sp.ploidy, h2=jnp.array([h2]))

# GBLUP


predicted_pop = gblup_predict(founder_pop_with_pheno, h2)




In [5]:
#| hide
import nbdev; nbdev.nbdev_export()