# Population

> Data Structures for Population

In [1]:
#| default_exp population

In [2]:
#| hide
from nbdev.showdoc import *


In [3]:
#| echo: false
from nbdev.showdoc import show_doc

In [4]:
#| export
from dataclasses import field
from typing import List, Optional, Dict, Callable, Any

from flax.struct import dataclass as flax_dataclass # Using flax's dataclass for JAX-friendliness
import jax
import jax.numpy as jnp

from chewc.sp import SimParam
from typing import Tuple
from numpy.random import default_rng
import msprime
import tskit
import numpy as np
import random
from collections import defaultdict
from flax.struct import dataclass as flax_dataclass, field

#testing
import jax
import jax.numpy as jnp
from fastcore.test import test_eq, test_ne

In [5]:
#| export
@flax_dataclass(frozen=True) # Make the class immutable, a JAX best practice
class Population:
    """
    A container for all data related to a population of individuals, designed
    for JAX-based genetic simulations.

    This structure is immutable. All operations that modify a population should
    return a new Population object.

    Attributes:
        geno (jnp.ndarray): A 4D array representing the genotypes of the population.
            Shape: `(nInd, nChr, ploidy, nLoci)`. dtype: `jnp.uint8`.
        idb (jnp.ndarray): A 4D array representing the founder origins of each allele of the population.
            Shape: `(nInd, nChr, ploidy, nLoci)`. dtype: `jnp.uint8`.        id (jnp.ndarray): The primary, user-facing identifier for each individual.
            These IDs may not be contiguous or sorted. Shape: `(nInd,)`.
        iid (jnp.ndarray): The internal, zero-indexed, contiguous identifier.
            Crucial for robust indexing in JAX operations. Shape: `(nInd,)`.
        mother (jnp.ndarray): Array of internal IDs (`iid`) for the mother of each
            individual. A value of -1 indicates no known mother. Shape: `(nInd,)`.
        father (jnp.ndarray): Array of internal IDs (`iid`) for the father of each
            individual. A value of -1 indicates no known father. Shape: `(nInd,)`.
        sex (jnp.ndarray): The sex of each individual, represented numerically
            (e.g., 0 for male, 1 for female). dtype: `jnp.int8`. Shape: `(nInd,)`.
        gen (jnp.ndarray): The generation each individual is in, represented numerically
             dtype: `jnp.int8`. Shape: `(nInd,)`.

        pheno (jnp.ndarray): The phenotypic values for each individual.
            Shape: `(nInd, nTraits)`.
        fixEff (jnp.ndarray): The value of a fixed effect for each individual,
            often used as an intercept in genomic selection models. Shape: `(nInd,)`.
        
        bv (Optional[jnp.ndarray]): The true breeding values (additive genetic effects)
            for each individual. Shape: `(nInd, nTraits)`.
        dd (Optional[jnp.ndarray]): The true dominance deviations for each individual.
            Shape: `(nInd, nTraits)`.
        aa (Optional[jnp.ndarray]): The true additive-by-additive epistatic deviations
            for each individual. Shape: `(nInd, nTraits)`.

        ebv (Optional[jnp.ndarray]): The estimated breeding values for each
            individual. Shape: `(nInd, nTraits)`.
        gxe (Optional[jnp.ndarray]): Genotype-by-environment interaction effects.
            Shape depends on the specific GxE model.

        misc (Dict): A dictionary for storing miscellaneous, non-JAX-critical
            metadata about individuals. Static.
        miscPop (Dict): A dictionary for storing miscellaneous, non-JAX-critical
            metadata about the entire population. Static.
    """



    # --- Core Genotype Info ---
    geno: jnp.ndarray
    ibd : jnp.ndarray
    
    # --- Pedigree and Identifiers ---
    id: jnp.ndarray
    iid: jnp.ndarray
    mother: jnp.ndarray
    father: jnp.ndarray
    sex: jnp.ndarray
    gen: jnp.ndarray

    # --- Trait and Value Data ---
    pheno: jnp.ndarray
    fixEff: jnp.ndarray

    gv: Optional[jnp.ndarray] = None      # Genetic Value (BV + Intercept)
    bv: Optional[jnp.ndarray] = None      # Breeding Value (Additive)
    dd: Optional[jnp.ndarray] = None      # Dominance Deviations
    aa: Optional[jnp.ndarray] = None      # Additive-by-Additive Epistatic Deviations
    
    ebv: Optional[jnp.ndarray] = None
    gxe: Optional[jnp.ndarray] = None

    # --- Metadata ---
    misc: Optional[Dict[str, Any]] = field(default=None, pytree_node=False)
    miscPop: Optional[Dict[str, Any]] = field(default=None, pytree_node=False)

    @property
    def nInd(self) -> int:
        """Returns the number of individuals in the population."""
        return self.geno.shape[0]
    
    @property
    def nChr(self) -> int:
        """Returns the number of chromosomes in the population."""
        return self.geno.shape[1]

    @property
    def nTraits(self) -> int:
        """Returns the number of traits, inferred from the breeding value (bv) shape."""
        if self.bv is None or self.bv.ndim <= 1:
            return 0
        return self.bv.shape[1]
    
    @property
    def haplo_matrix(self) -> jnp.ndarray:
        """ returns a haplotype matrix of shape (n_ind*ploidy, n_chr*n_markers)"""
        return self.geno.transpose(0, 2, 1, 3).reshape(self.geno.shape[0] * self.geno.shape[2], -1)

    @property
    def dosage(self) -> jnp.ndarray:
        """
        Calculates the dosage of alternate alleles for each individual.

        The dosage is the sum of alleles across the ploidy dimension, resulting
        in a 2D matrix where each entry represents the count of the alternate
        allele at a specific locus for an individual.

        Returns:
            A JAX array of shape `(nInd, nLoci)`, where nLoci is the total
            number of loci across all chromosomes.
        """
        # Sum over the ploidy axis (axis=2) to get dosage per chromosome
        # Shape: (nInd, nChr, nLoci_per_chr)
        dosage_per_chr = jnp.sum(self.geno, axis=2)

        # Reshape to combine the chromosome and loci dimensions
        # Shape: (nInd, nChr * nLoci_per_chr)
        return dosage_per_chr.reshape(self.nInd, -1)
    

    def plot_maf(self, genetic_map=None, maf_threshold=None):
        """
        Plot MAF distribution as a quick sanity check for the population.
        
        Args:
            genetic_map: Optional genetic map to identify valid markers.
            maf_threshold: Optional MAF threshold to highlight on plot.
        """
        import matplotlib.pyplot as plt
        
        maf_values = []
        
        # Calculate MAF for each marker
        for chr_idx in range(self.nChr):
            for snp_idx in range(self.geno.shape[3]):
                # Skip invalid markers
                if genetic_map is not None and jnp.isnan(genetic_map[chr_idx, snp_idx]):
                    continue
                
                marker_genotypes = self.geno[:, chr_idx, :, snp_idx]
                if jnp.any(jnp.isnan(marker_genotypes)):
                    continue
                
                # Calculate MAF
                allele_freq = float(jnp.mean(marker_genotypes))
                maf = min(allele_freq, 1 - allele_freq)
                maf_values.append(maf)
        
        if not maf_values:
            print("No valid markers found!")
            return
        
        # Plot MAF distribution
        plt.figure(figsize=(8, 5))
        plt.hist(maf_values, bins=50, alpha=0.7, edgecolor='black')
        plt.xlabel('Minor Allele Frequency (MAF)')
        plt.ylabel('Number of Markers')
        plt.title('MAF Distribution')
        
        mean_maf = jnp.mean(jnp.array(maf_values))
        plt.axvline(mean_maf, color='red', linestyle='--', label=f'Mean: {mean_maf:.3f}')
        
        if maf_threshold is not None:
            plt.axvline(maf_threshold, color='green', linestyle=':', 
                    label=f'Threshold: {maf_threshold}')
        
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        # Print summary
        fixed = sum(1 for maf in maf_values if maf == 0)
        print(f"Markers: {len(maf_values)} | Fixed: {fixed} | Mean MAF: {mean_maf:.3f}")

    def __repr__(self) -> str:
        """Provides a concise representation of the Population object."""
        return (f"Population(nInd={self.nInd}, nTraits={self.nTraits}, "
                f"has_ebv={'Yes' if self.ebv is not None else 'No'})")    



