In [1]:
import jax
import jax.numpy as jnp
import time

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2. Imports from the 'chewc' library ---
from chewc.sp import SimParam
from chewc.population import Population, msprime_pop # Use msprime_pop
from chewc.trait import add_trait_a
from chewc.burnin import run_burnin
# Import the generation runner
from chewc.pipe import run_generation




In [2]:
#| export
from typing import List, Optional

import jax
import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass
from jaxtyping import Array, Float, Int, PyTree

# Assuming Population and SimParam are in these locations
from chewc.population import Population
from chewc.sp import SimParam

# --- Base Loci and Trait Structures (Unchanged) ---

@flax_dataclass(frozen=True)
class LociMap:
    loci_per_chr: Int[Array, "nChr"]
    loci_loc: Int[Array, "nLoci"]
    name: str

    @property
    def n_loci(self) -> int:
        return self.loci_loc.shape[0]

@flax_dataclass(frozen=True)
class TraitA(LociMap):
    add_eff: Float[Array, "nLoci"]
    intercept: float = 0.0

@flax_dataclass(frozen=True)
class TraitCollection:
    loci_loc: Int[Array, "nLoci"]
    add_eff: Float[Array, "nTraits nLoci"]
    intercept: Float[Array, "nTraits"]

    @property
    def n_traits(self) -> int:
        return self.add_eff.shape[0]

    @property
    def n_loci(self) -> int:
        return self.loci_loc.shape[0]

# --- Helper Functions (Unchanged) ---

def _calculate_gvs_vectorized_alternative(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> Float[Array, "nInd nTraits"]:
    """Calculates all genetic values using a single matrix multiplication."""
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(pop.nInd, -1, ploidy)
    qtl_alleles = flat_geno_alleles[:, traits.loci_loc, :]
    qtl_geno = jnp.sum(qtl_alleles, axis=2)
    all_bv = jnp.dot(qtl_geno, traits.add_eff.T)
    all_gvs = all_bv + traits.intercept
    return all_gvs

# --- Public Trait-Adding Function (Refactored) ---
#| export
from typing import List, Optional

import jax
import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass
from jaxtyping import Array, Float, Int, PyTree

# Assuming Population and SimParam are in these locations
from chewc.population import Population
from chewc.sp import SimParam

# ... (Other dataclasses and helper functions remain the same) ...

# --- Public Trait-Adding Function (Refactored) ---
def add_trait_a(
    key: jax.random.PRNGKey,
    founder_pop: Population,
    sim_param: SimParam,
    n_qtl_per_chr: int,
    mean: Float[Array, "nTraits"],
    var: Float[Array, "nTraits"],
    cor_a: Optional[Float[Array, "nTraits nTraits"]] = None,
    gamma: bool = False,
    shape: float = 1.0
) -> SimParam:
    """
    Adds one or more new additive traits to the simulation parameters.
    ... (docstring remains the same) ...
    """
    # --- Input Validation ---
    # Calculate the number of segregating sites available per chromosome.
    # This assumes all chromosomes have the same number of loci.
    n_loci_per_chr = sim_param.gen_map.size // sim_param.n_chr
    if n_qtl_per_chr > n_loci_per_chr:
        raise ValueError(
            f"You requested n_qtl_per_chr={n_qtl_per_chr}, but there are only "
            f"{n_loci_per_chr} segregating loci available per chromosome in the "
            "founder population."
        )

    n_traits = mean.shape[0]
    assert mean.shape == var.shape, "Mean and variance vectors must have the same shape."
    if cor_a is None:
        cor_a = jnp.identity(n_traits)

    key, sample_key, qtl_key, sign_key = jax.random.split(key, 4)

    n_total_qtl = n_qtl_per_chr * sim_param.n_chr
    # Note: Assumes gen_map is a flat array of all loci positions
    all_loci_indices = jnp.arange(sim_param.gen_map.size)
    qtl_loc = jax.random.choice(qtl_key, all_loci_indices, shape=(n_total_qtl,), replace=False)
    qtl_loc = jnp.sort(qtl_loc)

    # --- Sample raw QTL effects from Normal or Gamma distribution ---
    if gamma:
        # Sample from gamma and randomly apply a sign for a symmetric distribution
        gamma_effects = jax.random.gamma(sample_key, shape, shape=(n_total_qtl, n_traits))
        signs = jax.random.choice(sign_key, jnp.array([-1.0, 1.0]), shape=(n_total_qtl, n_traits))
        raw_effects = gamma_effects * signs
    else:
        raw_effects = jax.random.normal(sample_key, (n_total_qtl, n_traits))

    # --- Correlate and Scale Effects (Logic remains the same) ---
    cholesky_matrix = jnp.linalg.cholesky(cor_a)
    correlated_raw_effects = jnp.dot(raw_effects, cholesky_matrix.T)

    temp_traits = TraitCollection(
        loci_loc=qtl_loc,
        add_eff=correlated_raw_effects.T,
        intercept=jnp.zeros(n_traits)
    )

    initial_bvs = _calculate_gvs_vectorized_alternative(founder_pop, temp_traits, sim_param.ploidy)
    initial_vars = jnp.var(initial_bvs, axis=0)

    scaling_factors = jnp.sqrt(var / (initial_vars + 1e-8))
    initial_means = jnp.mean(initial_bvs, axis=0)
    final_intercepts = mean - (initial_means * scaling_factors)

    final_add_eff = correlated_raw_effects * scaling_factors

    trait_collection = TraitCollection(
        loci_loc=qtl_loc,
        add_eff=final_add_eff.T,
        intercept=final_intercepts
    )

    return sim_param.replace(traits=trait_collection)

In [3]:
import jax
import jax.numpy as jnp

# Assuming these are the locations of your modules
from chewc.population import msprime_pop
from chewc.sp import SimParam
# from chewc.trait import add_trait_a, _calculate_gvs_vectorized_alternative


# JAX random key setup
key = jax.random.PRNGKey(42)
key, pop_key, trait_key = jax.random.split(key, 3)

# 1. Generate the founder population and its genetic map together.
founder_pop, genetic_map = msprime_pop(
    key=pop_key, n_ind=100, n_loci_per_chr=50, n_chr=3, ploidy=2
)

# 2. Use the founder population and its map to configure the simulation's rules.
SP = SimParam.from_founder_pop(
    founder_pop=founder_pop,
    gen_map=genetic_map,
    sexes="no"
)

# 3. Add two additive traits with a negative genetic correlation.
# Define trait means, variances, and the correlation matrix
trait_means = jnp.array([100.0, 50.0])
trait_vars = jnp.array([10.0, 5.0])
neg_cor_matrix = jnp.array([[1.0, -0.6],
                           [-0.6, 1.0]])

# Call the function to add the traits
SP = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=SP,
    n_qtl_per_chr=50,
    mean=trait_means,
    var=trait_vars,
    cor_a=neg_cor_matrix
)

