# meiosis

> Fill in a module description here

In [None]:
#| default_exp meiosis

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import jax
import jax.numpy as jnp
from chewc.sp import SimParam
from chewc.popgen import quick_haplo
from chewc.population import Population
from chewc.trait import add_trait_a, TraitCollection
from chewc.phenotype import set_pheno

# --- Imports for Testing ---
from fastcore.test import test_eq, test_close, test_fail
import jax
import jax.numpy as jnp
import numpy as np

# --- Functions from other modules needed for testing ---
from chewc.population import Population
from chewc.sp import SimParam

from jax import lax, vmap
from functools import partial
import matplotlib.pyplot as plt

In [None]:
#| export

@partial(jax.jit, static_argnames=("max_crossovers",))
def _create_gamete(key: jax.random.PRNGKey, 
                   parental_geno_haplotypes: jnp.ndarray,
                   parental_ibd_haplotypes: jnp.ndarray,  # NEW: IBD haplotypes
                   gen_map: jnp.ndarray,
                   v_interference: float,
                   max_crossovers: int = 20) -> tuple[jnp.ndarray, jnp.ndarray]:  # Return both geno and IBD
    """
    Creates a single recombinant gamete from a parent's two haplotypes for one
    chromosome, tracking both genotypes and IBD.

    Args:
        key: A JAX random key.
        parental_geno_haplotypes: A (2, nLoci) array of the two parental genotype haplotypes.
        parental_ibd_haplotypes: A (2, nLoci) array of the two parental IBD haplotypes.
        gen_map: A (nLoci,) array of locus positions.
        v_interference: The interference parameter for chiasma sampling.
        max_crossovers: A static integer for the maximum number of crossovers.

    Returns:
        A tuple of (new_geno_gamete, new_ibd_gamete), both (nLoci,) arrays.
    """
    key, chiasma_key, hap_key = jax.random.split(key, 3)

    map_length = gen_map[-1]
    
    crossover_positions = _sample_chiasmata(chiasma_key, map_length, 
                                            v_interference, max_crossovers)

    crossover_indices = jnp.searchsorted(gen_map, crossover_positions)

    start_hap = jax.random.choice(hap_key, jnp.array([0, 1], dtype=jnp.uint8))

    n_loci = gen_map.shape[0]
    
    # Use side='right' to ensure the crossover happens at the correct locus index.
    locus_segments = jnp.searchsorted(crossover_indices, jnp.arange(n_loci), side='right')
    
    haplotype_choice = (start_hap + locus_segments) % 2
    
    # Apply THE SAME choice mask to both genotypes and IBD
    new_geno_gamete = jnp.where(haplotype_choice == 0, 
                                parental_geno_haplotypes[0], 
                                parental_geno_haplotypes[1])
    
    new_ibd_gamete = jnp.where(haplotype_choice == 0,
                               parental_ibd_haplotypes[0],
                               parental_ibd_haplotypes[1])
    
    return new_geno_gamete, new_ibd_gamete



#| test


# --- Test Setup Helper ---

def _setup_meiosis_test():
    """Provides a consistent setup for meiosis tests."""
    key = jax.random.PRNGKey(123)
    n_ind = 2
    n_chr = 5
    n_loci_per_chr = 1000  # Use a larger number of loci to ensure crossovers happen
    ploidy = 2

    # Create a base population and SimParam
    founder_pop, gen_map = quick_haplo(key, n_ind, n_chr, n_loci_per_chr, ploidy=ploidy)
    sp = SimParam.from_founder_pop(founder_pop, gen_map, v_interference=2.6)

    # Create distinct parents
    # Mother is all zeros, Father is all ones
    mother_geno = jnp.zeros_like(founder_pop.geno[0])
    father_geno = jnp.ones_like(founder_pop.geno[1])
    
    return {
        "key": key,
        "mother_geno": mother_geno,
        "father_geno": father_geno,
        "sp": sp,
        "gen_map": gen_map
    }



