# cross

> Fill in a module description here

In [None]:
#| default_exp cross

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import jax
from jax import vmap
import jax.numpy as jnp
from functools import partial

# Assuming these imports are correctly set up in your project structure
from chewc.sp import SimParam
from chewc.population import Population
from chewc.meiosis import meiosis_for_one_cross



All IBD tests passed!


In [None]:
#| export

@partial(jax.jit, static_argnames=("n_chr",))
def _make_cross_geno(
    key: jax.random.PRNGKey,
    mothers_geno: jnp.ndarray,
    fathers_geno: jnp.ndarray,
    mothers_ibd: jnp.ndarray,    # NEW: Mothers' IBD
    fathers_ibd: jnp.ndarray,    # NEW: Fathers' IBD
    n_chr: int,
    gen_map: jnp.ndarray,
    recomb_param_v: float
) -> tuple[jnp.ndarray, jnp.ndarray]:  # Return both geno and IBD
    """
    (Internal JIT-compiled core) Creates progeny genotypes and IBD from parent genotypes and IBD.

    This function contains only JAX-traceable operations, making it ideal for
    `jax.jit`. It takes JAX arrays as input and returns JAX arrays.

    Args:
        key: A JAX random key. A unique key must be provided for each
             batch of crosses.
        mothers_geno: Genotypes of the mothers. Shape: (nCrosses, nChr, ploidy, nLoci).
        fathers_geno: Genotypes of the fathers. Shape: (nCrosses, nChr, ploidy, nLoci).
        mothers_ibd: IBD arrays of the mothers. Shape: (nCrosses, nChr, ploidy, nLoci).
        fathers_ibd: IBD arrays of the fathers. Shape: (nCrosses, nChr, ploidy, nLoci).
        n_chr: The number of chromosomes (static argument).
        gen_map: The genetic map.
        recomb_param_v: The 'v' interference parameter for recombination.

    Returns:
        A tuple of (progeny_geno, progeny_ibd), both with shape (nCrosses, nChr, ploidy, nLoci).
    """
    # Vectorize the single-cross meiosis function to run all crosses in parallel.
    # `in_axes` maps over the first dimension (the "cross" dimension) of the
    # keys and parent data, while broadcasting the static parameters.
    vmapped_cross_creator = vmap(
        meiosis_for_one_cross,
        in_axes=(0, 0, 0, 0, 0, None, None, None)  # Updated for IBD arrays
    )

    n_crosses = mothers_geno.shape[0]
    cross_keys = jax.random.split(key, n_crosses)

    progeny_geno, progeny_ibd = vmapped_cross_creator(
        cross_keys,
        mothers_geno,
        fathers_geno,
        mothers_ibd,  # Pass mothers' IBD
        fathers_ibd,  # Pass fathers' IBD
        n_chr,
        gen_map,
        recomb_param_v
    )
    return progeny_geno, progeny_ibd


def make_cross(
    key: jax.random.PRNGKey,
    pop: Population,
    cross_plan: jnp.ndarray,
    sp: SimParam,
    next_id_start: int
) -> Population:
    """
    (Public-facing) Creates progeny from a planned series of crosses.

    This function handles the "CPU-side" logic: preparing data from the main
    Population object, calling the JIT-compiled core `_make_cross_geno`, and
    then assembling the results into a new Population object with updated metadata.

    Args:
        key: A single JAX random key.
        pop: The parent population.
        cross_plan: A 2D array of shape (nCrosses, 2) with mother/father iids.
        sp: The simulation parameters.
        next_id_start: The starting integer for the new individuals' public IDs.

    Returns:
        A new Population object for the progeny.
    """
    n_crosses = cross_plan.shape[0]
    key_geno, key_sex = jax.random.split(key)

    # 1. Prepare JAX arrays for the JIT-compiled function
    mother_iids = cross_plan[:, 0]
    father_iids = cross_plan[:, 1]
    mothers_geno = pop.geno[mother_iids]
    fathers_geno = pop.geno[father_iids]
    mothers_ibd = pop.ibd[mother_iids]    # NEW: Extract mothers' IBD
    fathers_ibd = pop.ibd[father_iids]    # NEW: Extract fathers' IBD

    # 2. Call the highly-optimized, JIT-compiled core function
    progeny_geno, progeny_ibd = _make_cross_geno(  # Updated to return both
        key_geno,
        mothers_geno,
        fathers_geno,
        mothers_ibd,   # Pass mothers' IBD
        fathers_ibd,   # Pass fathers' IBD
        sp.n_chr,
        sp.gen_map,
        sp.recomb_params[0]
    )

    # 3. Handle CPU-side logic: create new metadata and Population object
    new_public_ids = jnp.arange(next_id_start, next_id_start + n_crosses)
    new_iids = jnp.arange(n_crosses) # Internal IDs are always 0-indexed for the new pop
    mother_public_ids = pop.id[mother_iids]
    father_public_ids = pop.id[father_iids]

    progeny_pop = Population(
        geno=progeny_geno,
        ibd=progeny_ibd,  # NEW: Include progeny IBD
        id=new_public_ids,
        iid=new_iids,
        mother=mother_public_ids,
        father=father_public_ids,
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        pheno=jnp.zeros((n_crosses, sp.n_traits)), # Initialize with correct shape
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, sp.n_traits)) # Initialize with correct shape
    )

    return progeny_pop