In [6]:
#| export
from typing import Tuple
import jax
import jax.numpy as jnp
# Make sure your updated Population class is imported
# from .population import Population



def quick_haplo(
    key: jax.random.PRNGKey, 
    n_ind: int, 
    n_chr: int, 
    n_loci_per_chr: int, 
    ploidy: int = 2, 
    inbred: bool = False,
    chr_len_cm: float = 100.0
) -> Tuple[Population, jnp.ndarray]:
    """
    Creates a new population with random haplotypes and a uniform genetic map.

    This function is a self-contained founder population generator, analogous to
    AlphaSimR's `quickHaplo`. It no longer depends on a SimParam object.
    Instead, it returns both the Population and the genetic map needed to
    construct a SimParam object later.

    Args:
        key: A JAX random key for reproducibility.
        n_ind: Number of individuals to create.
        n_chr: Number of chromosomes.
        n_loci_per_chr: Number of loci on each chromosome.
        ploidy: The ploidy level of the individuals (default: 2).
        inbred: If True, individuals will be fully inbred (homozygous).
        chr_len_cm: The length of each chromosome in centiMorgans for the
                    generated uniform genetic map (default: 100.0).

    Returns:
        A tuple containing:
        - A new Population object with random founder individuals.
        - A JAX array representing the genetic map, with shape 
          `(n_chr, n_loci_per_chr)`.
    """
    key, geno_key, sex_key = jax.random.split(key, 3)

    # Generate random haplotypes for all individuals and chromosomes
    # Shape: (nInd, nChr, ploidy, nLoci)
    if inbred:
        # Generate one set of haplotypes and tile it across the ploidy axis
        base_haplotypes = jax.random.randint(geno_key, (n_ind, n_chr, 1, n_loci_per_chr), 0, 2, dtype=jnp.uint8)
        geno = jnp.tile(base_haplotypes, (1, 1, ploidy, 1))
    else:
        # Fully random, outbred individuals
        geno = jax.random.randint(geno_key, (n_ind, n_chr, ploidy, n_loci_per_chr), 0, 2, dtype=jnp.uint8)

    # --- Create Pedigree and IDs using JAX arrays ---
    ids = jnp.arange(n_ind)
    sex_array = jax.random.choice(sex_key, jnp.array([0, 1], dtype=jnp.int8), (n_ind,))

    # --- Generate a uniform genetic map ---
    # Each chromosome has loci evenly spaced from 0 to chr_len_cm
    loci_pos = jnp.linspace(0., chr_len_cm, n_loci_per_chr)
    genetic_map = jnp.tile(loci_pos, (n_chr, 1))

    # --- Create unique founder IBD identifiers ---
    # Each allele (at each locus) needs a unique founder ID
    n_founder_alleles = n_ind * n_chr * ploidy * n_loci_per_chr
    founder_ids = jnp.arange(n_founder_alleles, dtype=jnp.uint32)
    ibd = founder_ids.reshape(n_ind, n_chr, ploidy, n_loci_per_chr)
    
    # Handle inbred case: if inbred, IBD should reflect that homologous 
    # chromosomes have identical founder origins
    if inbred:
        # For inbred individuals, both haplotypes should have the same founder IDs
        # Use the first haplotype's IDs for all ploidy copies
        base_ibd = ibd[:, :, 0:1, :]  # Shape: (n_ind, n_chr, 1, n_loci_per_chr)
        ibd = jnp.tile(base_ibd, (1, 1, ploidy, 1))

    population = Population(
        geno=geno,
        ibd=ibd,  # Add IBD tracking
        id=ids,
        iid=ids,  # In a new pop, id and iid are the same
        mother=jnp.full(n_ind, -1, dtype=jnp.int32),
        father=jnp.full(n_ind, -1, dtype=jnp.int32),
        sex=sex_array,
        gen=jnp.zeros(n_ind, dtype=jnp.int32),
        pheno=jnp.zeros((n_ind, 0)),
        fixEff=jnp.zeros(n_ind, dtype=jnp.float32), # Default fixed effect of 0
        bv=jnp.zeros((n_ind, 0)),  # No traits by default
        dd=None,
        aa=None,
    )
    
    return population, genetic_map

