In [2]:
#| default_exp k_pheno

In [3]:
#| export
from __future__ import annotations
import jax, jax.numpy as jnp
from jax import random


In [4]:
#| export
#chewc/k_select.py

def compute_bv_additive(
    geno: jnp.ndarray,        # (N, n_chr, ploidy, n_loci)
    is_active: jnp.ndarray,   # (N,)
    trait_effects: jnp.ndarray,  # (n_chr, n_loci, T)
    ploidy: int,
) -> jnp.ndarray:
    """
    Sum allele effects per locus and chr; return (N, T) BV.
    Assumes trait_effects aligns with loci per chr (no ragged).
    """
    # dosage across ploidy: (N, n_chr, n_loci)
    dosage = jnp.sum(geno, axis=2)
    # per-trait sum: (N, n_chr, T)
    per_chr = jnp.einsum("ncl,clt->nct", dosage, trait_effects)
    bv = jnp.sum(per_chr, axis=1)  # (N, T)
    # keep inactive rows (not strictly needed, but harmless)
    return bv

def add_env_noise(
    key: jax.Array,
    bv: jnp.ndarray,          # (N, T)
    is_active: jnp.ndarray,   # (N,)
    h2: jnp.ndarray,          # (T,)
) -> jnp.ndarray:
    """
    Phenotype = BV + e, with Var(e) = Var(BV)*(1-h2)/h2 per trait.
    Uses per-trait empirical var(BV_active) to scale noise.
    """
    N, T = bv.shape
    # compute var on active only
    active = is_active
    bv_act = jnp.where(active[:, None], bv, jnp.nan)
    mean_bv = jnp.nanmean(bv_act, axis=0)
    var_bv  = jnp.nanvar(bv_act, axis=0) + 1e-8
    var_e   = var_bv * (1.0 - h2) / jnp.clip(h2, 1e-8, 1.0 - 1e-8)

    key_noise = key
    noise = random.normal(key_noise, shape=(N, T)) * jnp.sqrt(var_e)[None, :]
    pheno = bv + jnp.where(active[:, None], noise, 0.0)
    return pheno

def set_pheno_kernel(
    key: jax.Array,
    geno: jnp.ndarray,
    is_active: jnp.ndarray,
    trait_effects: jnp.ndarray,  # (n_chr, n_loci, T)
    ploidy: int,
    h2: jnp.ndarray,             # (T,)
):
    """
    Returns (bv, pheno), both (N, T). Only active rows get fresh noise.
    """
    bv = compute_bv_additive(geno, is_active, trait_effects, ploidy)
    pheno = add_env_noise(key, bv, is_active, h2)
    return bv, pheno


In [5]:
#| hide
import nbdev; nbdev.nbdev_export()