In [None]:
#| default_exp structs

In [None]:
#|export

from flax.struct import dataclass
import jax
import jax.numpy as jnp
from typing import Optional, Tuple, List

@dataclass(frozen=True)
class GeneticMap:
    """
    Defines the genetic map of the organism.

    Attributes:
        chromosome_lengths (jnp.ndarray): The genetic length of each chromosome in Morgans.
            Shape: (n_chr,)
        locus_positions (List[jnp.ndarray]): A list where each element is a JAX array
            containing the Morgan positions of loci for a single chromosome.
    """
    chromosome_lengths: jnp.ndarray
    locus_positions: List[jnp.ndarray]


@dataclass(frozen=True)
class Population:
    """
    Represents a collection of individuals within a breeding program.

    Attributes:
        geno (jnp.ndarray): Genotypes.
            Shape: (n_individuals, n_chr, 2, n_loci)
        ibd (jnp.ndarray): Identity-by-descent tracking for founder alleles.
            Shape: (n_individuals, n_chr, 2, n_loci)
        meta (jnp.ndarray): Metadata [id, mother_id, father_id, birth_gen].
            Shape: (n_individuals, 4)
    """
    geno: jnp.ndarray
    ibd: jnp.ndarray
    meta: jnp.ndarray

@dataclass(frozen=True)
class Trait:
    """
    Defines the genetic architecture of one or more traits.

    This structure holds the information linking genotypes to phenotypes,
    based on an additive and dominance QTL model.

    Attributes:
        qtl_chromosome (jnp.ndarray): Chromosome index for each QTL.
            Shape: (n_qtl,)
        qtl_position (jnp.ndarray): Locus index for each QTL within its chromosome.
            Shape: (n_qtl,)
        qtl_effects (jnp.ndarray): The additive effect of each QTL on each trait.
            Shape: (n_qtl, n_traits)
        qtl_dominance_effects (jnp.ndarray): The dominance effect of each QTL on each trait.
            Shape: (n_qtl, n_traits)
        intercept (jnp.ndarray): The base value for each trait.
            Shape: (n_traits,)
    """
    qtl_chromosome: jnp.ndarray
    qtl_position: jnp.ndarray
    qtl_effects: jnp.ndarray
    qtl_dominance_effects: jnp.ndarray
    intercept: jnp.ndarray


@dataclass(frozen=True)
class BreedingState:
    """
    The complete, self-contained state of a single simulation replicate at a point in time.

    This object is the primary carrier passed between steps in the simulation.

    Attributes:
        population (Population): The current population of individuals.
        key (jax.Array): The JAX pseudo-random number generator key.
        generation (int): The current generation number of the simulation.
    """
    population: Population
    key: jax.Array
    generation: int
    next_id: int # To ensure unique IDs for new individuals


In [None]:
#|export

import jax
import jax.numpy as jnp
from typing import Tuple



def quick_haplo(
    key: jax.Array,
    n_ind: int,
    n_chr: int,
    seg_sites: int,
    inbred: bool = False,
    chr_length: float = 1.0,
) -> Tuple[Population, GeneticMap]:
    """Create a simple founder population and a uniform genetic map."""
    if inbred:
        haplotypes = jax.random.randint(
            key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.int8
        )
        geno = jnp.broadcast_to(haplotypes, (n_ind, n_chr, 2, seg_sites))
    else:
        geno = jax.random.randint(
            key, (n_ind, n_chr, 2, seg_sites), 0, 2, dtype=jnp.int8
        )

    # IBD labels: unique per haplotype across individuals
    ibd = jnp.broadcast_to(
        jnp.arange(n_ind * 2, dtype=jnp.int32).reshape(n_ind, 1, 2, 1),
        (n_ind, n_chr, 2, seg_sites),
    )

    # Meta (int32 all the way, stack doesn't accept dtype= kwarg)
    meta = jnp.stack(
        [
            jnp.arange(n_ind, dtype=jnp.int32),               # id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # mother_id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # father_id
            jnp.zeros((n_ind,), dtype=jnp.int32),             # birth_gen
        ],
        axis=-1,
    )

    population = Population(geno=geno, ibd=ibd, meta=meta)

    # Genetic map: same number of loci per chromosome, uniform spacing
    chromosome_lengths = jnp.full((n_chr,), chr_length, dtype=jnp.float32)
    locus_positions = jnp.stack(
        [jnp.linspace(0.0, chr_length, seg_sites, dtype=jnp.float32) for _ in range(n_chr)],
        axis=0,
    )
    genetic_map = GeneticMap(
        chromosome_lengths=chromosome_lengths, locus_positions=locus_positions
    )
    return population, genetic_map

