In [1]:
#| default_exp popgen

In [2]:
#| export
import jax
import jax.numpy as jnp
from typing import Tuple
from chewc.config import StaticConfig, Population

In [3]:
#| export

def create_founders(
    key: jax.random.PRNGKey,
    s_config: StaticConfig,
    n_founders: int,
    chr_length_cM: float = 100.0
) -> Tuple[Population, jnp.ndarray, jnp.ndarray]:
    """
    Creates a founder population and its corresponding genetic map.

    Args:
        key: A JAX PRNGKey.
        s_config: The static configuration of the simulation.
        n_founders: The number of founder individuals to create.
        chr_length_cM: The length of each chromosome in centiMorgans.

    Returns:
        A tuple containing:
        - population (Population): The founder population object.
        - genetic_map (jnp.ndarray): The genetic map (distances between loci in Morgans).
        - locus_positions (jnp.ndarray): The cumulative locus positions in Morgans.
    """
    key_geno, key_map = jax.random.split(key)

    # --- Genetic Map Creation ---
    # Chromosome length in Morgans (1 cM = 0.01 M)
    chr_length_M = chr_length_cM / 100.0

    # 1. Generate random locus positions for each chromosome and sort them
    locus_positions = jnp.sort(
        jax.random.uniform(key_map, shape=(s_config.n_chr, s_config.n_loci_per_chr)) * chr_length_M,
        axis=-1
    )

    # 2. The genetic map is the distance between adjacent loci
    genetic_map = jnp.diff(locus_positions, prepend=0., axis=-1)

    # --- Founder Genotype Creation ---
    geno = jax.random.randint(
        key_geno,
        (n_founders, s_config.n_chr, s_config.ploidy, s_config.n_loci_per_chr),
        0,
        2,
        dtype=jnp.int8,
    )
    ids = jnp.arange(n_founders, dtype=jnp.int32)[:, None]
    parent_ids = jnp.full((n_founders, 2), -1, dtype=jnp.int32)
    birth_gen = jnp.zeros((n_founders, 1), dtype=jnp.int32)
    population = Population(geno=geno, meta=jnp.hstack([ids, parent_ids, birth_gen]))

    return population, genetic_map, locus_positions


In [4]:
#| hide
import nbdev; nbdev.nbdev_export()