# core

> Fill in a module description here

In [None]:
#| default_exp phenotype

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export


from typing import Optional
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float


# --- Imports for Testing ---
from fastcore.test import test_eq, test_close, test_fail
import jax
import jax.numpy as jnp

# --- Project-Specific Imports ---
from chewc.population import Population      # <-- Import the Population class
from chewc.trait import TraitCollection       # <-- Import the TraitCollection class


In [None]:
#| export
# chewc/phenotype.py

@partial(jax.jit, static_argnames=('ploidy',))
def _calculate_gvs(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> tuple[Float[Array, "nInd nTraits"], Float[Array, "nInd nTraits"]]:
    """
    (JIT-compiled) Calculates all genetic values for the entire padded population.
    'ploidy' is a static argument because it affects array shapes during compilation.
    """
    # --- FIX ---
    # Use the static shape from the .geno array itself, not the dynamic pop.nInd.
    n_ind_padded = pop.geno.shape[0]
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(n_ind_padded, -1, ploidy)
    qtl_alleles = flat_geno_alleles[:, traits.loci_loc, :]
    qtl_geno = jnp.sum(qtl_alleles, axis=2)
    all_bv = jnp.dot(qtl_geno, traits.add_eff.T)
    all_gvs = all_bv + traits.intercept
    return all_bv, all_gvs

def set_bv(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> Population:
    """
    Calculates and sets the true breeding values (bv) and genetic values (gv)
    for a population, respecting the padding for inactive individuals.
    """
    # Call the JIT-compiled kernel, which now operates on the full padded geno array
    bvs, gvs = _calculate_gvs(pop=pop, traits=traits, ploidy=ploidy)

    # --- NEW: Masking ---
    # Use the is_active mask to set values for padded individuals to NaN.
    # This prevents them from being used in downstream calculations like selection.
    final_bvs = jnp.where(pop.is_active[:, None], bvs, jnp.nan)
    final_gvs = jnp.where(pop.is_active[:, None], gvs, jnp.nan)

    return pop.replace(bv=final_bvs, gv=final_gvs)

def set_pheno(
    key: jax.random.PRNGKey,
    pop: Population,
    traits: TraitCollection,
    ploidy: int,
    h2: Optional[Float[Array, "nTraits"]] = None,
    varE: Optional[Float[Array, "nTraits"]] = None,
    cor_e: Optional[Float[Array, "nTraits nTraits"]] = None,
) -> Population:
    """
    Sets phenotypes for a population, handling padded arrays correctly.
    Calculations are performed only on active individuals.
    """
    # 1. --- Input Validation ---
    if (h2 is None and varE is None) or (h2 is not None and varE is not None):
        raise ValueError("Exactly one of 'h2' or 'varE' must be provided.")

    n_traits = traits.n_traits if traits is not None else 0
    if n_traits == 0:
        return pop  # No traits to add phenotypes for

    if cor_e is None:
        cor_e = jnp.identity(n_traits)

    # 2. --- Core Genetic Calculation ---
    bvs, gvs = _calculate_gvs(pop=pop, traits=traits, ploidy=ploidy)

    # 3. --- Determine Environmental Variance ---
    if h2 is not None:
        # Calculate variance ONLY on active individuals using the mask
        gvs_active = jnp.where(pop.is_active[:, None], gvs, jnp.nan)
        var_g = jnp.nanvar(gvs_active, axis=0)
        var_e = (var_g / (h2 + 1e-8)) - var_g
        var_e = jnp.maximum(0, var_e)
    else:
        var_e = varE

    # 4. --- Add Environmental Noise ---
    # Temporarily replace NaNs in padded part with 0s to avoid NaN propagation
    gvs_for_noise = jnp.nan_to_num(gvs, nan=0.0)
    pheno_full = _add_environmental_noise(key=key, gvs=gvs_for_noise, var_e=var_e, cor_e=cor_e)

    # 5. --- Mask Final Arrays & Update Population ---
    final_bvs = jnp.where(pop.is_active[:, None], bvs, jnp.nan)
    final_gvs = jnp.where(pop.is_active[:, None], gvs, jnp.nan)
    final_pheno = jnp.where(pop.is_active[:, None], pheno_full, jnp.nan)

    return pop.replace(bv=final_bvs, gv=final_gvs, pheno=final_pheno)

#| export
@jax.jit
def _add_environmental_noise(
    key: jax.random.PRNGKey,
    gvs: Float[Array, "nInd nTraits"],
    var_e: Float[Array, "nTraits"],
    cor_e: Float[Array, "nTraits nTraits"],
) -> Float[Array, "nInd nTraits"]:
    """
    (JIT-compiled) Internal function to generate and add environmental noise.
    """
    n_ind, n_traits = gvs.shape
    # Ensure var_e is non-negative before taking the square root
    safe_var_e = jnp.maximum(var_e, 0.)
    cov_e = jnp.diag(jnp.sqrt(safe_var_e)) @ cor_e @ jnp.diag(jnp.sqrt(safe_var_e))
    environmental_noise = jax.random.multivariate_normal(
        key, jnp.zeros(n_traits), cov_e, (n_ind,)
    )
    return gvs + environmental_noise


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()