#| export

#| export

import jax
import jax.numpy as jnp
import numpy as np
from numpy.random import default_rng
import msprime
import pandas as pd
import matplotlib.pyplot as plt

from typing import Optional, Tuple
from IPython.display import display

from chewc.structs import Population, GeneticMap, add_trait
from chewc.pheno import calculate_phenotypes

def msprime_pop(
    key: jax.Array,
    n_ind: int,
    n_chr: int,
    n_loci_per_chr: int,
    ploidy: int = 2,
    effective_population_size: int = 5_000,
    mutation_rate: float = 1e-7,
    recombination_rate_per_chr: float = 1e-8,
    maf_threshold: float = 0.05,
    base_chr_length: int = 500_000,
    num_simulated_individuals: Optional[int] = None,
    enforce_founder_maf: bool = True,
) -> Tuple[Population, GeneticMap]:
    """Simulate a padded founder population from an msprime ancestry model.

    The output matches the lightweight `Population`/`GeneticMap` structures used by the
    workflow notebooks so it can plug straight into trait sampling, phenotype simulation,
    and selection routines.
    """
    if n_ind <= 0:
        raise ValueError('n_ind must be positive.')
    if n_chr <= 0:
        raise ValueError('n_chr must be positive.')
    if n_loci_per_chr <= 0:
        raise ValueError('n_loci_per_chr must be positive.')

    if num_simulated_individuals is None:
        multiplier = 5 if enforce_founder_maf else 2
        num_simulated_individuals = max(n_ind * multiplier, 2 * n_ind)
    if num_simulated_individuals < n_ind:
        raise ValueError('num_simulated_individuals must be >= n_ind.')

    key, seed_key, numpy_seed_key = jax.random.split(key, 3)
    random_seed = int(jax.random.randint(seed_key, (), 0, 2**31 - 1).item())
    numpy_seed = int(jax.random.randint(numpy_seed_key, (), 0, 2**31 - 1).item())
    rng = default_rng(numpy_seed)

    chromosome_lengths = [base_chr_length] * n_chr
    cumulative = np.cumsum([0] + chromosome_lengths)
    rate_map = msprime.RateMap(
        position=cumulative,
        rate=[recombination_rate_per_chr] * n_chr,
    )

    ts = msprime.sim_ancestry(
        samples=num_simulated_individuals * ploidy,
        population_size=effective_population_size,
        recombination_rate=rate_map,
        random_seed=random_seed,
    )
    mts = msprime.sim_mutations(ts, rate=mutation_rate, random_seed=random_seed + 1)

    true_num_individuals = mts.num_samples // ploidy
    founder_indices = np.sort(rng.choice(true_num_individuals, n_ind, replace=False))

    geno = np.zeros((n_ind, n_chr, ploidy, n_loci_per_chr), dtype=np.uint8)
    all_chr_positions_list = [] # Temporary list
    chr_lengths_cm = []

    all_variants = list(mts.variants())

    for chr_idx in range(n_chr):
        chr_left, chr_right = rate_map.left[chr_idx], rate_map.right[chr_idx]
        chr_variants = [
            var for var in all_variants
            if chr_left <= var.site.position < chr_right and len(var.alleles) == 2
        ]

        if not chr_variants:
            raise RuntimeError(f'No bi-allelic SNPs found on chromosome {chr_idx}.')

        def maf_in_founders(var):
            g = var.genotypes.reshape(true_num_individuals, ploidy)
            founder_g = g[founder_indices]
            # Use mean(axis=None) for JAX compatibility if needed, np.mean() is fine here
            p = founder_g.mean() / ploidy # Correct MAF calculation for diploid
            return min(p, 1 - p)

        if enforce_founder_maf:
            eligible = [var for var in chr_variants if maf_in_founders(var) >= maf_threshold]
        else:
            eligible = [var for var in chr_variants if maf_in_founders(var) > 0]

        if len(eligible) < n_loci_per_chr:
            raise RuntimeError(
                f'Only {len(eligible)} SNPs passed the MAF filter on chromosome {chr_idx}; '
                'consider lowering maf_threshold or increasing the mutation rate.'
            )

        selected_idx = rng.choice(len(eligible), n_loci_per_chr, replace=False)
        selected = [eligible[i] for i in selected_idx]
        selected.sort(key=lambda v: v.site.position)

        chr_positions = []
        for snp_pos, snp in enumerate(selected):
            genotypes = snp.genotypes.reshape(true_num_individuals, ploidy)
            founder_genotypes = genotypes[founder_indices]
            geno[:, chr_idx, :, snp_pos] = founder_genotypes

            pos_cm = (snp.site.position - chr_left) * recombination_rate_per_chr * 100.0
            chr_positions.append(pos_cm)

        # Append positions for this chromosome (as JAX array) to the temp list
        all_chr_positions_list.append(jnp.array(chr_positions, dtype=jnp.float32))
        chr_lengths_cm.append(chr_positions[-1] if chr_positions else 0.0)

    ibd = np.arange(n_ind * n_chr * ploidy * n_loci_per_chr, dtype=np.int32)
    ibd = ibd.reshape(n_ind, n_chr, ploidy, n_loci_per_chr)

    meta = np.stack([
        np.arange(n_ind, dtype=np.int32),
        np.full(n_ind, -1, dtype=np.int32),
        np.full(n_ind, -1, dtype=np.int32),
        np.zeros(n_ind, dtype=np.int32),
    ], axis=-1)

    population = Population(
        geno=jnp.array(geno, dtype=jnp.uint8),
        ibd=jnp.array(ibd, dtype=jnp.int32),
        meta=jnp.array(meta, dtype=jnp.int32),
    )

    chromosome_lengths_cm = jnp.array(chr_lengths_cm, dtype=jnp.float32)

    # --- STACK the list of arrays into a single JAX array ---
    # Ensure all arrays have the same length (n_loci_per_chr), pad if necessary.
    # Stacking assumes consistent shapes, which your current logic should provide.
    stacked_locus_positions = jnp.stack(all_chr_positions_list) # Now shape (n_chr, n_loci_per_chr)

    genetic_map = GeneticMap(
        chromosome_lengths=chromosome_lengths_cm,
        locus_positions=stacked_locus_positions, # Use the stacked array
    )

    return population, genetic_map

