In [1]:
from flax.struct import dataclass as flax_dataclass, field
import jax.numpy as jnp
from typing import Optional, Dict, Any, Tuple

@flax_dataclass(frozen=True)
class Population:
    """
    A container for all data related to a population of individuals, designed
    for JAX-based genetic simulations with a fixed maximum size.

    This structure is immutable and JAX-compatible (Pytree). All arrays
    share a leading dimension of `max_pop_size`.
    """
    # --- Core Genotype Info ---
    # Shape: (max_pop_size, nChr, ploidy, nLoci)
    geno: jnp.ndarray
    ibd: jnp.ndarray

    # --- Pedigree and Identifiers (all shape (max_pop_size,)) ---
    id: jnp.ndarray
    iid: jnp.ndarray
    mother: jnp.ndarray
    father: jnp.ndarray
    sex: jnp.ndarray
    gen: jnp.ndarray

    # --- JAX Compatibility ---
    # Boolean mask to distinguish real vs. padded individuals
    is_active: jnp.ndarray

    # --- Trait and Value Data (all shape (max_pop_size, ...)) ---
    pheno: jnp.ndarray
    fixEff: jnp.ndarray
    
    gv: Optional[jnp.ndarray] = None      # Genetic Value (BV + Intercept)
    bv: Optional[jnp.ndarray] = None      # Breeding Value (Additive)
    dd: Optional[jnp.ndarray] = None      # Dominance Deviations
    aa: Optional[jnp.ndarray] = None      # Additive-by-Additive Epistatic Deviations
    ebv: Optional[jnp.ndarray] = None
    gxe: Optional[jnp.ndarray] = None

    # --- Metadata (Static, not part of the JAX Pytree) ---
    misc: Optional[Dict[str, Any]] = field(default=None, pytree_node=False)
    miscPop: Optional[Dict[str, Any]] = field(default=None, pytree_node=False)

    @property
    def nInd(self) -> int:
        """Returns the number of ACTIVE individuals in the population."""
        return jnp.sum(self.is_active)

    @property
    def max_pop_size(self) -> int:
        """Returns the maximum capacity of the population arrays."""
        return self.geno.shape[0]

    @property
    def nChr(self) -> int:
        """Returns the number of chromosomes."""
        return self.geno.shape[1]

    @property
    def nTraits(self) -> int:
        """Returns the number of traits."""
        if self.bv is None or self.bv.ndim <= 1:
            return 0
        return self.bv.shape[1]

    @property
    def dosage(self) -> jnp.ndarray:
        """
        Calculates the dosage of alternate alleles for each individual,
        masking inactive individuals.
        """
        # Sum alleles to get dosage
        dosage_per_chr = jnp.sum(self.geno, axis=2)
        full_dosage = dosage_per_chr.reshape(self.max_pop_size, -1)
        
        # Mask out inactive individuals, replacing their dosage with NaN
        return jnp.where(self.is_active[:, None], full_dosage, jnp.nan)

In [2]:
from dataclasses import field
from typing import List, Optional, Dict, Callable, Any

import jax
import jax.numpy as jnp

from chewc.sp import SimParam
from typing import Tuple
from numpy.random import default_rng
import msprime
import tskit
import numpy as np
import random
from collections import defaultdict
from flax.struct import dataclass as flax_dataclass, field

#testing
import jax
import jax.numpy as jnp
from fastcore.test import test_eq, test_ne