In [7]:
#| hide
#| eval: false

# --- Test Suite for quick_haplo ---

import jax
import jax.numpy as jnp
import numpy as np
from numpy import testing as np_testing

# Common parameters for tests
KEY = jax.random.PRNGKey(42)
N_IND = 10
MAX_N_IND = 15
N_CHR = 2
N_LOCI = 5
PLOIDY = 2
CHR_LEN = 100.0

def test_basic_creation_and_shapes():
    """Test basic population creation and the shapes of its attributes."""
    pop, gmap = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI,
        ploidy=PLOIDY
    )

    # Check return types
    assert isinstance(pop, Population), f"Expected Population object, got {type(pop)}"
    assert isinstance(gmap, jnp.ndarray), f"Expected jax.numpy.ndarray, got {type(gmap)}"

    # Check Population properties
    assert pop.max_nInd == MAX_N_IND
    assert pop.nChr == N_CHR
    assert pop.nTraits == 0 # No traits created by default

    # Check array shapes
    expected_geno_shape = (MAX_N_IND, N_CHR, PLOIDY, N_LOCI)
    assert pop.geno.shape == expected_geno_shape, f"geno shape is {pop.geno.shape}, expected {expected_geno_shape}"
    assert pop.ibd.shape == expected_geno_shape, f"ibd shape is {pop.ibd.shape}, expected {expected_geno_shape}"
    
    expected_1d_shape = (MAX_N_IND,)
    assert pop.active_mask.shape == expected_1d_shape
    assert pop.id.shape == expected_1d_shape
    assert pop.iid.shape == expected_1d_shape
    assert pop.mother.shape == expected_1d_shape
    assert pop.father.shape == expected_1d_shape
    assert pop.sex.shape == expected_1d_shape
    assert pop.gen.shape == expected_1d_shape
    assert pop.fixEff.shape == expected_1d_shape

    expected_pheno_shape = (MAX_N_IND, 0)
    assert pop.pheno.shape == expected_pheno_shape
    assert pop.bv.shape == expected_pheno_shape

    # Check genetic map shape
    expected_gmap_shape = (N_CHR, N_LOCI)
    assert gmap.shape == expected_gmap_shape, f"gmap shape is {gmap.shape}, expected {expected_gmap_shape}"