#| export
# The signature and decorator are changed to be more specific and JIT-friendly.
@partial(jax.jit, static_argnames=("n_chr",))
def meiosis_for_one_cross(key: jax.random.PRNGKey,
                          mother_geno: jnp.ndarray,
                          father_geno: jnp.ndarray,
                          n_chr: int,
                          gen_map: jnp.ndarray,
                          v_interference: float
                         ) -> jnp.ndarray:
    """
    Creates a single diploid progeny's genotype from two parents' genotypes
    by simulating meiosis for all chromosomes in parallel.

    This function is a high-performance kernel designed to be compiled by JAX.

    --- JAX Implementation Notes ---
    This function showcases the composition of JAX's core transformations for
    maximum performance on parallel architectures (like GPUs).

    1.  **`vmap` for Parallelism**: The core logic of creating a gamete for a
        single chromosome is defined in `_create_gamete`. `vmap` is used to
        automatically "vectorize" this function, applying it across all
        chromosomes of a parent simultaneously. The `in_axes` argument
        is critical for this:
        - `in_axes=(0, 0, 0, None)` tells `vmap` to map over the first axis
          (the chromosome axis) of the `keys`, `parent_geno`, and `gen_map`
          arrays, while broadcasting the single `v_interference` value to all
          parallel executions. This avoids unnecessary memory duplication.

    2.  **`jit` for Fused Compilation**: The entire function is JIT-compiled.
        JAX is able to "fuse" the `vmap` operations and the final `jnp.stack`
        into a single, highly-optimized kernel. This minimizes overhead from
        launching separate computations and maximizes hardware utilization.

    3.  **Static Arguments**: The number of chromosomes, `n_chr`, is used to
        determine the number of random keys to split via
        `jax.random.split(key, n_chr)`. Because the shape of a JAX array must
        be known at compile time, `n_chr` cannot be a dynamic (traced) value.
        It is therefore marked as a static argument, meaning JAX will
        re-compile this function if `n_chr` changes.

    Args:
        key: A JAX random key.
        mother_geno: The mother's genotype. Shape: `(nChr, ploidy, nLoci)`.
        father_geno: The father's genotype. Shape: `(nChr, ploidy, nLoci)`.
        n_chr: The number of chromosomes. **Must be a static integer.**
        gen_map: The genetic map defining locus positions for each chromosome.
        v_interference: The interference parameter for the Gamma process.

    Returns:
        The progeny's complete diploid genotype. Shape: `(nChr, ploidy, nLoci)`.
    """

    key_mother, key_father = jax.random.split(key)

    # Define a vectorized version of the single-chromosome gamete creator.
    vmapped_gamete_creator = vmap(
        _create_gamete,
        in_axes=(0, 0, 0, None)  # Map over keys, geno, gen_map; broadcast v
    )

    # Create all of the mother's gamete-haplotypes in parallel
    mother_gametes = vmapped_gamete_creator(
        jax.random.split(key_mother, n_chr),
        mother_geno,
        gen_map,
        v_interference
    )

    # Create all of the father's gamete-haplotypes in parallel
    father_gametes = vmapped_gamete_creator(
        jax.random.split(key_father, n_chr),
        father_geno,
        gen_map,
        v_interference
    )

    # Stack the two resulting gametes to form the new diploid genotype
    progeny_geno = jnp.stack([mother_gametes, father_gametes], axis=1)

    return progeny_geno

In [None]:
#| test

def test_sample_chiasmata_shape_and_bounds():
    "Tests the output shape and that positions are within the chromosome."
    # ARRANGE
    setup = _setup_meiosis_test()
    key = setup["key"]
    sp = setup["sp"]
    max_crossovers = 20
    map_length = setup["gen_map"][0, -1] # Length of one chromosome

    # ACT
    crossovers = _sample_chiasmata(key, map_length, sp.recomb_params[0], max_crossovers=max_crossovers)

    # ASSERT
    test_eq(crossovers.shape, (max_crossovers,))
    
    # Check that all non-NaN values are within the valid range
    valid_crossovers = crossovers[~jnp.isnan(crossovers)]
    assert jnp.all(valid_crossovers > 0)
    assert jnp.all(valid_crossovers < map_length)