def quick_haplo(
    key: jax.random.PRNGKey,
    n_ind: int,
    n_chr: int,
    n_loci_per_chr: int,
    max_pop_size: int,
    ploidy: int = 2,
    inbred: bool = False,
    chr_len_cm: float = 100.0
) -> Tuple[Population, jnp.ndarray]:
    """
    Creates a new, padded founder population with random haplotypes.

    Args:
        key: A JAX random key.
        n_ind: Number of founder individuals to create.
        n_chr: Number of chromosomes.
        n_loci_per_chr: Number of loci on each chromosome.
        max_pop_size: The total size of the arrays to pre-allocate.
        ... (other args are the same)

    Returns:
        A tuple of (Population, genetic_map). The Population object will
        contain `n_ind` active individuals in arrays of size `max_pop_size`.
    """
    if n_ind > max_pop_size:
        raise ValueError("n_ind cannot be greater than max_pop_size.")

    key, geno_key, sex_key = jax.random.split(key, 3)

    # 1. Generate data for the initial n_ind founders
    if inbred:
        base_haplotypes = jax.random.randint(geno_key, (n_ind, n_chr, 1, n_loci_per_chr), 0, 2, dtype=jnp.uint8)
        founder_geno = jnp.tile(base_haplotypes, (1, 1, ploidy, 1))
    else:
        founder_geno = jax.random.randint(geno_key, (n_ind, n_chr, ploidy, n_loci_per_chr), 0, 2, dtype=jnp.uint8)
    
    founder_ids = jnp.arange(n_ind)
    founder_sex = jax.random.choice(sex_key, jnp.array([0, 1], dtype=jnp.int8), (n_ind,))
    
    n_founder_alleles = n_ind * n_chr * ploidy * n_loci_per_chr
    founder_ibd_flat = jnp.arange(n_founder_alleles, dtype=jnp.uint32)
    founder_ibd = founder_ibd_flat.reshape(n_ind, n_chr, ploidy, n_loci_per_chr)
    if inbred:
        founder_ibd = jnp.tile(founder_ibd[:, :, 0:1, :], (1, 1, ploidy, 1))

    # 2. Pad all arrays to max_pop_size
    n_pad = max_pop_size - n_ind
    
    padded_geno = jnp.pad(founder_geno, ((0, n_pad), (0, 0), (0, 0), (0, 0)), constant_values=0)
    padded_ibd = jnp.pad(founder_ibd, ((0, n_pad), (0, 0), (0, 0), (0, 0)), constant_values=-1)
    padded_id = jnp.pad(founder_ids, (0, n_pad), constant_values=-1)
    padded_sex = jnp.pad(founder_sex, (0, n_pad), constant_values=-1)

    # 3. Create the is_active mask
    is_active = jnp.arange(max_pop_size) < n_ind

    # --- Generate a uniform genetic map ---
    loci_pos = jnp.linspace(0., chr_len_cm, n_loci_per_chr)
    genetic_map = jnp.tile(loci_pos, (n_chr, 1))

    population = Population(
        geno=padded_geno,
        ibd=padded_ibd,
        id=padded_id,
        iid=jnp.arange(max_pop_size), # iid is always contiguous for indexing
        mother=jnp.full(max_pop_size, -1, dtype=jnp.int32),
        father=jnp.full(max_pop_size, -1, dtype=jnp.int32),
        sex=padded_sex,
        gen=jnp.zeros(max_pop_size, dtype=jnp.int32), # Active individuals are gen 0
        is_active=is_active,
        pheno=jnp.full((max_pop_size, 0), jnp.nan),
        fixEff=jnp.zeros(max_pop_size, dtype=jnp.float32),
        bv=jnp.full((max_pop_size, 0), jnp.nan),
    )
    
    return population, genetic_map

In [4]:
import jax
import jax.numpy as jnp
from fastcore.test import test_eq

key = jax.random.PRNGKey(42)
n_founders = 10
max_pop = 25
n_chr = 2
n_loci = 5

# --- Function Call ---
founder_pop, gen_map = quick_haplo(
key=key,
n_ind=n_founders,
n_chr=n_chr,
n_loci_per_chr=n_loci,
max_pop_size=max_pop
)

# --- Assertions ---
print("--- Verifying Population Shapes ---")
test_eq(founder_pop.geno.shape, (max_pop, n_chr, 2, n_loci))
test_eq(founder_pop.id.shape, (max_pop,))
test_eq(founder_pop.is_active.shape, (max_pop,))

print("\n--- Verifying Active Population Size ---")
test_eq(founder_pop.nInd, n_founders)
test_eq(founder_pop.max_pop_size, max_pop)

print("\n--- Verifying `is_active` Mask ---")
# First n_founders should be True
test_eq(jnp.sum(founder_pop.is_active[:n_founders]), n_founders)
# The rest should be False
test_eq(jnp.sum(founder_pop.is_active[n_founders:]), 0)

print("\n--- Verifying Padding Values ---")
# Check that padded IDs are -1
test_eq(founder_pop.id[n_founders], -1)
# Check that padded genotypes are 0
test_eq(jnp.sum(founder_pop.geno[n_founders:]), 0)

print("\n--- Verifying Founder Data Integrity ---")
# Check that the first n_founders have correct, non-padded IDs
expected_ids = jnp.arange(n_founders)
test_eq(founder_pop.id[:n_founders], expected_ids)

print("\n✅ All tests passed!")


--- Verifying Population Shapes ---

--- Verifying Active Population Size ---

--- Verifying `is_active` Mask ---

--- Verifying Padding Values ---

--- Verifying Founder Data Integrity ---

✅ All tests passed!