def test_active_individuals_count():
    """Test that nInd and active_mask are set correctly."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI
    )
    
    assert pop.nInd == N_IND, f"pop.nInd is {pop.nInd}, expected {N_IND}"
    
    # Check active_mask content
    assert jnp.sum(pop.active_mask) == N_IND
    np_testing.assert_array_equal(pop.active_mask[:N_IND], jnp.ones(N_IND, dtype=bool))
    np_testing.assert_array_equal(pop.active_mask[N_IND:], jnp.zeros(MAX_N_IND - N_IND, dtype=bool))

def test_padding_values():
    """Test that padded regions of arrays have the correct constant values."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI
    )
    
    # Check active region is not the padding value
    assert pop.id[N_IND-1] != -1
    assert not jnp.all(pop.geno[N_IND-1] == 0)
    
    # Check padded region
    pad_slice = slice(N_IND, MAX_N_IND)
    
    assert jnp.all(pop.geno[pad_slice] == 0)
    assert jnp.all(pop.ibd[pad_slice] == 0)
    assert jnp.all(pop.id[pad_slice] == -1)
    assert jnp.all(pop.iid[pad_slice] == -1)
    assert jnp.all(pop.mother[pad_slice] == -1)
    assert jnp.all(pop.father[pad_slice] == -1)
    assert jnp.all(pop.sex[pad_slice] == -1)
    assert jnp.all(pop.gen[pad_slice] == -1)
    
    # For float arrays, check for NaN
    assert jnp.all(jnp.isnan(pop.pheno[pad_slice]))
    assert jnp.all(jnp.isnan(pop.fixEff[pad_slice]))
    assert jnp.all(jnp.isnan(pop.bv[pad_slice]))

def test_inbred_true():
    """Test that inbred=True creates individuals with identical homologous chromosomes."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI,
        ploidy=PLOIDY,
        inbred=True
    )
    
    # Check only the active individuals
    active_geno = pop.geno[:N_IND]
    active_ibd = pop.ibd[:N_IND]
    
    # For each individual, check that all homologous chromosomes are identical
    for p in range(1, PLOIDY):
        np_testing.assert_array_equal(active_geno[:, :, 0, :], active_geno[:, :, p, :])
        np_testing.assert_array_equal(active_ibd[:, :, 0, :], active_ibd[:, :, p, :])

def test_inbred_false():
    """Test that inbred=False creates individuals with unique founder alleles."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI,
        ploidy=PLOIDY,
        inbred=False
    )
    
    # For an outbred population, every allele should have a unique IBD identity.
    active_ibd = pop.ibd[:N_IND]
    n_founder_alleles = N_IND * N_CHR * PLOIDY * N_LOCI
    
    # Check that all IBD values are unique
    unique_ibd_values = jnp.unique(active_ibd)
    assert len(unique_ibd_values) == n_founder_alleles
    
    # Check that the values are a contiguous range from 0 to N-1
    expected_values = jnp.arange(n_founder_alleles, dtype=jnp.uint32)
    np_testing.assert_array_equal(jnp.sort(unique_ibd_values), expected_values)

