# SimParam

> The SimParam class is the cornerstone of a simulation, acting as a global container for parameters that are not specific to any single Population but govern the rules of the entire simulation. Just as the Population class holds the state of individuals, SimParam holds the "genetic laws" of the simulated world, such as trait architecture, genetic maps, and SNP chip definitions.



In [14]:
#| default_exp sp

In [15]:
#| hide
from nbdev.showdoc import *

In [16]:
#| export

import jax.numpy as jnp
from flax.struct import dataclass as flax_dataclass
from dataclasses import field
from typing import List, Optional, TYPE_CHECKING
from functools import cached_property

# The TYPE_CHECKING block is still essential for static type analysis.
if TYPE_CHECKING:
    from .population import Population
    from .trait import TraitA, LociMap, TraitCollection

@flax_dataclass(frozen=True)
class SimParam:
    """
    A container for all global simulation parameters, initialized from a
    founder population.

    --- JAX JIT Compilation Notes ---

    This class is a JAX Pytree. Attributes are derived from the founderPop
    using @cached_property to ensure compatibility with Flax's frozen data
    structures. When an instance is passed to a JIT-compiled function, the
    function is compiled specifically for the values of its static attributes.

    - **Dynamic Attributes (Tracable)**: `jnp.ndarray`s like `gen_map`,
      `centromere`, `pedigree`, and `var_e`. Their values can change
      without causing re-compilation.

    - **Static Attributes (Non-Tracable)**: All other types (`int`, `str`,
      `bool`, custom classes). A change in these values *will* trigger
      re-compilation. This includes `ploidy`, `traits`, `sexes`, etc.
    """
    # --- Population State Reference (Primary Input) ---
    founderPop: 'Population'

    # --- Genetic Architecture ---
    traits: Optional['TraitCollection'] = None
    snp_chips: List['LociMap'] = field(default_factory=list)

    # --- Simulation Control ---
    sexes: str = "no"
    recomb_params: tuple = (2.6, 0.0, 0.0)
    n_threads: int = 1

    # --- Pedigree & History Tracking ---
    track_pedigree: bool = False
    track_recomb: bool = False
    pedigree: Optional[jnp.ndarray] = None

    # --- Default Phenotyping Parameters ---
    var_e: Optional[jnp.ndarray] = None

    # --- Core Genome Structure (Derived Properties) ---
    @cached_property
    def ploidy(self) -> int:
        """Derives ploidy from the founder population's genotype shape."""
        return self.founderPop.geno.shape[2]

    @cached_property
    def gen_map(self) -> jnp.ndarray:
        """Extracts the genetic map from the founder population's metadata."""
        if 'genetic_map_cm' not in self.founderPop.miscPop:
            raise ValueError(
                "Founder population must have 'genetic_map_cm' in its "
                "`miscPop` dictionary. Use a function like `msprime_pop` "
                "to generate it."
            )
        return self.founderPop.miscPop['genetic_map_cm']

    @cached_property
    def centromere(self) -> jnp.ndarray:
        """Sets a default centromere position for each chromosome."""
        return jnp.zeros(self.n_chr)

    @cached_property
    def last_id(self) -> int:
        """Initializes last_id based on the founder population size."""
        return self.founderPop.nInd

    @property
    def n_chr(self) -> int:
        """Returns the number of chromosomes."""
        return self.gen_map.shape[0]

    @property
    def n_loci_per_chr(self) -> jnp.ndarray:
        """Returns an array with the number of loci for each chromosome."""
        # This assumes a uniform number of loci, consistent with current geno shape
        return jnp.full((self.n_chr,), self.gen_map.shape[1])

    @property
    def n_traits(self) -> int:
        """Returns the number of defined traits."""
        if self.traits is None:
            return 0
        return self.traits.n_traits

    def __repr__(self) -> str:
        # Accessing properties like self.n_chr will trigger their calculation
        # and caching on the first call.
        return (f"SimParam(nChr={self.n_chr}, nTraits={self.n_traits}, "
                f"ploidy={self.ploidy}, sexes='{self.sexes}')")

In [17]:
#| hide
import nbdev; nbdev.nbdev_export()