# Population

> Data Structures for Population

In [1]:
#| default_exp core

In [2]:
#| hide
from nbdev.showdoc import *

In [1]:
#| export

from dataclasses import dataclass
from typing import List, Optional
import jax.numpy as jnp
import numpy as np

# Using a standard numpy array for non-GPU data like IDs
# Using jax.numpy for data that will be used in computations
# on the GPU.

@dataclass
class Population:
    # --- Core Genotype and Map Info (from RawPop & MapPop) ---
    # Shape: (nInd, nChr, ploidy, nLoci_per_chr) or similar flattened representation
    # We will unpack the bit-packed representation into a JAX array of integers (0s and 1s).
    # This is less memory efficient but massively more friendly for GPU computation.
    geno: jnp.ndarray

    # List of JAX arrays, one for each chromosome's map
    genMap: List[jnp.ndarray]
    centromere: jnp.ndarray # Shape: (nChr,)
    ploidy: int

    # --- Pedigree and Identifiers (from Pop) ---
    # We use standard numpy arrays for these as they won't be used in JAX transformations
    id: np.ndarray # Shape: (nInd,)
    mother: np.ndarray # Shape: (nInd,)
    father: np.ndarray # Shape: (nInd,)
    sex: np.ndarray # Shape: (nInd,)

    # --- Trait and Value Data (from Pop) ---
    gv: jnp.ndarray # Shape: (nInd, nTraits)
    pheno: jnp.ndarray # Shape: (nInd, nTraits)
    ebv: Optional[jnp.ndarray] = None # Shape: (nInd, nEbvModels)
    gxe: Optional[List[jnp.ndarray]] = None

    # --- Metadata ---
    # Using a dictionary for flexibility
    misc: Optional[dict] = None

    # --- Computed properties for convenience ---
    @property
    def nInd(self) -> int:
        return self.geno.shape[0]

    @property
    def nChr(self) -> int:
        return len(self.genMap)

    @property
    def nTraits(self) -> int:
        return self.gv.shape[1]

### Future Work: `MultiPop` Functionality for Parallel Simulations

For the library to efficiently handle simulations involving multiple distinct populations at once (e.g., comparing different breeding programs), functionality equivalent to the original AlphaSimR's `MultiPop` class should be implemented.

The goal is to run computations like meiosis, crossing, and phenotyping across multiple populations in a single, parallelized step that takes full advantage of JAX's `vmap` and `pmap` capabilities.

#### Recommended Implementation: The "Stack and Index" Method

Directly mapping over a Python list of `Population` objects is not feasible if they contain different numbers of individuals, as JAX's transformations require consistent array shapes.

The recommended approach is to write a utility function that temporarily merges multiple `Population` objects into a single, larger object for computation.

1.  **Add a Population Index:** First, the `Population` dataclass should be amended to include a population identifier.

    ```python
    @dataclass
    class Population:
        # ... all other fields ...
        id: np.ndarray
        pop_id: np.ndarray # New field to track origin population
    ```

2.  **Create a `merge_pops` Utility:** This function will take a list of `Population` objects and concatenate their internal arrays. It will also generate the `pop_id` array to track the origin of each individual.

    ```python
    from typing import List

    def merge_pops(pops: List[Population]) -> Population:
        """
        Combines a list of Population objects into a single, large object 
        suitable for parallel processing with JAX.
        """
        if not pops:
            raise ValueError("Cannot merge an empty list of populations.")

        # Concatenate all JAX arrays (GPU data)
        all_geno = jnp.concatenate([p.geno for p in pops], axis=0)
        all_gv = jnp.concatenate([p.gv for p in pops], axis=0)
        # ... and so on for pheno, etc. ...

        # Concatenate all NumPy arrays (CPU data)
        all_id = np.concatenate([p.id for p in pops], axis=0)

        # Create the population tracking index
        pop_ids = np.concatenate([
            np.full(p.nInd, i, dtype=np.int32) for i, p in enumerate(pops)
        ])

        # Create and return the new unified Population object
        return Population(
            geno=all_geno,
            gv=all_gv,
            id=all_id,
            pop_id=pop_ids,
            # Static data (like genMap, ploidy) can be taken from the first pop,
            # assuming it's consistent across all populations being merged.
            genMap=pops[0].genMap,
            ploidy=pops[0].ploidy
            # ... fill other fields ...
        )
    ```

This "stack and index" approach creates a single, unified data structure with consistent array shapes. This is the ideal input for `pmap` and `vmap`, allowing for maximum computational throughput on the accelerator. The `pop_id` field ensures that results can be correctly partitioned and assigned back to their original populations after the computation is complete.

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()