In [None]:
#|export

import jax.numpy as jnp
import jax

@jax.jit
def compute_dosage(population: Population) -> jnp.ndarray:
    """
    Computes the dosage of the '1' allele for each individual at each locus.

    Dosage is calculated as the sum of alleles across the two homologous
    chromosomes, resulting in values of 0, 1, or 2.

    Args:
        population: A `Population` object.

    Returns:
        A 3D array of dosages with shape (n_individuals, n_chr, n_loci).
    """
    # geno shape: (n_individuals, n_chr, 2, n_loci)
    # Sum across the chromosome pair axis (axis=2)
    return jnp.sum(population.geno, axis=2)


def _flatten_gather_chr_locus(dosage_chr_locus: jnp.ndarray,
                              chr_idx: jnp.ndarray,
                              locus_idx: jnp.ndarray) -> jnp.ndarray:
    """
    dosage_chr_locus: (n_ind, n_chr, n_loci)
    chr_idx, locus_idx: (n_qtl,)
    returns: (n_ind, n_qtl) gathered pairwise along (chr,locus).
    """
    n_chr, n_loci = dosage_chr_locus.shape[1], dosage_chr_locus.shape[2]
    flat = dosage_chr_locus.reshape(dosage_chr_locus.shape[0], n_chr * n_loci)
    flat_ids = chr_idx * n_loci + locus_idx
    return jnp.take(flat, flat_ids, axis=1)





