# predict

> Common operations around the core datastructures for running a sim

In [2]:
#| default_exp predict

In [3]:
#| hide
from nbdev.showdoc import *

In [4]:
#| export

from flax.struct import dataclass as flax_dataclass
import jax.numpy as jnp
from typing import Optional
from chewc.population import Population

In [5]:
#| export
"""
GBLUP (Genomic Best Linear Unbiased Prediction) implementation for chewc library.

This module provides functions for genomic prediction using the GBLUP methodology,
which is a standard approach in genomic selection and animal breeding.
"""

from typing import Optional, Dict
import jax
import jax.numpy as jnp
from jax.numpy.linalg import solve, inv, pinv
from flax.struct import dataclass as flax_dataclass

from chewc.population import Population

@flax_dataclass(frozen=True)
class PredictionResults:
    """
    A container for the results of a genomic prediction.
    """
    ids: jnp.ndarray
    ebv: jnp.ndarray
    pev: Optional[jnp.ndarray] = None
    reliability: Optional[jnp.ndarray] = None
    fixed_effects: Optional[jnp.ndarray] = None
    h2_used: Optional[float] = None
    var_components: Optional[Dict] = None



In [7]:
#| export

from typing import Optional, Dict, Literal
import jax
import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass

# Make sure to import the necessary functions from your population module
from chewc.population import Population, calc_relationship_matrices, calc_a_inverse_matrix_pedigree_jax

@flax_dataclass(frozen=True)
class PredictionResults:
    """A container for the results of a prediction."""
    ids: jnp.ndarray
    ebv: jnp.ndarray
    pev: Optional[jnp.ndarray] = None
    reliability: Optional[jnp.ndarray] = None
    fixed_effects: Optional[jnp.ndarray] = None
    h2_used: Optional[float] = None
    var_components: Optional[Dict] = None


@jax.jit
def _mme_solver(
    y_train: jnp.ndarray,
    train_mask: jnp.ndarray,
    K_inv: jnp.ndarray,
    h2: float,
    n_ind: int
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    A general Mixed Model Equation (MME) solver for ABLUP and GBLUP.

    This function solves the MME for a model with a single fixed effect (intercept)
    and random additive genetic effects for all individuals. It is JIT-compiled for performance.

    Args:
        y_train: Phenotype vector for training individuals.
        train_mask: A boolean mask indicating which individuals are in the training set.
        K_inv: The inverse of the relationship matrix (A-inverse or G-inverse).
        h2: Heritability.
        n_ind: Total number of individuals.

    Returns:
        A tuple containing:
        - fixed_effects: The estimated fixed effects (intercept).
        - all_ebv: The estimated breeding values for ALL individuals.
        - C_inv: The inverse of the coefficient matrix of the LHS, for PEV calculation.
    """
    # --- Step 1: Set up Incidence Matrices ---
    # The incidence matrix for fixed effects (X) is a column of ones for the intercept.
    X = jnp.ones((n_ind, 1))
    # The incidence matrix for random effects (Z) is an identity matrix for an animal model.
    Z = jnp.identity(n_ind)

    X_train = X[train_mask]

    # --- Step 2: Build the Mixed Model Equations (MME) ---
    alpha = (1.0 - h2) / h2

    # Left-hand side (LHS)
    XTX = X_train.T @ X_train
    XTZ = X_train.T @ Z
    ZTX = Z.T @ X_train
    # This efficiently constructs the Z'R_inv*Z part of the MME,
    # where R_inv is a diagonal matrix with non-zero elements only for phenotyped individuals.
    ZRinvZ = jnp.diag(train_mask.astype(jnp.float32))
    ZTZ_plus_alpha_Kinv = ZRinvZ + alpha * K_inv

    lhs_top = jnp.hstack([XTX, XTZ])
    lhs_bottom = jnp.hstack([ZTX, ZTZ_plus_alpha_Kinv])
    lhs = jnp.vstack([lhs_top, lhs_bottom])

    # Right-hand side (RHS)
    XTy = X_train.T @ y_train
    # This constructs the Z'R_inv*y part of the MME.
    rhs_z = jnp.zeros(n_ind).at[train_mask].set(y_train.flatten())
    rhs = jnp.concatenate([XTy.flatten(), rhs_z])

    # --- Step 3: Solve the MME ---
    solutions = jnp.linalg.solve(lhs, rhs)
    C_inv = jnp.linalg.inv(lhs)

    n_fixed = X.shape[1]
    fixed_effects = solutions[:n_fixed]
    all_ebv = solutions[n_fixed:]

    return fixed_effects, all_ebv, C_inv


def mme_predict(
    pop: Population,
    h2: float = 0.5,
    trait_idx: int = 0,
    method: Literal['gblup', 'ablup'] = 'gblup'
) -> PredictionResults:
    """
    Perform genetic prediction using a single, unified Mixed Model Equation solve.

    This function can employ either GBLUP (using a genomic relationship matrix)
    or ABLUP (using a pedigree-based relationship matrix).

    Args:
        pop: Population object. Can contain individuals with phenotypes (training set)
             and individuals with NaN phenotypes (prediction set).
        h2: Heritability of the trait.
        trait_idx: Index of the trait to predict.
        method: The prediction method to use, either 'gblup' or 'ablup'.

    Returns:
        PredictionResults object containing EBVs, PEVs, and other metrics for ALL individuals.
    """
    pheno = pop.pheno[:, trait_idx:trait_idx+1]
    n_ind = pop.nInd

    # --- Step 1: Identify training individuals ---
    train_mask = ~jnp.isnan(pheno.flatten())
    y_train = pheno[train_mask]

    if jnp.sum(train_mask) == 0:
        raise ValueError("No individuals with phenotypes found for model training.")

    # --- Step 2: Calculate the required inverse relationship matrix (K_inv) ---
    if method == 'gblup':
        G = calc_relationship_matrices(pop, method="genomic")
        K_inv = jnp.linalg.inv(G)
    elif method == 'ablup':
        # Use the efficient JAX-native function to get A-inverse directly
        K_inv = calc_a_inverse_matrix_pedigree_jax(pop)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'gblup' or 'ablup'.")

    # --- Step 3: Solve the MME using the generic, JIT-compiled solver ---
    fixed_effects, all_ebv, C_inv = _mme_solver(
        y_train=y_train,
        train_mask=train_mask,
        K_inv=K_inv,
        h2=h2,
        n_ind=n_ind
    )

    # --- Step 4: Calculate Prediction Error Variance (PEV) and Reliability ---
    n_fixed = 1  # Assuming a single fixed effect (intercept)
    C22_inv = C_inv[n_fixed:, n_fixed:]
    var_p_train = jnp.var(y_train)

    # Estimate variance components from the data
    var_e_est = var_p_train * (1 - h2)
    var_a_est = jnp.var(all_ebv)

    pev = jnp.diag(C22_inv) * var_e_est
    # Clamp reliability to be non-negative
    reliability = jnp.maximum(0, 1 - (pev / (var_a_est + 1e-8)))

    return PredictionResults(
        ids=pop.id,
        ebv=all_ebv.reshape(-1, 1),
        pev=pev,
        reliability=reliability,
        fixed_effects=fixed_effects,
        h2_used=h2,
        var_components={'var_a': var_a_est, 'var_e': var_e_est}
    )

In [8]:
#| hide
import nbdev; nbdev.nbdev_export()