In [None]:
# --- Imports and setup for cross tests ---
from fastcore.test import test_eq
import jax
import jax.numpy as jnp

# --- Functions and classes from other modules needed for testing ---
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.meiosis import meiosis_for_one_cross # Needed for the cross module itself

def _setup_cross_test():
    """Provides a consistent SimParam and founder population for crossing tests."""
    key = jax.random.PRNGKey(42)
    founder_pop, gen_map = quick_haplo(
        key=key,
        n_ind=10,
        n_chr=2,
        n_loci_per_chr=50
    )
    sp = SimParam.from_founder_pop(founder_pop, gen_map)

    return {
        "key": jax.random.PRNGKey(43), # Return a fresh key for the actual tests
        "pop": founder_pop,
        "sp": sp
    }

In [None]:
#| test

def test_make_cross_geno_output_shape():
    "Test the internal JIT-compiled function for correct output shape."
    # ARRANGE
    setup = _setup_cross_test()
    pop = setup["pop"]
    sp = setup["sp"]
    key = setup["key"]
    n_crosses = 5
    
    # Create arbitrary parent genotypes
    mothers_geno = pop.geno[:n_crosses]
    fathers_geno = pop.geno[n_crosses:]
    
    # ACT
    progeny_geno = _make_cross_geno(
        key,
        mothers_geno,
        fathers_geno,
        sp.n_chr,
        sp.gen_map,
        sp.recomb_params[0]
    )
    
    # ASSERT
    expected_shape = (n_crosses, sp.n_chr, sp.ploidy, sp.n_loci_per_chr[0])
    test_eq(progeny_geno.shape, expected_shape)
    test_eq(progeny_geno.dtype, jnp.uint8)

def test_make_cross_progeny_population_attributes():
    "Test the public API for correct Population object construction."
    # ARRANGE
    setup = _setup_cross_test()
    pop = setup["pop"]
    sp = setup["sp"]
    key = setup["key"]
    
    # Cross individuals 0 & 1, and 2 & 3
    cross_plan = jnp.array([[0, 1], [2, 3]]) 
    n_crosses = cross_plan.shape[0]
    
    # ACT
    progeny_pop = make_cross(
        key,
        pop,
        cross_plan,
        sp,
        next_id_start=pop.nInd  # Start new IDs after the last parent
    )
    
    # ASSERT
    assert isinstance(progeny_pop, Population)
    test_eq(progeny_pop.nInd, n_crosses)
    
    # Check public IDs
    expected_ids = jnp.array([10, 11])
    test_eq(progeny_pop.id, expected_ids)
    
    # Check internal IDs (should be 0-indexed for the new population)
    expected_iids = jnp.array([0, 1])
    test_eq(progeny_pop.iid, expected_iids)
    
    # Check pedigree using the public IDs of the parents
    expected_mothers = pop.id[cross_plan[:, 0]]
    expected_fathers = pop.id[cross_plan[:, 1]]
    test_eq(progeny_pop.mother, expected_mothers)
    test_eq(progeny_pop.father, expected_fathers)
    
    # Check that trait arrays are initialized correctly
    test_eq(progeny_pop.pheno.shape, (n_crosses, sp.n_traits))
    test_eq(progeny_pop.bv.shape, (n_crosses, sp.n_traits))

def test_make_cross_reproducibility_and_randomness():
    "Test that JAX keys control reproducibility correctly."
    # ARRANGE
    setup = _setup_cross_test()
    pop = setup["pop"]
    sp = setup["sp"]
    key1, key2 = jax.random.split(setup["key"])
    cross_plan = jnp.array([[0, 1], [2, 3], [4, 5]])

    # ACT
    # Run 1 and 2 use the same key
    progeny1 = make_cross(key1, pop, cross_plan, sp, next_id_start=10)
    progeny2 = make_cross(key1, pop, cross_plan, sp, next_id_start=10)
    
    # Run 3 uses a different key
    progeny3 = make_cross(key2, pop, cross_plan, sp, next_id_start=10)
    
    # ASSERT
    # Reproducibility: Same key should produce identical genotypes
    assert jnp.array_equal(progeny1.geno, progeny2.geno)
    
    # Randomness: Different keys should produce different genotypes
    assert not jnp.array_equal(progeny1.geno, progeny3.geno)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()