# Population

> Data Structures for Population

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from dataclasses import field
from typing import List, Optional, Dict, Callable

from flax.struct import dataclass as flax_dataclass # Using flax's dataclass for JAX-friendliness
import jax
import jax.numpy as jnp

# --- Utility Functions ---

def _pad_and_stack_arrays(arrays: List[jnp.ndarray], pad_value: float = jnp.nan) -> jnp.ndarray:
    """
    Pads a list of JAX arrays to a uniform shape and stacks them.

    This function is critical for handling data where entries can have
    variable lengths, such as genetic maps for different chromosomes. JAX's
    core transformations (like `vmap` for vectorization) require arrays to have
    a uniform shape for efficient batching on accelerators (GPU/TPU). This
    utility converts a Python list of arrays into a single, rectangular
    JAX array that meets this requirement.

    Using `jnp.nan` as the default `pad_value` is a robust choice for
    genetic maps, as it's an unambiguous "missing value" marker.
    """
    if not arrays:
        return jnp.array([])
    max_len = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        jnp.pad(arr, (0, max_len - arr.shape[0]), 'constant', constant_values=pad_value)
        for arr in arrays
    ]
    return jnp.stack(padded_arrays)

# --- Core Population Structure ---

@flax_dataclass(frozen=True) # Make the class immutable, a JAX best practice
class Population:
    """
    A container for all data related to a population of individuals, designed
    for JAX-based genetic simulations.

    This structure is immutable. All operations that modify a population should
    return a new Population object.

    Attributes:
        geno (jnp.ndarray): A 4D array representing the genotypes of the population.
            Shape: `(nInd, nChr, ploidy, nLoci)`. dtype: `jnp.uint8`.
        genMap (jnp.ndarray): A 2D array holding the genetic map positions for each
            locus on each chromosome. Positions are in Morgans. Padded with jnp.nan.
            Shape: `(nChr, nLoci)`.
        centromere (jnp.ndarray): A 1D array with the centromere position (in Morgans)
            for each chromosome. Shape: `(nChr,)`.
        ploidy (int): The ploidy level of the individuals (e.g., 2 for diploid).

        id (jnp.ndarray): The primary, user-facing identifier for each individual.
            These IDs may not be contiguous or sorted. Shape: `(nInd,)`.
        iid (jnp.ndarray): The internal, zero-indexed, contiguous identifier.
            Crucial for robust indexing in JAX operations. Shape: `(nInd,)`.
        mother (jnp.ndarray): Array of internal IDs (`iid`) for the mother of each
            individual. A value of -1 indicates no known mother. Shape: `(nInd,)`.
        father (jnp.ndarray): Array of internal IDs (`iid`) for the father of each
            individual. A value of -1 indicates no known father. Shape: `(nInd,)`.
        sex (jnp.ndarray): The sex of each individual, represented numerically
            (e.g., 0 for male, 1 for female). dtype: `jnp.int8`. Shape: `(nInd,)`.

        gv (jnp.ndarray): The genetic values (true breeding values) for each
            individual across one or more traits. Shape: `(nInd, nTraits)`.
        pheno (jnp.ndarray): The phenotypic values for each individual.
            Shape: `(nInd, nTraits)`.
        fixEff (jnp.ndarray): The value of a fixed effect for each individual,
            often used as an intercept in genomic selection models. Shape: `(nInd,)`.
        ebv (Optional[jnp.ndarray]): The estimated breeding values for each
            individual. Shape: `(nInd, nTraits)`.
        gxe (Optional[jnp.ndarray]): Genotype-by-environment interaction effects.
            Shape depends on the specific GxE model.

        misc (Dict): A dictionary for storing miscellaneous, non-JAX-critical
            metadata about individuals.
        miscPop (Dict): A dictionary for storing miscellaneous, non-JAX-critical
            metadata about the entire population.
    """
    # --- Core Genotype and Map Info ---
    geno: jnp.ndarray
    genMap: jnp.ndarray
    centromere: jnp.ndarray
    ploidy: int

    # --- Pedigree and Identifiers ---
    id: jnp.ndarray
    iid: jnp.ndarray
    mother: jnp.ndarray
    father: jnp.ndarray
    sex: jnp.ndarray

    # --- Trait and Value Data ---
    gv: jnp.ndarray
    pheno: jnp.ndarray
    fixEff: jnp.ndarray
    ebv: Optional[jnp.ndarray] = None
    gxe: Optional[jnp.ndarray] = None

    # --- Metadata ---
    misc: Dict = field(default_factory=dict)
    miscPop: Dict = field(default_factory=dict)

    def __post_init__(self):
        """Validates the consistency of the population data after initialization."""
        n_ind = self.nInd
        assert self.iid.shape[0] == n_ind, f"Internal ID array length ({self.iid.shape[0]}) must match number of individuals ({n_ind})."
        assert self.mother.shape[0] == n_ind, f"Mother array length ({self.mother.shape[0]}) must match number of individuals ({n_ind})."
        assert self.father.shape[0] == n_ind, f"Father array length ({self.father.shape[0]}) must match number of individuals ({n_ind})."
        assert self.sex.shape[0] == n_ind, f"Sex array length ({self.sex.shape[0]}) must match number of individuals ({n_ind})."
        assert self.gv.shape[0] == n_ind, f"Genetic value (gv) array length ({self.gv.shape[0]}) must match number of individuals ({n_ind})."
        assert self.pheno.shape[0] == n_ind, f"Phenotype (pheno) array length ({self.pheno.shape[0]}) must match number of individuals ({n_ind})."
        assert self.fixEff.shape[0] == n_ind, f"Fixed effect (fixEff) array length ({self.fixEff.shape[0]}) must match number of individuals ({n_ind})."
        if self.ebv is not None:
            assert self.ebv.shape[0] == n_ind, f"EBV array length ({self.ebv.shape[0]}) must match number of individuals ({n_ind})."

        assert self.genMap.shape[0] == self.nChr, f"genMap nChr ({self.genMap.shape[0]}) must match geno nChr ({self.nChr})."
        assert self.centromere.shape[0] == self.nChr, f"centromere nChr ({self.centromere.shape[0]}) must match geno nChr ({self.nChr})."
        assert self.geno.shape[1] == self.nChr

    @property
    def nInd(self) -> int:
        """Returns the number of individuals in the population."""
        return self.geno.shape[0]

    @property
    def nChr(self) -> int:
        """Returns the number of chromosomes."""
        return self.geno.shape[1]

    @property
    def nLoci(self) -> int:
        """Returns the number of loci per chromosome."""
        return self.geno.shape[3]

    @property
    def nTraits(self) -> int:
        """Returns the number of traits."""
        return self.gv.shape[1] if self.gv.ndim > 1 else 0

    def __repr__(self) -> str:
        """Provides a concise representation of the Population object."""
        return (f"Population(nInd={self.nInd}, nChr={self.nChr}, nLoci={self.nLoci}, "
                  f"nTraits={self.nTraits}, ploidy={self.ploidy}, "
                  f"has_ebv={'Yes' if self.ebv is not None else 'No'})")