def add_trait(
    key: jax.Array,
    founder_pop: Population,
    n_qtl_per_chr: int,
    mean: jnp.ndarray,      # (n_traits,)
    var_a: jnp.ndarray,     # Additive variance (n_traits,)
    var_d: jnp.ndarray,     # Dominance variance (n_traits,)
    sigma: jnp.ndarray,     # (n_traits, n_traits) PSD
) -> Trait:
    """Sample QTLs and multi-trait effects for both additive and dominance components."""
    key, qtl_key, effect_key, dom_effect_key = jax.random.split(key, 4)

    n_chr = founder_pop.geno.shape[1]
    n_loci_per_chr = founder_pop.geno.shape[3]
    n_traits = int(mean.shape[0])
    n_total_qtl = n_qtl_per_chr * n_chr

    # --- 1. Sample QTL locations (same for additive and dominance) ---
    all_loci_indices = jnp.arange(n_chr * n_loci_per_chr, dtype=jnp.int32)
    qtl_loc_flat = jax.random.choice(qtl_key, all_loci_indices, (n_total_qtl,), replace=False)
    qtl_chromosome, qtl_position = jnp.divmod(jnp.sort(qtl_loc_flat), n_loci_per_chr)

    # --- 2. Sample and correlate raw effects ---
    cholesky_factor = jnp.linalg.cholesky(sigma.astype(jnp.float32))
    
    # Additive effects
    raw_add_effects = jax.random.normal(effect_key, (n_total_qtl, n_traits), dtype=jnp.float32)
    add_effects = raw_add_effects @ cholesky_factor.T
    
    # Dominance effects
    raw_dom_effects = jax.random.normal(dom_effect_key, (n_total_qtl, n_traits), dtype=jnp.float32)
    dom_effects = raw_dom_effects @ cholesky_factor.T

    # --- 3. Compute founder dosages and scale effects ---
    founder_dosage_full = jnp.sum(founder_pop.geno, axis=2, dtype=jnp.int32)
    qtl_dosage = _flatten_gather_chr_locus(founder_dosage_full, qtl_chromosome, qtl_position).astype(jnp.float32)
    
    # Scale additive effects
    add_gvs = qtl_dosage @ add_effects
    scale_a = jnp.sqrt(var_a.astype(jnp.float32) / (jnp.var(add_gvs, axis=0) + 1e-8))
    final_add_effects = add_effects * scale_a

    # Scale dominance effects
    # Dominance dosage is 1 for heterozygotes (additive dosage is 1), and 0 otherwise.
    # This is a clean, vectorized way to compute it.
    dominance_qtl_dosage = (qtl_dosage == 1).astype(jnp.float32)
    dom_gvs = dominance_qtl_dosage @ dom_effects
    scale_d = jnp.sqrt(var_d.astype(jnp.float32) / (jnp.var(dom_gvs, axis=0) + 1e-8))
    final_dom_effects = dom_effects * scale_d
    
    # --- 4. Calculate intercept based on total genetic value ---
    # The intercept should center the total genetic value (A + D) in the founder population.
    total_gvs = (qtl_dosage @ final_add_effects) + (dominance_qtl_dosage @ final_dom_effects)
    intercept = mean.astype(jnp.float32) - jnp.mean(total_gvs, axis=0)

    return Trait(
        qtl_chromosome=qtl_chromosome.astype(jnp.int32),
        qtl_position=qtl_position.astype(jnp.int32),
        qtl_effects=final_add_effects,
        qtl_dominance_effects=final_dom_effects,
        intercept=intercept,
    )


In [None]:
key = jax.random.PRNGKey(42)
founder_pop = quick_haplo(
key=key,
n_ind=10,
n_chr=3,
seg_sites=100,
inbred=False
 )



In [None]:
#| hide
import nbdev; nbdev.nbdev_export()