# core

> Fill in a module description here

In [1]:
#| default_exp phenotype

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
# chewc/phenotype.py

from typing import Optional
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from chewc.population import Population
from chewc.trait import TraitCollection, _calculate_gvs_vectorized_alternative # IMPORT FIXED
from jaxtyping import Array, Float, Int


# This function remains a core, JIT-able utility
def _calculate_gvs_vectorized_alternative(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> tuple[Float[Array, "nInd nTraits"], Float[Array, "nInd nTraits"]]:
    """
    Calculates all genetic values using a single, highly-optimized matrix
    multiplication.
    """
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(pop.nInd, -1, ploidy)
    qtl_alleles = flat_geno_alleles[:, traits.loci_loc, :]
    qtl_geno = jnp.sum(qtl_alleles, axis=2)
    all_bv = jnp.dot(qtl_geno, traits.add_eff.T)
    all_gvs = all_bv + traits.intercept
    return all_bv, all_gvs

# NEW: A simplified internal function for just the noise calculation
def _add_environmental_noise(
    key: jax.random.PRNGKey,
    gvs: Float[Array, "nInd nTraits"],
    var_e: Float[Array, "nTraits"],
    cor_e: Float[Array, "nTraits nTraits"],
) -> Float[Array, "nInd nTraits"]:
    """
    Internal JIT-able function to generate and add environmental noise.
    """
    n_ind, n_traits = gvs.shape
    cov_e = jnp.diag(jnp.sqrt(var_e)) @ cor_e @ jnp.diag(jnp.sqrt(var_e))
    environmental_noise = jax.random.multivariate_normal(
        key, jnp.zeros(n_traits), cov_e, (n_ind,)
    )
    return gvs + environmental_noise


# REFACTORED: The public-facing function now orchestrates the steps
def set_pheno(
    key: jax.random.PRNGKey,
    pop: Population,
    traits: TraitCollection,
    ploidy: int,
    h2: Optional[Float[Array, "nTraits"]] = None,
    varE: Optional[Float[Array, "nTraits"]] = None,
    cor_e: Optional[Float[Array, "nTraits nTraits"]] = None,
) -> Population:
    """
    Sets phenotypes for a population based on its genetic values and
    either a specified heritability (h2) or environmental variance (varE).

    Exactly one of `h2` or `varE` must be provided.
    """
    # 1. --- Input Validation (Python Land) ---
    if (h2 is None and varE is None) or (h2 is not None and varE is not None):
        raise ValueError("Exactly one of 'h2' or 'varE' must be provided.")

    if cor_e is None:
        cor_e = jnp.identity(traits.n_traits)

    # 2. --- Core Genetic Calculation (JIT-able) ---
    # We bake `ploidy` into a partial function and JIT it
    gvs_calculator = jax.jit(
        partial(_calculate_gvs_vectorized_alternative, ploidy=ploidy)
    )
    bvs, gvs = gvs_calculator(pop=pop, traits=traits)

    # 3. --- Determine Environmental Variance (Python Land) ---
    # This logic now lives outside any JIT compilation path
    if h2 is not None:
        var_g = jnp.var(gvs, axis=0)
        # Add a small epsilon to prevent division by zero for traits with no variance
        var_e = (var_g / (h2 + 1e-8)) - var_g
        var_e = jnp.maximum(0, var_e) # Ensure variance is not negative
    else: # varE is not None
        var_e = varE

    # 4. --- Add Environmental Noise (JIT-able) ---
    # The noise calculator is a pure function, so we can JIT it directly
    noise_adder = jax.jit(_add_environmental_noise)
    pheno = noise_adder(key=key, gvs=gvs, var_e=var_e, cor_e=cor_e)

    # 5. --- Update Population ---
    return pop.replace(bv=bvs, pheno=pheno)

In [4]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
