# core

> Fill in a module description here

In [11]:
#| default_exp phenotype

In [12]:
#| hide
from nbdev.showdoc import *


In [13]:
#| export


from typing import Optional
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float


# --- Imports for Testing ---
from fastcore.test import test_eq, test_close, test_fail
import jax
import jax.numpy as jnp

# --- Project-Specific Imports ---
from chewc.population import Population      # <-- Import the Population class
from chewc.trait import TraitCollection       # <-- Import the TraitCollection class


In [14]:
#| export
# chewc/phenotype.py


@partial(jax.jit, static_argnames=('ploidy',))
def _calculate_gvs(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> tuple[Float[Array, "nInd nTraits"], Float[Array, "nInd nTraits"]]:
    """
    (JIT-compiled) Calculates all genetic values using a single, 
    highly-optimized matrix multiplication. 'ploidy' is a static argument
    because it affects array shapes during compilation.
    """
    # This logic is identical to your _calculate_gvs_vectorized_alternative
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(pop.nInd, -1, ploidy)
    qtl_alleles = flat_geno_alleles[:, traits.loci_loc, :]
    qtl_geno = jnp.sum(qtl_alleles, axis=2)
    all_bv = jnp.dot(qtl_geno, traits.add_eff.T)
    all_gvs = all_bv + traits.intercept
    return all_bv, all_gvs



#| export
# In chewc/phenotype.py, alongside set_pheno

def set_bv(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> Population:
    """
    Calculates and sets the true breeding values (bv) and genetic values (gv)
    for a population based on its genotypes and a given trait architecture.
    This function does NOT generate phenotypes or environmental noise.
    """
    # Call the pre-compiled JAX kernel
    bvs, gvs = _calculate_gvs(pop=pop, traits=traits, ploidy=ploidy)
    return pop.replace(bv=bvs, gv=gvs)

def set_pheno(
    key: jax.random.PRNGKey,
    pop: Population,
    traits: TraitCollection,
    ploidy: int,
    h2: Optional[Float[Array, "nTraits"]] = None,
    varE: Optional[Float[Array, "nTraits"]] = None,
    cor_e: Optional[Float[Array, "nTraits nTraits"]] = None,
) -> Population:
    """
    Sets phenotypes for a population based on its genetic values and
    either a specified heritability (h2) or environmental variance (varE).

    Exactly one of `h2` or `varE` must be provided.
    """
    # 1. --- Input Validation (Python Land) ---
    if (h2 is None and varE is None) or (h2 is not None and varE is not None):
        raise ValueError("Exactly one of 'h2' or 'varE' must be provided.")

    if cor_e is None:
        cor_e = jnp.identity(traits.n_traits)

    # 2. --- Core Genetic Calculation ---
    # Call the pre-compiled JAX kernel for calculating genetic values
    bvs, gvs = _calculate_gvs(pop=pop, traits=traits, ploidy=ploidy)

    # 3. --- Determine Environmental Variance (Python Land) ---
    # This logic remains outside the JIT path as it involves data-dependent
    # shape calculations that are fine to run on the CPU.
    if h2 is not None:
        var_g = jnp.var(gvs, axis=0)
        # Add a small epsilon to prevent division by zero for traits with no variance
        var_e = (var_g / (h2 + 1e-8)) - var_g
        var_e = jnp.maximum(0, var_e) # Ensure variance is not negative
    else: # varE is not None
        var_e = varE

    # 4. --- Add Environmental Noise ---
    # Call the pre-compiled JAX kernel for adding noise
    pheno = _add_environmental_noise(key=key, gvs=gvs, var_e=var_e, cor_e=cor_e)

    # 5. --- Update Population ---
    return pop.replace(bv=bvs, gv=gvs, pheno=pheno)


In [15]:
#| export
@jax.jit
def _add_environmental_noise(
    key: jax.random.PRNGKey,
    gvs: Float[Array, "nInd nTraits"],
    var_e: Float[Array, "nTraits"],
    cor_e: Float[Array, "nTraits nTraits"],
) -> Float[Array, "nInd nTraits"]:
    """
    (JIT-compiled) Internal function to generate and add environmental noise.
    """
    n_ind, n_traits = gvs.shape
    # Ensure var_e is non-negative before taking the square root
    safe_var_e = jnp.maximum(var_e, 0.)
    cov_e = jnp.diag(jnp.sqrt(safe_var_e)) @ cor_e @ jnp.diag(jnp.sqrt(safe_var_e))
    environmental_noise = jax.random.multivariate_normal(
        key, jnp.zeros(n_traits), cov_e, (n_ind,)
    )
    return gvs + environmental_noise


In [16]:
#| test



# --- Test Setup Helper ---

def _setup_phenotype_test():
    """Provides a consistent founder population and SimParam for tests."""
    key = jax.random.PRNGKey(42)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    n_ind = 100
    n_chr = 2
    n_loci_per_chr = 50
    ploidy = 2
    
    # Create a base population and SimParam
    founder_pop, gen_map = quick_haplo(founder_key, n_ind, n_chr, n_loci_per_chr, ploidy=ploidy)
    sp = SimParam.from_founder_pop(founder_pop, gen_map)
    
    # Add a simple trait to the SimParam for testing phenotype generation
    sp_with_trait = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=10,
        mean=jnp.array([10.0]),
        var=jnp.array([1.5]),
    )
    
    return {
        "key": pheno_key,
        "founder_pop": founder_pop,
        "sp": sp_with_trait,
        "ploidy": ploidy
    }

In [17]:
#| test

def test_calculate_gvs():
    "Test the internal GV calculation function for correct output shapes."
    # ARRANGE
    setup = _setup_phenotype_test()
    pop = setup["founder_pop"]
    sp = setup["sp"]
    
    # ACT
    bvs, gvs = _calculate_gvs_vectorized_alternative(pop, sp.traits, sp.ploidy)
    
    # ASSERT
    test_eq(gvs.shape, (pop.nInd, sp.n_traits))
    test_eq(bvs.shape, (pop.nInd, sp.n_traits))

def test_add_environmental_noise():
    "Test the internal noise-adding function for statistical correctness."
    # ARRANGE
    setup = _setup_phenotype_test()
    key = setup["key"]
    n_ind, n_traits = 1000, 2
    
    # Create some sample genetic values
    gvs = jnp.zeros((n_ind, n_traits)) 
    var_g = jnp.var(gvs, axis=0)
    
    # Define environmental variance and correlation
    var_e = jnp.array([1.0, 2.0])
    cor_e = jnp.array([[1.0, 0.5], [0.5, 1.0]])
    
    # ACT
    pheno = _add_environmental_noise(key, gvs, var_e, cor_e)
    
    # ASSERT
    # The mean of the phenotype should be close to the mean of the GV
    test_close(jnp.mean(pheno, axis=0), jnp.mean(gvs, axis=0), eps=0.1)
    
    # The variance of the phenotype should be close to Var(G) + Var(E)
    # The covariance should also be close to the target
    var_p = jnp.cov(pheno, rowvar=False)
    expected_cov_p = jnp.diag(jnp.sqrt(var_g)) @ jnp.identity(n_traits) @ jnp.diag(jnp.sqrt(var_g)) + \
                     jnp.diag(jnp.sqrt(var_e)) @ cor_e @ jnp.diag(jnp.sqrt(var_e))

    test_close(var_p, expected_cov_p, eps=0.15)

In [18]:
#| test

def test_set_pheno_with_h2():
    "Test the main set_pheno function using the heritability (h2) pathway."
    # ARRANGE
    setup = _setup_phenotype_test()
    pop = setup["founder_pop"]
    sp = setup["sp"]
    target_h2 = jnp.array([0.5])

    # ACT
    pop_with_pheno = set_pheno(
        key=setup["key"],
        pop=pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=target_h2
    )

    # ASSERT
    assert isinstance(pop_with_pheno, Population)
    test_eq(pop_with_pheno.pheno.shape, (pop.nInd, sp.n_traits))
    
    # Verify the heritability
    var_g = jnp.var(pop_with_pheno.gv, axis=0)
    var_p = jnp.var(pop_with_pheno.pheno, axis=0)
    calculated_h2 = var_g / var_p
    test_close(calculated_h2, target_h2, eps=0.1) # Use a larger tolerance due to sampling noise

def test_set_pheno_with_varE():
    "Test the main set_pheno function using the environmental variance (varE) pathway."
    # ARRANGE
    setup = _setup_phenotype_test()
    pop = setup["founder_pop"]
    sp = setup["sp"]
    target_var_e = jnp.array([2.0])

    # ACT
    pop_with_pheno = set_pheno(
        key=setup["key"],
        pop=pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        varE=target_var_e
    )

    # ASSERT
    test_eq(pop_with_pheno.pheno.shape, (pop.nInd, sp.n_traits))
    
    # Verify the phenotypic variance
    var_g = jnp.var(pop_with_pheno.gv, axis=0)
    var_p = jnp.var(pop_with_pheno.pheno, axis=0)
    expected_var_p = var_g + target_var_e
    test_close(var_p, expected_var_p, eps=0.1)

def test_set_pheno_validation():
    "Test that set_pheno raises errors for invalid arguments."
    # ARRANGE
    setup = _setup_phenotype_test()
    pop = setup["founder_pop"]
    sp = setup["sp"]
    
    # ACT & ASSERT
    # Test case 1: Providing both h2 and varE should fail
    test_fail(
        lambda: set_pheno(
            key=setup["key"], pop=pop, traits=sp.traits, ploidy=sp.ploidy, 
            h2=jnp.array([0.5]), varE=jnp.array([1.0])
        ),
        contains="Exactly one of 'h2' or 'varE' must be provided."
    )
    
    # Test case 2: Providing neither h2 nor varE should fail
    test_fail(
        lambda: set_pheno(
            key=setup["key"], pop=pop, traits=sp.traits, ploidy=sp.ploidy
        ),
        contains="Exactly one of 'h2' or 'varE' must be provided."
    )

In [19]:
#| hide
import nbdev; nbdev.nbdev_export()