# --- Factory Functions ---

def quick_haplo(key: jax.random.PRNGKey, n_ind: int, n_chr: int, seg_sites: int, ploidy: int = 2, gen_len: int = 1, inbred: bool = False) -> Population:
    """
    Creates a new population with randomly generated haplotypes, analogous to
    AlphaSimR's `quickHaplo` function.

    Args:
        key: A JAX random key. Must be provided by the user to ensure
             reproducibility.
        n_ind: Number of individuals.
        n_chr: Number of chromosomes.
        seg_sites: Number of segregating sites (loci) per chromosome.
        ploidy: Ploidy level of the individuals.
        gen_len: The genetic length of the chromosomes in Morgans.
        inbred: If True, individuals will be fully inbred (homozygous at all loci).

    Returns:
        A new Population object with random founder individuals.
    """
    key, geno_key, map_key, sex_key = jax.random.split(key, 4)

    # Generate random haplotypes for all individuals and chromosomes
    # Shape: (nInd, nChr, ploidy, nLoci)
    if inbred:
        # Generate one set of haplotypes and tile it across the ploidy axis
        base_haplotypes = jax.random.randint(geno_key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.uint8)
        geno = jnp.tile(base_haplotypes, (1, 1, ploidy, 1))
    else:
        # Fully random, outbred individuals
        geno = jax.random.randint(geno_key, (n_ind, n_chr, ploidy, seg_sites), 0, 2, dtype=jnp.uint8)

    # Create genetic map using JAX for random numbers
    map_keys = jax.random.split(map_key, n_chr)
    gen_map_list = [gen_len * jnp.sort(jax.random.uniform(m_key, (seg_sites,))) for m_key in map_keys]
    gen_map = _pad_and_stack_arrays(gen_map_list) # Uses the new nan padding by default

    centromere = jnp.array([jnp.max(m) / 2 for m in gen_map_list])

    # --- Create Pedigree and IDs using JAX arrays ---
    ids = jnp.arange(n_ind)
    sex_array = jax.random.choice(sex_key, jnp.array([0, 1], dtype=jnp.int8), (n_ind,))

    return Population(
        geno=geno,
        genMap=gen_map,
        centromere=centromere,
        ploidy=ploidy,
        id=ids,
        iid=ids,  # In a new pop, id and iid are the same
        mother=jnp.full(n_ind, -1, dtype=jnp.int32),
        father=jnp.full(n_ind, -1, dtype=jnp.int32),
        sex=sex_array,
        gv=jnp.zeros((n_ind, 0)),  # No traits by default
        pheno=jnp.zeros((n_ind, 0)),
        fixEff=jnp.zeros(n_ind, dtype=jnp.float32) # Default fixed effect of 0
    )