def test_genetic_map_properties():
    """Test the shape and values of the genetic map."""
    _, gmap = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI,
        chr_len_cm=CHR_LEN
    )
    
    assert gmap.shape == (N_CHR, N_LOCI)
    
    # Check that the first locus is at 0.0 cM and the last is at CHR_LEN cM
    assert gmap[0, 0] == 0.0
    np_testing.assert_allclose(gmap[0, -1], CHR_LEN)
    
    # Check that all rows are identical
    for c in range(1, N_CHR):
        np_testing.assert_array_equal(gmap[0, :], gmap[c, :])
        
    # Check that positions are non-decreasing
    assert jnp.all(jnp.diff(gmap, axis=1) >= 0)

def test_edge_case_no_padding():
    """Test population creation when n_ind equals max_n_ind."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=MAX_N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI
    )
    
    assert pop.nInd == MAX_N_IND
    assert pop.geno.shape[0] == MAX_N_IND
    assert jnp.all(pop.active_mask) # All individuals should be active
    assert not jnp.any(pop.id == -1) # No padded IDs

def test_edge_case_empty_population():
    """Test population creation when n_ind is 0."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=0,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI
    )
    
    assert pop.nInd == 0
    assert pop.geno.shape[0] == MAX_N_IND
    assert not jnp.any(pop.active_mask) # All individuals should be inactive
    assert jnp.all(pop.id == -1) # All IDs should be padded

def test_error_on_invalid_n_ind():
    """Test that a ValueError is raised if n_ind > max_n_ind."""
    try:
        quick_haplo(
            key=KEY,
            n_ind=MAX_N_IND + 1,
            max_n_ind=MAX_N_IND,
            n_chr=N_CHR,
            n_loci_per_chr=N_LOCI
        )
        # If the above line does not raise an error, the test fails.
        raised_error = False
    except ValueError:
        # The expected error was caught.
        raised_error = True
        
    assert raised_error, "ValueError was not raised for n_ind > max_n_ind"

def test_data_types():
    """Test the dtypes of key population attributes."""
    pop, _ = quick_haplo(
        key=KEY,
        n_ind=N_IND,
        max_n_ind=MAX_N_IND,
        n_chr=N_CHR,
        n_loci_per_chr=N_LOCI
    )
    
    assert pop.geno.dtype == jnp.uint8
    assert pop.ibd.dtype == jnp.uint32
    assert pop.active_mask.dtype == jnp.bool_
    assert pop.id.dtype == jnp.int32
    assert pop.iid.dtype == jnp.int32
    assert pop.mother.dtype == jnp.int32
    assert pop.father.dtype == jnp.int32
    assert pop.sex.dtype == jnp.int8
    assert pop.gen.dtype == jnp.int32
    # JAX defaults to float32, which is usually sufficient
    assert pop.pheno.dtype == jnp.float32 
    assert pop.fixEff.dtype == jnp.float32

# To run these tests, you can execute `nbdev_test()` from your terminal
# or run each test function individually within a notebook.



In [8]:
#| export

def combine_populations(pop1, pop2, new_id_start=None):
    """Combine two populations into one, handling ID management and all array sizes"""
    if new_id_start is None:
        new_id_start = jnp.max(pop1.id) + 1
    
    # Update pop2 IDs to avoid conflicts
    pop2_new_ids = jnp.arange(new_id_start, new_id_start + pop2.nInd)
    
    # Create combined population with proper array concatenation for ALL fields
    combined_pop = Population(
        geno=jnp.concatenate([pop1.geno, pop2.geno], axis=0),
        ibd=jnp.concatenate([pop1.ibd, pop2.ibd], axis=0),  # ADDED: Handle IBD array
        id=jnp.concatenate([pop1.id, pop2_new_ids]),
        iid=jnp.arange(pop1.nInd + pop2.nInd),  # Reset internal IDs
        mother=jnp.concatenate([pop1.mother, pop2.mother]),
        father=jnp.concatenate([pop1.father, pop2.father]),
        sex=jnp.concatenate([pop1.sex, pop2.sex]),
        gen=jnp.concatenate([pop1.gen, pop2.gen]),
        pheno=jnp.concatenate([pop1.pheno, pop2.pheno]),
        fixEff=jnp.concatenate([pop1.fixEff, pop2.fixEff]),
        bv=jnp.concatenate([pop1.bv, pop2.bv]) if pop1.bv is not None and pop2.bv is not None else None,
        ebv=jnp.concatenate([pop1.ebv, pop2.ebv]) if pop1.ebv is not None and pop2.ebv is not None else None,
        # Handle optional arrays
        dd=jnp.concatenate([pop1.dd, pop2.dd]) if pop1.dd is not None and pop2.dd is not None else None,
        aa=jnp.concatenate([pop1.aa, pop2.aa]) if pop1.aa is not None and pop2.aa is not None else None,
        gv=jnp.concatenate([pop1.gv, pop2.gv]) if pop1.gv is not None and pop2.gv is not None else None,
        gxe=jnp.concatenate([pop1.gxe, pop2.gxe]) if pop1.gxe is not None and pop2.gxe is not None else None
    )
    
    return combined_pop