# --- Verification (Optional) ---
# Calculate genetic values for the new traits in the founder population
gvs = _calculate_gvs_vectorized_alternative(founder_pop, SP.traits, SP.ploidy)

# Calculate the realized correlation of the genetic values
realized_cor = jnp.corrcoef(gvs, rowvar=False)


# --- Print Results ---
print("--- Configuration (derived from founders) ---")
print(SP)
print("\n--- Initial State ---")
print(founder_pop)
print("\n--- Added Trait Details ---")
print(SP.traits)
print(f"\nTarget Correlation:\n{neg_cor_matrix}")
print(f"\nRealized Correlation of Genetic Values:\n{realized_cor}")

--- Configuration (derived from founders) ---
SimParam(nChr=3, nTraits=2, ploidy=2, sexes='no')

--- Initial State ---
Population(nInd=100, nTraits=0, has_ebv=No)

--- Added Trait Details ---
TraitCollection(loci_loc=Array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 1

In [4]:
from chewc.phenotype import set_pheno # Import the new function

# 4. Calculate phenotypes for the population.
# Define the narrow-sense heritability for each trait.
heritabilities = jnp.array([0.3, 0.6])
# JAX random key setup
key = jax.random.PRNGKey(12)
key, pop_key, trait_key, pheno_key = jax.random.split(key, 4)

# Call the JIT-compiled set_pheno function
phenotyped_pop = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    h2=heritabilities
)


# --- Verification (Optional) ---
# Calculate the realized heritability from the phenotyped population
var_g = jnp.var(phenotyped_pop.bv, axis=0)
var_p = jnp.var(phenotyped_pop.pheno, axis=0)
realized_h2 = var_g / var_p


# --- Print Results ---
print("--- Final Population State (with Phenotypes) ---")
print(phenotyped_pop)
print("\n--- Phenotype Summary ---")
# Display the first 5 individuals' breeding values and phenotypes
for i in range(5):
    print(f"Ind {i+1}: BV={phenotyped_pop.bv[i]}, Pheno={phenotyped_pop.pheno[i]}")