def test_sample_chiasmata_interference():
    "Tests the statistical properties of the interference parameter 'v'."
    # ARRANGE
    key = jax.random.PRNGKey(42)
    n_reps = 1000
    map_length = 1.0  # 1 Morgan
    
    keys = jax.random.split(key, n_reps)

    # --- With interference (v > 1) ---
    v_interference = 5.0 
    sample_fn_interference = jax.vmap(_sample_chiasmata, in_axes=(0, None, None))
    all_crossovers_interference = sample_fn_interference(keys, map_length, v_interference)
    n_crossovers_interference = jnp.sum(~jnp.isnan(all_crossovers_interference), axis=1)
    
    # --- Without interference (v = 1, approximates a Poisson process) ---
    v_no_interference = 1.0
    sample_fn_no_interference = jax.vmap(_sample_chiasmata, in_axes=(0, None, None))
    all_crossovers_no_interference = sample_fn_no_interference(keys, map_length, v_no_interference)
    n_crossovers_no_interference = jnp.sum(~jnp.isnan(all_crossovers_no_interference), axis=1)

    # ASSERT
    # Haldane's mapping function (no interference) expects lambda = 2 * map_length crossovers
    # Our model is slightly different, but the mean should be close to map_length.
    # The key is that interference should reduce the variance and mean number of crossovers.
    mean_xo_interference = jnp.mean(n_crossovers_interference)
    mean_xo_no_interference = jnp.mean(n_crossovers_no_interference)
    
    assert mean_xo_interference < mean_xo_no_interference
    test_close(mean_xo_no_interference, map_length, eps=0.1)

In [None]:
#| test

def test_create_gamete_recombination():
    "Tests that a gamete is a valid recombinant of parental haplotypes."
    # ARRANGE
    setup = _setup_meiosis_test()
    key = setup["key"]
    sp = setup["sp"]
    
    # Parental haplotypes are distinct: [all 0s, all 1s]
    parental_haps = jnp.stack([
        jnp.zeros(sp.n_loci_per_chr[0], dtype=jnp.uint8),
        jnp.ones(sp.n_loci_per_chr[0], dtype=jnp.uint8)
    ])
    
    # ACT (with recombination)
    gamete = _create_gamete(key, parental_haps, setup["gen_map"][0], sp.recomb_params[0])
    
    # ASSERT (with recombination)
    test_eq(gamete.shape, (sp.n_loci_per_chr[0],))
    # A successful recombination should have both 0s and 1s
    assert jnp.any(gamete == 0)
    assert jnp.any(gamete == 1)
    
    # ACT (without recombination)
    # Create a genetic map with zero length to prevent crossovers
    zero_map = jnp.zeros_like(setup["gen_map"][0])
    gamete_no_recomb = _create_gamete(key, parental_haps, zero_map, sp.recomb_params[0])
    
    # ASSERT (without recombination)
    # The gamete should be identical to one of the parental haplotypes
    is_hap0 = jnp.all(gamete_no_recomb == parental_haps[0])
    is_hap1 = jnp.all(gamete_no_recomb == parental_haps[1])
    assert is_hap0 or is_hap1


def test_meiosis_for_one_cross_output():
    "Tests the final output shape and allele origins for a full cross."
    # ARRANGE
    setup = _setup_meiosis_test()
    key = setup["key"]
    sp = setup["sp"]
    
    # ACT
    progeny_geno = meiosis_for_one_cross(
        key,
        setup["mother_geno"],
        setup["father_geno"],
        sp.n_chr,
        setup["gen_map"],
        sp.recomb_params[0]
    )
    
    # ASSERT
    # Check shape: (nChr, ploidy, nLoci)
    expected_shape = (sp.n_chr, sp.ploidy, sp.n_loci_per_chr[0])
    test_eq(progeny_geno.shape, expected_shape)
    
    # Check allele origin. One haplotype in the progeny should be all 0s (from mother)
    # and the other should be all 1s (from father). Recombination happens within
    # a parent before segregation, so this should hold true.
    maternal_haplotype = progeny_geno[:, 0, :]
    paternal_haplotype = progeny_geno[:, 1, :]
    
    assert jnp.all(maternal_haplotype == 0)
    assert jnp.all(paternal_haplotype == 1)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()