In [None]:
#| default_exp cross

In [None]:
#|export
import jax
import jax.numpy as jnp
from typing import Tuple
from chewc.meiosis import create_gamete
from chewc.structs import GeneticMap
from functools import partial


@partial(jax.jit, static_argnames=("n_crosses",))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    """Sample (mother, father) pairs with replacement.
    Works even when `n_parents` is not a Python literal.
    """
    k_m, k_f = jax.random.split(key)
    mothers = jax.random.randint(k_m, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    fathers = jax.random.randint(k_f, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    return jnp.stack([mothers, fathers], axis=-1)



@partial(jax.jit, static_argnames=("max_crossovers",))
def cross_pair(
    key: jax.Array,
    mother_geno: jax.Array, father_geno: jax.Array,   # (n_chr, 2, n_loci)
    mother_ibd: jax.Array,  father_ibd: jax.Array,    # (n_chr, 2, n_loci)
    genetic_map: GeneticMap,
    max_crossovers: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    km, kf = jax.random.split(key)
    m_geno, m_ibd = create_gamete(km, mother_geno, mother_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    f_geno, f_ibd = create_gamete(kf, father_geno, father_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    offspring_geno = jnp.stack([m_geno, f_geno], axis=1)  # (n_chr, 2, n_loci)
    offspring_ibd  = jnp.stack([m_ibd,  f_ibd],  axis=1)  # (n_chr, 2, n_loci)
    return offspring_geno, offspring_ibd


@partial(jax.jit, static_argnames=['n_crosses', 'allow_selfing'])
def create_ocs_mating_plan(
    key: jax.Array,
    candidate_ids: jnp.ndarray,
    contributions: jnp.ndarray,
    n_crosses: int,
    allow_selfing: bool = False
) -> jnp.ndarray:
    """
    Generates a mating list based on optimal contribution scores for unisex populations.
    """
    contributions = jnp.maximum(0, contributions)
    contributions /= jnp.sum(contributions)

    if allow_selfing:
        total_parents = n_crosses * 2
        all_parents = jax.random.choice(key, candidate_ids,
                                        shape=(total_parents,), p=contributions)
        return all_parents.reshape((n_crosses, 2))
    else:
        # Sample Parent 1s
        key_p1, key_p2 = jax.random.split(key)
        parent1s = jax.random.choice(key_p1, candidate_ids,
                                     shape=(n_crosses,), p=contributions)

        # For each cross, create a custom probability distribution for Parent 2
        # that zeros out the probability of selfing
        def sample_parent2(p1_id, subkey):
            # Create a mask to exclude parent 1
            mask = candidate_ids != p1_id
            
            # Adjust probabilities
            p2_probs = contributions * mask
            p2_probs /= jnp.sum(p2_probs)

            # Sample parent 2
            return jax.random.choice(subkey, candidate_ids, p=p2_probs)
        
        # Vmap the sampling process for all crosses
        parent2s = jax.vmap(sample_parent2)(parent1s, jax.random.split(key_p2, n_crosses))

        return jnp.stack([parent1s, parent2s], axis=-1)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()