print(f"\nTarget Heritabilities: {heritabilities}")
print(f"Realized Heritabilities: {realized_h2}")


--- Final Population State (with Phenotypes) ---
Population(nInd=100, nTraits=2, has_ebv=No)

--- Phenotype Summary ---
Ind 1: BV=[-0.7797859  1.1650431], Pheno=[110.31276  50.46098]
Ind 2: BV=[-9.329305   2.0032213], Pheno=[101.59567   49.227432]
Ind 3: BV=[-7.0039377  6.330233 ], Pheno=[106.53382  53.23699]
Ind 4: BV=[-4.487576   -0.06843241], Pheno=[105.65081  43.68624]
Ind 5: BV=[-10.379346   4.628529], Pheno=[99.61451  49.806454]

Target Heritabilities: [0.3 0.6]
Realized Heritabilities: [0.3824622 0.6015502]


In [5]:
#| export
# chewc/phenotype.py

from typing import Optional, Union
from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from chewc.population import Population
from chewc.trait import TraitCollection
from jaxtyping import Array, Float, Int

# This function remains a core, JIT-able utility
def _calculate_gvs_vectorized_alternative(
    pop: Population,
    traits: TraitCollection,
    ploidy: int
) -> tuple[Float[Array, "nInd nTraits"], Float[Array, "nInd nTraits"]]:
    """
    Calculates all genetic values using a single, highly-optimized matrix
    multiplication.
    """
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(pop.nInd, -1, ploidy)
    qtl_alleles = flat_geno_alleles[:, traits.loci_loc, :]
    qtl_geno = jnp.sum(qtl_alleles, axis=2)
    all_bv = jnp.dot(qtl_geno, traits.add_eff.T)
    all_gvs = all_bv + traits.intercept
    return all_bv, all_gvs

# NEW: A simplified internal function for just the noise calculation
def _add_environmental_noise(
    key: jax.random.PRNGKey,
    gvs: Float[Array, "nInd nTraits"],
    var_e: Float[Array, "nTraits"],
    cor_e: Float[Array, "nTraits nTraits"],
) -> Float[Array, "nInd nTraits"]:
    """
    Internal JIT-able function to generate and add environmental noise.
    """
    n_ind, n_traits = gvs.shape
    cov_e = jnp.diag(jnp.sqrt(var_e)) @ cor_e @ jnp.diag(jnp.sqrt(var_e))
    environmental_noise = jax.random.multivariate_normal(
        key, jnp.zeros(n_traits), cov_e, (n_ind,)
    )
    return gvs + environmental_noise


# REFACTORED: The public-facing function now orchestrates the steps
def set_pheno(
    key: jax.random.PRNGKey,
    pop: Population,
    traits: TraitCollection,
    ploidy: int,
    h2: Optional[Float[Array, "nTraits"]] = None,
    varE: Optional[Float[Array, "nTraits"]] = None,
    cor_e: Optional[Float[Array, "nTraits nTraits"]] = None,
) -> Population:
    """
    Sets phenotypes for a population based on its genetic values and
    either a specified heritability (h2) or environmental variance (varE).

    Exactly one of `h2` or `varE` must be provided.
    """
    # 1. --- Input Validation (Python Land) ---
    if (h2 is None and varE is None) or (h2 is not None and varE is not None):
        raise ValueError("Exactly one of 'h2' or 'varE' must be provided.")

    if cor_e is None:
        cor_e = jnp.identity(traits.n_traits)

    # 2. --- Core Genetic Calculation (JIT-able) ---
    # We bake `ploidy` into a partial function and JIT it
    gvs_calculator = jax.jit(
        partial(_calculate_gvs_vectorized_alternative, ploidy=ploidy)
    )
    bvs, gvs = gvs_calculator(pop=pop, traits=traits)

    # 3. --- Determine Environmental Variance (Python Land) ---
    # This logic now lives outside any JIT compilation path
    if h2 is not None:
        var_g = jnp.var(gvs, axis=0)
        # Add a small epsilon to prevent division by zero for traits with no variance
        var_e = (var_g / (h2 + 1e-8)) - var_g
        var_e = jnp.maximum(0, var_e) # Ensure variance is not negative
    else: # varE is not None
        var_e = varE

    # 4. --- Add Environmental Noise (JIT-able) ---
    # The noise calculator is a pure function, so we can JIT it directly
    noise_adder = jax.jit(_add_environmental_noise)
    pheno = noise_adder(key=key, gvs=gvs, var_e=var_e, cor_e=cor_e)

    # 5. --- Update Population ---
    return pop.replace(bv=bvs, pheno=pheno)