def subset_population(pop: Population, indices: jnp.ndarray) -> Population:
    """
    Creates a new Population object containing only the individuals specified by indices.

    Args:
        pop: The original Population object.
        indices: A JAX array of integer indices of individuals to select.

    Returns:
        A new Population object with the subset of individuals.
    """
    n_new_ind = indices.shape[0]

    # Use .at[indices].get() or direct indexing to select rows from each array
    # Ensure all arrays are handled, including optional ones
    return Population(
        geno=pop.geno[indices],
        ibd=pop.ibd[indices],  # ADDED: Handle IBD array
        id=pop.id[indices],
        iid=jnp.arange(n_new_ind, dtype=jnp.int32), # New internal IDs are 0-indexed
        mother=pop.mother[indices],
        father=pop.father[indices],
        sex=pop.sex[indices],
        gen=pop.gen[indices],
        pheno=pop.pheno[indices],
        fixEff=pop.fixEff[indices],
        bv=pop.bv[indices] if pop.bv is not None else None,
        dd=pop.dd[indices] if pop.dd is not None else None,
        aa=pop.aa[indices] if pop.aa is not None else None,
        ebv=pop.ebv[indices] if pop.ebv is not None else None,
        gxe=pop.gxe[indices] if pop.gxe is not None else None,
        # misc and miscPop are static and apply to the whole population,
        # so they are carried over as is.
        misc=pop.misc,
        miscPop=pop.miscPop
    )



In [9]:
#| export

