# trait

> Fill in a module description here

In [None]:
#| default_exp trait

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from typing import List, Optional

import jax
import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass
from jaxtyping import Array, Float, Int, PyTree

# Assuming Population and SimParam are in these locations
from chewc.pop import Population
from chewc.sp import SimParam

# --- Base Loci and Trait Structures ---

@flax_dataclass(frozen=True)
class LociMap:
    """
    Defines a set of loci on the genome. Serves as a base class for SNP chips and Traits.
    This is a JAX-native Pytree, making it compatible with JIT compilation.

    Attributes:
        loci_per_chr (Int[Array, "nChr"]): The number of loci on each chromosome.
        loci_loc (Int[Array, "nLoci"]): The specific indices of the loci on the genetic map.
        name (str): The name of this loci map (e.g., "Trait1", "Chip1").
                      Note: String attributes make this class a "static" Pytree.
    """
    loci_per_chr: Int[Array, "nChr"]
    loci_loc: Int[Array, "nLoci"]
    name: str

    @property
    def n_loci(self) -> int:
        return self.loci_loc.shape[0]

@flax_dataclass(frozen=True)
class TraitA(LociMap):
    """
    Defines an additive trait, extending LociMap with genetic effect parameters.

    Attributes:
        add_eff (Float[Array, "nLoci"]): The additive effect for each locus.
        intercept (float): A global value added to the genetic value of all individuals
                           for this trait, used to adjust the trait mean.
    """
    add_eff: Float[Array, "nLoci"]
    intercept: float = 0.0


# --- Helper Functions ---

def _calculate_genetic_params(
    trait: TraitA,
    pop: Population,
    sim_param: SimParam # FIX: Added sim_param to the function signature
) -> PyTree:
    """
    Calculates genetic parameters (mean, variance) for a given trait in a population.
    This is a pure function, making it JAX-composable.

    Args:
        trait: A TraitA object defining the QTL effects.
        pop: A Population object containing the genotypes.
        sim_param: The simulation parameters, needed for rules like ploidy.

    Returns:
        A Pytree (dict) containing the breeding values (bv), genetic values (gv),
        and the population variance of the breeding values.
    """
    # 1. Flatten the chromosome and loci dimensions of the main genotype array.
    # Reshaped shape: (nInd, nChr * nLoci, ploidy)
    # FIX: Using sim_param.ploidy instead of pop.ploidy
    flat_geno_alleles = pop.geno.transpose((0, 1, 3, 2)).reshape(pop.nInd, -1, sim_param.ploidy)

    # 2. Use the trait's global loci locations to select the QTL alleles.
    qtl_alleles = flat_geno_alleles[:, trait.loci_loc, :]

    # 3. Sum across the ploidy dimension to get the genotype state for each QTL.
    qtl_geno = jnp.sum(qtl_alleles, axis=2)

    # 4. Calculate breeding values (bv).
    bv = jnp.dot(qtl_geno, trait.add_eff)

    # For an additive trait, the total genetic value (gv) is the breeding value plus the intercept.
    gv = bv + trait.intercept

    return {
        "bv": bv,
        "gv": gv,
        "var_bv": jnp.var(bv)
    }


# --- Public Trait-Adding Function ---
def add_trait_a(
    key: jax.random.PRNGKey,
    sim_param: SimParam,
    n_qtl_per_chr: int,
    mean: Float[Array, "nTraits"],
    var: Float[Array, "nTraits"],
    cor_a: Optional[Float[Array, "nTraits nTraits"]] = None
) -> SimParam:
    """
    Adds one or more new additive traits to the simulation parameters.
    ...
    """
    # ... (code for validation and QTL selection is unchanged) ...
    n_traits = mean.shape[0]
    assert mean.shape == var.shape, "Mean and variance vectors must have the same shape."
    if cor_a is None:
        cor_a = jnp.identity(n_traits)
    else:
        assert cor_a.shape == (n_traits, n_traits), "Correlation matrix has incorrect dimensions."
        assert jnp.allclose(cor_a, cor_a.T), "Correlation matrix must be symmetric."

    key, sample_key, qtl_key = jax.random.split(key, 3)

    n_total_qtl = n_qtl_per_chr * sim_param.n_chr
    all_loci_indices = jnp.arange(sim_param.gen_map.size)
    qtl_loc = jax.random.choice(qtl_key, all_loci_indices, shape=(n_total_qtl,), replace=False)
    qtl_loc = jnp.sort(qtl_loc)
    
    loci_per_chr = jnp.full((sim_param.n_chr,), n_qtl_per_chr)
    
    base_qtl_map = LociMap(
        loci_per_chr=loci_per_chr,
        loci_loc=qtl_loc,
        name="temp_base_map"
    )

    raw_effects = jax.random.normal(sample_key, (base_qtl_map.n_loci, n_traits))
    cholesky_matrix = jnp.linalg.cholesky(cor_a)
    correlated_raw_effects = jnp.dot(raw_effects, cholesky_matrix.T)

    new_traits = []
    for i in range(n_traits):
        temp_trait = TraitA(
            loci_per_chr=base_qtl_map.loci_per_chr,
            loci_loc=base_qtl_map.loci_loc,
            name=f"temp_trait_{i}",
            add_eff=correlated_raw_effects[:, i]
        )

        # Calculate the *initial* genetic variance using the founder population
        # FIX: Pass the sim_param object to the helper function
        genetic_params = _calculate_genetic_params(temp_trait, sim_param.founderPop, sim_param)
        initial_var = genetic_params['var_bv']

        # Calculate the scaling factor
        scaling_factor = jnp.sqrt(var[i] / (initial_var + 1e-8))

        # Scale the effects
        scaled_effects = temp_trait.add_eff * scaling_factor

        # Calculate the intercept to achieve the desired mean
        initial_mean_bv = jnp.mean(genetic_params['bv'])
        intercept = mean[i] - (initial_mean_bv * scaling_factor)

        # Create the final, correctly scaled trait object
        final_trait = TraitA(
            loci_per_chr=base_qtl_map.loci_per_chr,
            loci_loc=base_qtl_map.loci_loc,
            name=f"Trait{sim_param.n_traits + i + 1}",
            add_eff=scaled_effects,
            intercept=intercept
        )
        new_traits.append(final_trait)

    # Return a new, updated SimParam
    return sim_param.replace(
        traits=sim_param.traits + new_traits
    )


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()