In [7]:
import jax
import jax.numpy as jnp

# Assuming the previous setup code has been run to create founder_pop and SP
# founder_pop, SP, ...

# from chewc.phenotype import set_pheno # Import the function

# Split the JAX key for two separate phenotyping operations
key = jax.random.PRNGKey(42) # A fresh key for this block
key, h2_key, varE_key = jax.random.split(key, 3)

# --- Example 1: Phenotyping with Heritability (h2) ---

# Define the narrow-sense heritability for each trait
heritabilities = jnp.array([0.3, 0.6])

# Call the JIT-compiled set_pheno function using the h2 argument
pop_with_h2 = set_pheno(
    key=h2_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    h2=heritabilities
)

# Verification for the h2-based population
var_g_h2 = jnp.var(pop_with_h2.bv, axis=0)
var_p_h2 = jnp.var(pop_with_h2.pheno, axis=0)
realized_h2 = var_g_h2 / var_p_h2

# --- Print h2 Results ---
print("--- Example 1: Phenotyping with Heritability (h2) ---")
print(f"\nTarget Heritabilities: {heritabilities}")
print(f"Realized Heritabilities: {realized_h2}")
print(f"\nGenetic Variance (from pop): {var_g_h2}")
print(f"Phenotypic Variance (from pop): {var_p_h2}")
print("\n--- Phenotype Summary (first 5 individuals) ---")
for i in range(5):
    print(f"Ind {i+1}: BV={pop_with_h2.bv[i]}, Pheno={pop_with_h2.pheno[i]}")


print("\n" + "="*60 + "\n") # Separator


# --- Example 2: Phenotyping with Environmental Variance (varE) ---

# Define the absolute environmental variance for each trait
environmental_variances = jnp.array([23.0, 3.0])

# Call the JIT-compiled set_pheno function using the varE argument
pop_with_varE = set_pheno(
    key=varE_key,
    pop=founder_pop,
    traits=SP.traits,
    ploidy=SP.ploidy,
    varE=environmental_variances
)

# Verification for the varE-based population
var_g_varE = jnp.var(pop_with_varE.bv, axis=0)
var_p_varE = jnp.var(pop_with_varE.pheno, axis=0)
implied_h2 = var_g_varE / var_p_varE

# --- Print varE Results ---
print("--- Example 2: Phenotyping with Environmental Variance (varE) ---")
print(f"\nTarget Environmental Variances: {environmental_variances}")
print(f"Implied Heritabilities: {implied_h2}") # Note: this is a result, not a target
print(f"\nGenetic Variance (from pop): {var_g_varE}")
print(f"Phenotypic Variance (from pop): {var_p_varE}")
print("\n--- Phenotype Summary (first 5 individuals) ---")
for i in range(5):
    print(f"Ind {i+1}: BV={pop_with_varE.bv[i]}, Pheno={pop_with_varE.pheno[i]}")

--- Example 1: Phenotyping with Heritability (h2) ---

Target Heritabilities: [0.3 0.6]
Realized Heritabilities: [0.2922263  0.65621984]

Genetic Variance (from pop): [10.000003  5.      ]
Phenotypic Variance (from pop): [34.220062  7.619398]

--- Phenotype Summary (first 5 individuals) ---
Ind 1: BV=[-0.7797859  1.1650431], Pheno=[108.899376  48.453247]
Ind 2: BV=[-9.329305   2.0032213], Pheno=[93.0332  46.67276]
Ind 3: BV=[-7.0039377  6.330233 ], Pheno=[93.84311 50.6401 ]
Ind 4: BV=[-4.487576   -0.06843241], Pheno=[99.97499  43.477955]
Ind 5: BV=[-10.379346   4.628529], Pheno=[95.51975  50.368103]


--- Example 2: Phenotyping with Environmental Variance (varE) ---

Target Environmental Variances: [23.  3.]
Implied Heritabilities: [0.28627202 0.6583786 ]

Genetic Variance (from pop): [10.000003  5.      ]
Phenotypic Variance (from pop): [34.93182   7.594414]

--- Phenotype Summary (first 5 individuals) ---
Ind 1: BV=[-0.7797859  1.1650431], Pheno=[108.04653  48.01157]
Ind 2: BV=[-9.32