def msprime_pop(
    key: jax.random.PRNGKey,
    n_ind: int,
    n_loci_per_chr: int,
    n_chr: int,
    ploidy: int = 2,
    effective_population_size: int = 10_000,
    mutation_rate: float = 2e-8,
    recombination_rate_per_chr: float = 2e-8,
    maf_threshold: float = 0.1,
    num_simulated_individuals: int = None,
    base_chr_length: int = 1_000_000,
    enforce_founder_maf: bool = True
) -> Tuple[Population, jnp.ndarray]:
    """
    Creates a new founder population using msprime coalescent simulation.

    Generates genotypes and a genetic map based on population genetics principles.
    Updated with improved parameter validation and more reasonable defaults.

    Args:
        key: JAX random key.
        n_ind: Number of founder individuals to generate.
        n_loci_per_chr: Number of SNPs (loci) to select per chromosome.
        n_chr: Number of chromosomes.
        ploidy: The ploidy level of the individuals (default: 2).
        effective_population_size: The effective population size for simulation.
        mutation_rate: The mutation rate for the simulation.
        recombination_rate_per_chr: Recombination rate per chromosome.
        maf_threshold: Minimum allele frequency threshold for SNPs.
        num_simulated_individuals: Number of individuals to simulate initially.
            If None, will be set to max(n_ind * 2, 1000) for better variant diversity.
        base_chr_length: Length of each chromosome in base pairs.
        enforce_founder_maf: If True, ensures MAF threshold is met in the final
            founder population. If False, applies MAF filter to the full simulated
            population (original behavior).

    Returns:
        A tuple containing:
        - A new Population object with random founder individuals.
        - A JAX array representing the genetic map, with shape 
          `(n_chr, n_loci_per_chr)`.

    Raises:
        ValueError: If parameters are invalid or likely to cause memory issues.
    """
    # --- Parameter Validation ---
    if effective_population_size > 100_000:
        raise ValueError(
            f"Effective population size {effective_population_size} is too large and may cause "
            f"memory issues. Consider using values <= 50,000. For very large populations, "
            f"consider using quick_haplo() instead."
        )
    
    if effective_population_size < 10:
        raise ValueError(
            f"Effective population size {effective_population_size} is too small. "
            f"Use values >= 10 for realistic simulations."
        )
    
    # Set num_simulated_individuals dynamically if not provided
    if num_simulated_individuals is None:
        # If enforcing founder MAF, we need more individuals to ensure diversity
        multiplier = 5 if enforce_founder_maf else 2
        num_simulated_individuals = min(max(n_ind * multiplier, 1000), 10_000)
    
    if n_ind > num_simulated_individuals:
        raise ValueError(
            f"Number of founders requested ({n_ind}) cannot exceed the base simulated "
            f"population size ({num_simulated_individuals})."
        )
    
    # Additional warning for founder MAF enforcement
    if enforce_founder_maf and n_ind < 20:
        import warnings
        warnings.warn(
            f"Small founder population size ({n_ind}) with enforce_founder_maf=True "
            f"may result in few usable markers. Consider increasing n_ind or setting "
            f"enforce_founder_maf=False.",
            UserWarning
        )

    # --- Derive Seeds ---
    key, seed_key, sex_key, numpy_seed_key = jax.random.split(key, 4)
    random_seed = int(jnp.sum(seed_key))
    numpy_seed = int(jnp.sum(numpy_seed_key))
    rng = default_rng(numpy_seed)

    # --- Chromosome Lengths ---
    chromosome_lengths = [base_chr_length] * n_chr

    # --- Run msprime Simulation ---
    num_haplotypes = num_simulated_individuals * ploidy

    # Create the recombination map for msprime
    rate_map_positions = [0] + list(np.cumsum(chromosome_lengths))
    rate_map_rates = [recombination_rate_per_chr] * len(chromosome_lengths)
    rate_map = msprime.RateMap(position=rate_map_positions, rate=rate_map_rates)

    try:
        ts = msprime.sim_ancestry(
            samples=num_haplotypes, 
            population_size=effective_population_size,
            recombination_rate=rate_map, 
            random_seed=random_seed
        )
        mts = msprime.sim_mutations(ts, rate=mutation_rate, random_seed=random_seed)
    except Exception as e:
        if "memory" in str(e).lower() or "malloc" in str(e).lower():
            raise RuntimeError(
                f"Memory allocation failed during msprime simulation. This is likely due to "
                f"too large parameter combination. Try reducing effective_population_size "
                f"(current: {effective_population_size}), num_simulated_individuals "
                f"(current: {num_simulated_individuals}), or genome size. "
                f"Original error: {str(e)}"
            ) from e
        else:
            raise RuntimeError(f"msprime simulation failed: {str(e)}") from e

    # --- Sample Founders FIRST ---
    true_num_individuals = mts.num_samples // ploidy
    founder_indices = np.sort(rng.choice(true_num_individuals, n_ind, replace=False))
    
    # --- Data Extraction with Proper MAF Filtering ---
    all_variants = list(mts.variants())
    genetic_map = np.full((n_chr, n_loci_per_chr), np.nan)
    # Only store founder data directly
    founder_haplotype_matrix = np.full((n_ind, n_chr, ploidy, n_loci_per_chr), np.nan)

    for i in range(n_chr):
        chr_start, chr_end, recomb_rate = rate_map.left[i], rate_map.right[i], rate_map.rate[i]

        # Get all biallelic SNPs in this chromosome
        chromosome_snps = [
            var for var in all_variants
            if chr_start <= var.site.position < chr_end and len(var.alleles) == 2
        ]

        if enforce_founder_maf:
            # Apply MAF filter to the FOUNDER population only
            eligible_snps = []
            for var in chromosome_snps:
                # Extract genotypes for founders only
                all_genotypes = var.genotypes.reshape(true_num_individuals, ploidy)
                founder_genotypes = all_genotypes[founder_indices]
                founder_maf = min(np.mean(founder_genotypes), 1 - np.mean(founder_genotypes))
                
                if founder_maf > maf_threshold:
                    eligible_snps.append(var)
        else:
            # Apply MAF filter to the full simulated population (original behavior)
            eligible_snps = [
                var for var in chromosome_snps
                if min(np.mean(var.genotypes), 1 - np.mean(var.genotypes)) > maf_threshold
            ]

        num_found = len(eligible_snps)
        num_to_select = min(num_found, n_loci_per_chr)

        if num_to_select > 0:
            selected_indices = rng.choice(len(eligible_snps), num_to_select, replace=False)
            selected_snps = [eligible_snps[i] for i in selected_indices]
            selected_snps.sort(key=lambda v: v.site.position)

            for snp_idx, snp in enumerate(selected_snps):
                # Extract genotypes for ALL individuals, then subset to founders
                all_genotypes = snp.genotypes.reshape(true_num_individuals, ploidy)
                founder_genotypes = all_genotypes[founder_indices]
                founder_haplotype_matrix[:, i, :, snp_idx] = founder_genotypes

            positions_cm = [(v.site.position - chr_start) * recomb_rate * 100 for v in selected_snps]
            genetic_map[i, :num_to_select] = positions_cm
        elif num_found == 0:
            import warnings
            population_type = "founder" if enforce_founder_maf else "full simulated"
            warnings.warn(
                f"No variants found for chromosome {i} with MAF > {maf_threshold} in the "
                f"{population_type} population. Consider lowering maf_threshold or "
                f"increasing mutation_rate/effective_population_size.",
                UserWarning
            )

    # --- Founder data is already extracted ---
    founder_haplotypes = founder_haplotype_matrix

    # Convert to JAX arrays
    geno = jnp.array(founder_haplotypes, dtype=jnp.uint8)
    gen_map_jax = jnp.array(genetic_map)

    # --- Create IBD tracking for msprime founders ---
    # For msprime-generated founders, create unique IBD identifiers
    # This is a simplified approach - a more sophisticated version would
    # track actual coalescent relationships from the tree sequence
    n_founder_alleles = n_ind * n_chr * ploidy * n_loci_per_chr
    founder_ids = jnp.arange(n_founder_alleles, dtype=jnp.uint32)
    ibd = founder_ids.reshape(n_ind, n_chr, ploidy, n_loci_per_chr)

    # --- Create Pedigree and IDs using JAX arrays ---
    ids = jnp.arange(n_ind)
    sex_array = jax.random.choice(sex_key, jnp.array([0, 1], dtype=jnp.int8), (n_ind,))
    
    pop = Population(
        geno=geno,
        ibd=ibd,  # Include IBD tracking
        id=ids,
        iid=ids,
        mother=jnp.full(n_ind, -1, dtype=jnp.int32),
        father=jnp.full(n_ind, -1, dtype=jnp.int32),
        sex=sex_array,
        gen=jnp.zeros(n_ind, dtype=jnp.int32),
        pheno=jnp.zeros((n_ind, 0)),
        fixEff=jnp.zeros(n_ind, dtype=jnp.float32),
        bv=jnp.zeros((n_ind, 0)),
        miscPop={
            'msprime_params': {
                'effective_population_size': effective_population_size,
                'mutation_rate': mutation_rate,
                'recombination_rate_per_chr': recombination_rate_per_chr,
                'maf_threshold': maf_threshold,
                'num_simulated_individuals': num_simulated_individuals,
                'base_chr_length': base_chr_length,
                'enforce_founder_maf': enforce_founder_maf
            }
        }
    )

    return pop, gen_map_jax


