# Population

> Data Structures for Population

In [1]:
#| default_exp core

In [2]:
#| hide
from nbdev.showdoc import *

In [15]:
#| export
from dataclasses import dataclass, field
from typing import List, Optional, Dict
import jax
import jax.numpy as jnp
import numpy as np

@dataclass
class Population:
    # --- Core Genotype and Map Info (from RawPop & MapPop) ---
    geno: jnp.ndarray
    genMap: List[jnp.ndarray]
    centromere: jnp.ndarray
    ploidy: int
    
    # --- Pedigree and Identifiers (from Pop) ---
    id: np.ndarray
    iid: np.ndarray # Internal ID for robust tracking
    mother: np.ndarray
    father: np.ndarray
    sex: np.ndarray
    
    # --- Trait and Value Data (from Pop) ---
    gv: jnp.ndarray
    pheno: jnp.ndarray
    fixEff: jnp.ndarray # Fixed effect for GS models
    ebv: Optional[jnp.ndarray] = None
    gxe: Optional[List[jnp.ndarray]] = None
    
    # --- Metadata ---
    misc: Dict = field(default_factory=dict)
    miscPop: Dict = field(default_factory=dict)


    def __post_init__(self):
        """
        Validates the consistency of the population data after initialization.
        """
        n_ind_from_id = len(self.id)
        assert len(self.iid) == n_ind_from_id, "Internal ID array length must match ID array length."
        assert len(self.mother) == n_ind_from_id, "Mother array length must match ID array length."
        assert len(self.father) == n_ind_from_id, "Father array length must match ID array length."
        assert len(self.sex) == n_ind_from_id, "Sex array length must match ID array length."
        assert self.geno.shape[0] == n_ind_from_id, "Genotype array must have the same number of individuals as the ID array."
        assert self.gv.shape[0] == n_ind_from_id, "Genetic value (gv) array must have the same number of individuals as the ID array."
        assert self.pheno.shape[0] == n_ind_from_id, "Phenotype (pheno) array must have the same number of individuals as the ID array."
        assert self.fixEff.shape[0] == n_ind_from_id, "Fixed effect (fixEff) array must have the same number of individuals as the ID array."
        if self.ebv is not None:
            assert self.ebv.shape[0] == n_ind_from_id, "EBV array must have the same number of individuals as the ID array."
        assert len(self.genMap) == len(self.centromere), "genMap and centromere must have the same number of chromosomes."

    @property
    def nInd(self) -> int:
        return self.geno.shape[0]

    @property
    def nChr(self) -> int:
        return len(self.genMap)

    @property
    def nLoci(self):
        # Assumes equal loci per chromosome for now
        return self.geno.shape[3]

    @property
    def nTraits(self) -> int:
        # Returns 0 if gv is empty to avoid errors
        return self.gv.shape[1] if self.gv.ndim > 1 else 0

    def __repr__(self) -> str:
        return (f"Population(nInd={self.nInd}, nChr={self.nChr}, nTraits={self.nTraits}, "
                f"ploidy={self.ploidy}, has_ebv={'Yes' if self.ebv is not None else 'No'})")


# --- Factory Functions ---

