# cross

> Fill in a module description here

In [None]:
#| default_exp cross

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import jax
from jax import vmap
import jax.numpy as jnp
from chewc.sp import SimParam
from chewc.population import quick_haplo, Population
from chewc.trait import add_trait_a, TraitCollection
from chewc.pheno import set_pheno
from chewc.meiosis import *

# --- 1. JAX Setup ---
key = jax.random.PRNGKey(42)

# --- 2. Define the Genome's "Blueprint" ---
n_chr = 3
n_loci_per_chr = 100
ploidy = 2
gen_map = jnp.array([jnp.linspace(0, 1, n_loci_per_chr) for _ in range(n_chr)])
centromeres = jnp.full(n_chr, 0.5)

# --- 3. Instantiate Initial Simulation Parameters ---
SP = SimParam(
    gen_map=gen_map,
    centromere=centromeres,
    ploidy=ploidy
)

key, pop_key = jax.random.split(key)
n_founders = 50

founder_pop = quick_haplo(
    key=pop_key,
    sim_param=SP,
    n_ind=n_founders,
    inbred=False
)

SP = SP.replace(founderPop=founder_pop)
SP.founderPop.geno.shape # (50, 3, 2, 100) individuals,chromosomes, ploidy, markers




(50, 3, 2, 100)

In [None]:
#| export

import jax
from jax import vmap
import jax.numpy as jnp
from chewc.sp import SimParam
from chewc.population import quick_haplo, Population
from chewc.trait import add_trait_a, TraitCollection
from chewc.pheno import set_pheno
from chewc.meiosis import *


def make_cross(key: jax.random.PRNGKey, 
               pop: Population, 
               cross_plan: jnp.ndarray, 
               sim_param: SimParam) -> Population:
    """
    Creates progeny from a series of planned crosses in a vectorized manner.

    Args:
        key: A single JAX random key for the entire operation.
        pop: The parent population.
        cross_plan: A 2D array of shape (nCrosses, 2) where each row
                    contains the mother and father iid, respectively.
        sim_param: The simulation parameters object.

    Returns:
        A new Population object containing all the generated progeny.
    """
    n_crosses = cross_plan.shape[0]

    # 1. Gather the genotypes of all parents in the plan.
    # The `cross_plan` contains internal IDs (iids), which are perfect for direct indexing.
    mother_iids = cross_plan[:, 0]
    father_iids = cross_plan[:, 1]
    
    mothers_geno = pop.geno[mother_iids] # Shape: (nCrosses, nChr, ploidy, nLoci)
    fathers_geno = pop.geno[father_iids] # Shape: (nCrosses, nChr, ploidy, nLoci)

    # 2. Create a vectorized version of our single-cross function.
    # `in_axes` tells vmap to map over the first axis of the first three arguments
    # (keys, mothers, fathers) and to treat the subsequent arguments as constant.
    vmapped_cross_creator = vmap(
        meiosis_for_one_cross, 
        in_axes=(0, 0, 0, None, None, None)
    )

    # 3. Generate a unique key for each cross.
    cross_keys = jax.random.split(key, n_crosses)
    
    # 4. Execute all crosses in one parallel operation.
    progeny_geno = vmapped_cross_creator(
        cross_keys,
        mothers_geno,
        fathers_geno,
        sim_param.n_chr,
        sim_param.gen_map,
        sim_param.recomb_params[0]
    )
    # The resulting shape is (nCrosses, nChr, ploidy, nLoci), which matches
    # the shape of our population's `geno` attribute.

    # 5. Create the new Population object for the progeny.
    # Note: This part runs on the CPU after the main JAX computation is done.
    # In a real simulation, you would increment last_id from SimParam.
    new_iids = jnp.arange(n_crosses) 
    
    # Get the public-facing IDs from the parent population
    mother_ids = pop.id[mother_iids]
    father_ids = pop.id[father_iids]
    
    # For simplicity, we create new IDs; in the full library, you'd
    # manage this globally from SimParam.
    new_public_ids = jnp.arange(pop.nInd, pop.nInd + n_crosses) 

    progeny_pop = Population(
        geno=progeny_geno,
        id=new_public_ids,
        iid=new_iids, 
        mother=mother_ids,
        father=father_ids,
        sex=jax.random.choice(key, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)), # Placeholder
        pheno=jnp.zeros((n_crosses, 0)),
        fixEff=jnp.ones(n_crosses),
    )
    
    return progeny_pop

In [None]:
#| hide
# --- Setup ---
# Create a crossing plan: cross individuals (0, 1), (2, 3), and (0, 3)
cross_plan = jnp.array([
    [0, 1],
    [2, 3],
    [0, 3],
    [0, 3],
    [3, 3],
], dtype=jnp.int32)
n_crosses = cross_plan.shape[0]

# --- Run Validation ---
key, progeny_key = jax.random.split(key)
progeny_population = make_cross(progeny_key, founder_pop, cross_plan, SP)

# --- Analyze and Report ---
print("Validating `make_cross` function:")
print(f"Number of crosses made: {n_crosses}")
print(f"Number of progeny produced: {progeny_population.nInd}")
print(f"Shape of progeny genotype array: {progeny_population.geno.shape}")
print(f"Progeny pedigree:\n{progeny_population.mother=}\n{progeny_population.father=}")

# 1. Check that the number of individuals matches the number of crosses
assert progeny_population.nInd == n_crosses
# 2. Check that the shape of the genotype array is correct
assert progeny_population.geno.shape == (n_crosses, SP.n_chr, SP.ploidy, SP.gen_map.shape[1])
# 3. Check pedigree tracking
assert jnp.all(progeny_population.mother == founder_pop.id[cross_plan[:, 0]])
assert jnp.all(progeny_population.father == founder_pop.id[cross_plan[:, 1]])

print("\n✅ Validation successful: Population-scale crossing works correctly.")

Validating `make_cross` function:
Number of crosses made: 5
Number of progeny produced: 5
Shape of progeny genotype array: (5, 3, 2, 100)
Progeny pedigree:
progeny_population.mother=Array([0, 2, 0, 0, 3], dtype=int32)
progeny_population.father=Array([1, 3, 3, 3, 3], dtype=int32)

✅ Validation successful: Population-scale crossing works correctly.


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()