In [1]:
#| default_exp burnin

In [2]:
#|export


import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple, Optional

from chewc.structs import Population, BreedingState, GeneticMap, quick_haplo, compute_dosage
from chewc.cross import random_mating, cross_pair

# %% ../nbs/10_burnin.ipynb 2
@jax.jit
def calculate_adjacent_ld(population: Population) -> jnp.ndarray:
    """
    Calculates the Linkage Disequilibrium (r^2) between adjacent markers 
    across the genome.

    Args:
        population (Population): The population to analyze.

    Returns:
        jnp.ndarray: An array of mean r^2 values per chromosome. 
                     Shape: (n_chr,)
    """
    # 1. Get Dosage: Shape (n_ind, n_chr, n_loci)
    dosages = compute_dosage(population).astype(jnp.float32)
    
    n_ind, n_chr, n_loci = dosages.shape
    
    if n_loci < 2:
        return jnp.zeros(n_chr)

    # Helper to calc correlation of adjacent columns for one chromosome matrix
    # Matrix shape: (n_ind, n_loci)
    def _chr_ld(mat):
        # Slice left (0 to N-1) and right (1 to N)
        left = mat[:, :-1]
        right = mat[:, 1:]
        
        # Standardize (z-score)
        # Handle division by zero if a locus is fixed (std=0) by adding epsilon
        left_mean = jnp.mean(left, axis=0)
        left_std = jnp.std(left, axis=0) + 1e-8
        left_z = (left - left_mean) / left_std
        
        right_mean = jnp.mean(right, axis=0)
        right_std = jnp.std(right, axis=0) + 1e-8
        right_z = (right - right_mean) / right_std
        
        # Pearson correlation = mean of product of z-scores
        r = jnp.mean(left_z * right_z, axis=0)
        r2 = r ** 2
        
        # Return the mean r2 for this chromosome
        return jnp.mean(r2)

    # Vectorize over the chromosome axis (axis 1 of dosages)
    # dosages has shape (n_ind, n_chr, n_loci) -> we want to map over axis 1
    # Move axis 1 to front for vmap: (n_chr, n_ind, n_loci)
    dosages_per_chr = jnp.transpose(dosages, (1, 0, 2))
    
    # Calculate mean LD per chromosome
    mean_ld_per_chr = jax.vmap(_chr_ld)(dosages_per_chr)
    
    return mean_ld_per_chr

# %% ../nbs/10_burnin.ipynb 3
@partial(jax.jit, static_argnames=("n_pop", "max_crossovers"))
def _burnin_step(carry, _, genetic_map, n_pop, max_crossovers):
    """
    Single generation step for random mating burn-in.
    """
    state = carry
    key, mate_key, cross_key = jax.random.split(state.key, 3)
    
    # 1. Random Mating Plan (Panmixia)
    # Maintains constant population size n_pop
    pairings = random_mating(mate_key, n_parents=n_pop, n_crosses=n_pop)
    
    # 2. Extract Parent Genotypes/IBD
    # pairings shape: (n_pop, 2)
    mothers_geno = state.population.geno[pairings[:, 0]]
    fathers_geno = state.population.geno[pairings[:, 1]]
    mothers_ibd = state.population.ibd[pairings[:, 0]]
    fathers_ibd = state.population.ibd[pairings[:, 1]]
    
    # 3. Cross to create offspring
    vmapped_cross = jax.vmap(
        partial(cross_pair, max_crossovers=max_crossovers),
        in_axes=(0, 0, 0, 0, 0, None)
    )
    off_keys = jax.random.split(cross_key, n_pop)
    off_geno, off_ibd = vmapped_cross(off_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map)
    
    # 4. Update Metadata
    new_gen = state.generation + 1
    new_ids = jnp.arange(n_pop, dtype=jnp.int32) + state.next_id
    
    # Inherit parent IDs for pedigree tracking
    new_meta = jnp.stack([
        new_ids,
        state.population.meta[pairings[:, 0], 0], # Mother ID
        state.population.meta[pairings[:, 1], 0], # Father ID
        jnp.full((n_pop,), new_gen, dtype=jnp.int32)
    ], axis=-1)
    
    new_pop = Population(geno=off_geno, ibd=off_ibd, meta=new_meta)
    
    # Return updated state
    return BreedingState(
        population=new_pop, 
        key=key, 
        generation=new_gen, 
        next_id=state.next_id + n_pop
    ), None


def run_burnin(
    key: jax.Array,
    n_gens: int,
    n_pop: int,
    n_chr: int = 5,
    n_loci: int = 1000,
    chr_length: float = 1.0,
    max_crossovers: int = 10,
    founder_pop: Optional[Population] = None,
    genetic_map: Optional[GeneticMap] = None
) -> Tuple[BreedingState, jnp.ndarray, GeneticMap]:
    """
    Runs a burn-in simulation to establish Linkage Disequilibrium (LD).

    If no founder population is provided, one is created using quick_haplo.
    
    Args:
        key (jax.Array): Random number generator key.
        n_gens (int): Number of generations to burn in.
        n_pop (int): Population size (constant throughout).
        n_chr (int): Number of chromosomes (used if creating founders).
        n_loci (int): Number of loci per chromosome (used if creating founders).
        chr_length (float): Length of chromosomes in Morgans.
        max_crossovers (int): Max crossovers per chromosome during meiosis.
        founder_pop (Optional[Population]): Custom starting population.
        genetic_map (Optional[GeneticMap]): Custom genetic map.

    Returns:
        Tuple containing:
        - final_state (BreedingState): The simulation state after burn-in.
        - final_ld (jnp.ndarray): Array of mean r^2 between adjacent markers per chromosome.
        - genetic_map (GeneticMap): The genetic map used/created.
    """
    key, init_key = jax.random.split(key)

    # 1. Initialize Founders if not provided
    if founder_pop is None or genetic_map is None:
        founders, gmap = quick_haplo(
            key=init_key,
            n_ind=n_pop,
            n_chr=n_chr,
            seg_sites=n_loci,
            chr_length=chr_length
        )
        # Use generated ones if not provided
        start_pop = founder_pop if founder_pop is not None else founders
        start_map = genetic_map if genetic_map is not None else gmap
    else:
        start_pop = founder_pop
        start_map = genetic_map

    # 2. Initialize State
    initial_state = BreedingState(
        population=start_pop,
        key=key,
        generation=0,
        next_id=n_pop if founder_pop is None else jnp.max(start_pop.meta[:, 0]) + 1
    )

    # 3. Compile and Run Scan
    # We use partial to bake in static arguments required for JIT compilation
    scan_fn = partial(
        _burnin_step, 
        genetic_map=start_map, 
        n_pop=n_pop, 
        max_crossovers=max_crossovers
    )

    final_state, _ = lax.scan(scan_fn, initial_state, None, length=n_gens)
    
    # 4. Calculate Final Metrics
    final_ld = calculate_adjacent_ld(final_state.population)

    return final_state, final_ld, start_map



In [3]:
#| hide
import nbdev; nbdev.nbdev_export()