### Future Work: `MultiPop` Functionality for Parallel Simulations

For the library to efficiently handle simulations involving multiple distinct populations at once (e.g., comparing different breeding programs), functionality equivalent to the original AlphaSimR's `MultiPop` class should be implemented.

The goal is to run computations like meiosis, crossing, and phenotyping across multiple populations in a single, parallelized step that takes full advantage of JAX's `vmap` and `pmap` capabilities.

#### Recommended Implementation: The "Stack and Index" Method

Directly mapping over a Python list of `Population` objects is not feasible if they contain different numbers of individuals, as JAX's transformations require consistent array shapes.

The recommended approach is to write a utility function that temporarily merges multiple `Population` objects into a single, larger object for computation.

1.  **Add a Population Index:** First, the `Population` dataclass should be amended to include a population identifier.

    ```python
    @dataclass
    class Population:
        # ... all other fields ...
        id: np.ndarray
        pop_id: np.ndarray # New field to track origin population
    ```

2.  **Create a `merge_pops` Utility:** This function will take a list of `Population` objects and concatenate their internal arrays. It will also generate the `pop_id` array to track the origin of each individual.

    ```python
    from typing import List

    def merge_pops(pops: List[Population]) -> Population:
        """
        Combines a list of Population objects into a single, large object 
        suitable for parallel processing with JAX.
        """
        if not pops:
            raise ValueError("Cannot merge an empty list of populations.")

        # Concatenate all JAX arrays (GPU data)
        all_geno = jnp.concatenate([p.geno for p in pops], axis=0)
        all_gv = jnp.concatenate([p.gv for p in pops], axis=0)
        # ... and so on for pheno, etc. ...

        # Concatenate all NumPy arrays (CPU data)
        all_id = np.concatenate([p.id for p in pops], axis=0)

        # Create the population tracking index
        pop_ids = np.concatenate([
            np.full(p.nInd, i, dtype=np.int32) for i, p in enumerate(pops)
        ])

        # Create and return the new unified Population object
        return Population(
            geno=all_geno,
            gv=all_gv,
            id=all_id,
            pop_id=pop_ids,
            # Static data (like genMap, ploidy) can be taken from the first pop,
            # assuming it's consistent across all populations being merged.
            genMap=pops[0].genMap,
            ploidy=pops[0].ploidy
            # ... fill other fields ...
        )
    ```

This "stack and index" approach creates a single, unified data structure with consistent array shapes. This is the ideal input for `pmap` and `vmap`, allowing for maximum computational throughput on the accelerator. The `pop_id` field ensures that results can be correctly partitioned and assigned back to their original populations after the computation is complete.

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()