In [10]:
#| test


# 1. ARRANGE: Define test parameters
key = jax.random.PRNGKey(123)
n_ind = 10
n_loci_per_chr = 5
n_chr = 2
ploidy = 2

# 2. ACT: Call the function and unpack the tuple
test_pop, test_gen_map = msprime_pop(
    key,
    n_ind,
    n_loci_per_chr,
    n_chr,
    ploidy
)

# 3. ASSERT: Perform tests

def test_msprime_pop_output_type():
    "Test that msprime_pop returns a Population object and a JAX array."
    assert isinstance(test_pop, Population)
    assert isinstance(test_gen_map, jnp.ndarray)

def test_msprime_pop_dimensions():
    "Test if the created population has the correct dimensions."
    test_eq(test_pop.nInd, n_ind)
    test_eq(test_pop.nChr, n_chr)
    test_eq(test_pop.geno.shape, (n_ind, n_chr, ploidy, n_loci_per_chr))

def test_msprime_pop_genetic_map_dimensions():
    "Test that the returned genetic map has the correct dimensions."
    test_eq(test_gen_map.shape, (n_chr, n_loci_per_chr))

def test_msprime_pop_not_all_zeros():
    "Test that the genotypes are not all zeros (some variation exists)."
    assert jnp.any(test_pop.geno != 0)

def test_msprime_pop_different_key():
    "Test that a different key produces a different population."
    key2 = jax.random.PRNGKey(456)
    test_pop2, _ = msprime_pop(
        key2,
        n_ind,
        n_loci_per_chr,
        n_chr,
        ploidy
    )
    test_ne(jnp.sum(test_pop.geno), jnp.sum(test_pop2.geno))

def test_msprime_pop_reproducibility():
    "Test that the same key produces the exact same population and map."
    key1 = jax.random.PRNGKey(789)
    pop1, map1 = msprime_pop(key1, n_ind, n_loci_per_chr, n_chr, ploidy)
    pop2, map2 = msprime_pop(key1, n_ind, n_loci_per_chr, n_chr, ploidy)
    assert jnp.array_equal(pop1.geno, pop2.geno)
    assert jnp.allclose(map1, map2, equal_nan=True)



In [11]:
#| hide
import nbdev; nbdev.nbdev_export()