def quick_haplo(n_ind: int, n_chr: int, seg_sites: int, ploidy: int = 2, gen_len: int = 1, inbred: bool = False, key=None) -> Population:
    """
    Creates a new population with randomly generated haplotypes.
    This is analogous to AlphaSimR's quickHaplo function.
    
    Args:
        n_ind: Number of individuals.
        n_chr: Number of chromosomes.
        seg_sites: Number of segregating sites per chromosome.
        ploidy: Ploidy level.
        gen_len: The genetic length of the chromosomes in Morgans.
        inbred: If True, individuals will be fully inbred.
        key: JAX random key. If None, a new key is created.

    Returns:
        A new Population object.
    """
    if key is None:
        key = jax.random.PRNGKey(0) 

    # Generate random haplotypes for all individuals and chromosomes
    # Shape: (nInd, nChr, ploidy, nLoci)
    key, geno_key, map_key, sex_key = jax.random.split(key, 4)
    
    # If inbred, we first generate haplotypes for one parental strand
    if inbred:
        # Shape: (nInd, nChr, 1, nLoci)
        base_haplotypes = jax.random.randint(geno_key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.uint8)
        # Tile along the ploidy axis to make them homozygous
        geno = jnp.tile(base_haplotypes, (1, 1, ploidy, 1))
    else:
        # Fully random, outbred individuals
        geno = jax.random.randint(geno_key, (n_ind, n_chr, ploidy, seg_sites), 0, 2, dtype=jnp.uint8)

    # Create genetic map using JAX for random numbers
    map_keys = jax.random.split(map_key, n_chr)
    gen_map = [gen_len * jnp.sort(jax.random.uniform(m_key, (seg_sites,))) for m_key in map_keys]
    
    centromere = jnp.array([jnp.max(m) / 2 for m in gen_map])
    
    ids = np.arange(n_ind)
    
    return Population(
        geno=geno,
        genMap=gen_map,
        centromere=centromere,
        ploidy=ploidy,
        id=ids,
        iid=ids, # In a new pop, id and iid are the same
        mother=np.full(n_ind, -1, dtype=int),
        father=np.full(n_ind, -1, dtype=int),
        sex=jax.random.choice(sex_key, np.array(['M', 'F']), (n_ind,)),
        gv=jnp.zeros((n_ind, 0)), # No traits by default
        pheno=jnp.zeros((n_ind, 0)),
        fixEff=jnp.ones(n_ind, dtype=int)
    )

In [9]:
new_pop.geno.shape

(10, 5, 2, 100)

### Future Work: `MultiPop` Functionality for Parallel Simulations

For the library to efficiently handle simulations involving multiple distinct populations at once (e.g., comparing different breeding programs), functionality equivalent to the original AlphaSimR's `MultiPop` class should be implemented.

The goal is to run computations like meiosis, crossing, and phenotyping across multiple populations in a single, parallelized step that takes full advantage of JAX's `vmap` and `pmap` capabilities.

#### Recommended Implementation: The "Stack and Index" Method

Directly mapping over a Python list of `Population` objects is not feasible if they contain different numbers of individuals, as JAX's transformations require consistent array shapes.

The recommended approach is to write a utility function that temporarily merges multiple `Population` objects into a single, larger object for computation.

1.  **Add a Population Index:** First, the `Population` dataclass should be amended to include a population identifier.

    ```python
    @dataclass
    class Population:
        # ... all other fields ...
        id: np.ndarray
        pop_id: np.ndarray # New field to track origin population
    ```

2.  **Create a `merge_pops` Utility:** This function will take a list of `Population` objects and concatenate their internal arrays. It will also generate the `pop_id` array to track the origin of each individual.

    ```python
    from typing import List

    def merge_pops(pops: List[Population]) -> Population:
        """
        Combines a list of Population objects into a single, large object 
        suitable for parallel processing with JAX.
        """
        if not pops:
            raise ValueError("Cannot merge an empty list of populations.")

        # Concatenate all JAX arrays (GPU data)
        all_geno = jnp.concatenate([p.geno for p in pops], axis=0)
        all_gv = jnp.concatenate([p.gv for p in pops], axis=0)
        # ... and so on for pheno, etc. ...

        # Concatenate all NumPy arrays (CPU data)
        all_id = np.concatenate([p.id for p in pops], axis=0)

        # Create the population tracking index
        pop_ids = np.concatenate([
            np.full(p.nInd, i, dtype=np.int32) for i, p in enumerate(pops)
        ])

        # Create and return the new unified Population object
        return Population(
            geno=all_geno,
            gv=all_gv,
            id=all_id,
            pop_id=pop_ids,
            # Static data (like genMap, ploidy) can be taken from the first pop,
            # assuming it's consistent across all populations being merged.
            genMap=pops[0].genMap,
            ploidy=pops[0].ploidy
            # ... fill other fields ...
        )
    ```

This "stack and index" approach creates a single, unified data structure with consistent array shapes. This is the ideal input for `pmap` and `vmap`, allowing for maximum computational throughput on the accelerator. The `pop_id` field ensures that results can be correctly partitioned and assigned back to their original populations after the computation is complete.

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()