In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from flax.struct import dataclass
from jax.tree_util import tree_map


# ---------------------------
# Datatypes (immutable PyTrees)
# ---------------------------

@dataclass(frozen=True)
class Population:
    geno: jnp.ndarray  # (n_ind, n_chr, 2, n_loci)
    ibd: jnp.ndarray   # (n_ind, n_chr, 2, n_loci) integer labels per hap-locus
    meta: jnp.ndarray  # (n_ind, 4) -> [id, mother_id, father_id, birth_gen] (int32)


@dataclass(frozen=True)
class Trait:
    qtl_chromosome: jnp.ndarray  # (n_qtl,) int32
    qtl_position: jnp.ndarray    # (n_qtl,) int32
    qtl_effects: jnp.ndarray     # (n_qtl, n_traits) float32
    intercept: jnp.ndarray       # (n_traits,) float32


@dataclass(frozen=True)
class GeneticMap:
    # locus_positions[c, :] is a sorted vector of positions for chromosome c
    chromosome_lengths: jnp.ndarray  # (n_chr,) float32
    locus_positions: jnp.ndarray     # (n_chr, n_loci) float32


# --------------------------------
# Initialization / Founders & Map
# --------------------------------

def quick_haplo(
    key: jax.Array,
    n_ind: int,
    n_chr: int,
    seg_sites: int,
    inbred: bool = False,
    chr_length: float = 1.0,
) -> Tuple[Population, GeneticMap]:
    """Create a simple founder population and a uniform genetic map."""
    if inbred:
        haplotypes = jax.random.randint(
            key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.int8
        )
        geno = jnp.broadcast_to(haplotypes, (n_ind, n_chr, 2, seg_sites))
    else:
        geno = jax.random.randint(
            key, (n_ind, n_chr, 2, seg_sites), 0, 2, dtype=jnp.int8
        )

    # IBD labels: unique per haplotype across individuals
    ibd = jnp.broadcast_to(
        jnp.arange(n_ind * 2, dtype=jnp.int32).reshape(n_ind, 1, 2, 1),
        (n_ind, n_chr, 2, seg_sites),
    )

    # Meta (int32 all the way, stack doesn't accept dtype= kwarg)
    meta = jnp.stack(
        [
            jnp.arange(n_ind, dtype=jnp.int32),               # id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # mother_id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # father_id
            jnp.zeros((n_ind,), dtype=jnp.int32),             # birth_gen
        ],
        axis=-1,
    )

    population = Population(geno=geno, ibd=ibd, meta=meta)

    # Genetic map: same number of loci per chromosome, uniform spacing
    chromosome_lengths = jnp.full((n_chr,), chr_length, dtype=jnp.float32)
    locus_positions = jnp.stack(
        [jnp.linspace(0.0, chr_length, seg_sites, dtype=jnp.float32) for _ in range(n_chr)],
        axis=0,
    )
    genetic_map = GeneticMap(
        chromosome_lengths=chromosome_lengths, locus_positions=locus_positions
    )
    return population, genetic_map


# ----------------------------
# Trait architecture utilities
# ----------------------------

def _flatten_gather_chr_locus(dosage_chr_locus: jnp.ndarray,
                              chr_idx: jnp.ndarray,
                              locus_idx: jnp.ndarray) -> jnp.ndarray:
    """
    dosage_chr_locus: (n_ind, n_chr, n_loci)
    chr_idx, locus_idx: (n_qtl,)
    returns: (n_ind, n_qtl) gathered pairwise along (chr,locus).
    """
    n_chr, n_loci = dosage_chr_locus.shape[1], dosage_chr_locus.shape[2]
    flat = dosage_chr_locus.reshape(dosage_chr_locus.shape[0], n_chr * n_loci)
    flat_ids = chr_idx * n_loci + locus_idx
    return jnp.take(flat, flat_ids, axis=1)


def add_trait(
    key: jax.Array,
    founder_pop: Population,
    n_qtl_per_chr: int,
    mean: jnp.ndarray,   # (n_traits,)
    var: jnp.ndarray,    # (n_traits,)
    sigma: jnp.ndarray,  # (n_traits, n_traits) PSD
) -> Trait:
    """Sample QTLs and multi-trait effects with covariance `sigma`."""
    key, qtl_key, effect_key = jax.random.split(key, 3)

    n_chr = founder_pop.geno.shape[1]
    n_loci_per_chr = founder_pop.geno.shape[3]

    n_total_qtl = n_qtl_per_chr * n_chr
    all_loci_indices = jnp.arange(n_chr * n_loci_per_chr, dtype=jnp.int32)
    qtl_loc_flat = jax.random.choice(
        qtl_key, all_loci_indices, (n_total_qtl,), replace=False
    )
    qtl_chromosome, qtl_position = jnp.divmod(jnp.sort(qtl_loc_flat), n_loci_per_chr)

    n_traits = int(mean.shape[0])
    raw_effects = jax.random.normal(effect_key, (n_total_qtl, n_traits), dtype=jnp.float32)
    cholesky_factor = jnp.linalg.cholesky(sigma.astype(jnp.float32))  # (T, T)
    effects = raw_effects @ cholesky_factor.T  # (n_qtl, T)

    # Use founders to scale effects to desired trait variances
    founder_dosage = jnp.sum(founder_pop.geno, axis=2, dtype=jnp.int32)  # (n, chr, loci)
    qtl_dosage = _flatten_gather_chr_locus(founder_dosage, qtl_chromosome, qtl_position).astype(jnp.float32)
    gvs = qtl_dosage @ effects  # (n, T)

    # Scale to hit target marginal variances
    scale = jnp.sqrt(var.astype(jnp.float32) / (jnp.var(gvs, axis=0) + 1e-8))
    intercept = mean.astype(jnp.float32) - jnp.mean(gvs, axis=0) * scale
    final_effects = effects * scale  # (n_qtl, T)

    return Trait(
        qtl_chromosome=qtl_chromosome.astype(jnp.int32),
        qtl_position=qtl_position.astype(jnp.int32),
        qtl_effects=final_effects,
        intercept=intercept,
    )


# ----------------------
# Phenotype computation
# ----------------------

@jax.jit
def calculate_phenotypes(
    key: jax.Array,
    population: Population,
    trait: Trait,
    heritability: jnp.ndarray,  # (n_traits,)
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Return (phenotypes, true_breeding_values)."""
    # Dosage per locus (0-2) as float
    dosage = jnp.sum(population.geno, axis=2, dtype=jnp.int32)  # (n, chr, loci)
    qtl_dosage = _flatten_gather_chr_locus(
        dosage, trait.qtl_chromosome, trait.qtl_position
    ).astype(jnp.float32)  # (n, n_qtl)

    tbv = trait.intercept + qtl_dosage @ trait.qtl_effects  # (n, n_traits)
    add_var = jnp.var(tbv, axis=0)  # (n_traits,)

    h2 = jnp.clip(heritability.astype(jnp.float32), 1e-8, 1.0 - 1e-8)
    env_var = jnp.maximum(0.0, add_var * ((1.0 / h2) - 1.0))
    noise = jax.random.normal(key, tbv.shape)
    phenotypes = tbv + noise * jnp.sqrt(env_var)

    return phenotypes, tbv


# ----------------------
# Meiosis / Crossing
# ----------------------

@partial(jax.jit, static_argnames=("max_crossovers",))
def create_gamete(
    key: jax.Array,
    parent_geno: jax.Array,  # (n_chr, 2, n_loci)
    parent_ibd: jax.Array,   # (n_chr, 2, n_loci)
    genetic_map: GeneticMap,
    interference_nu: float,
    max_crossovers: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Return a haploid gamete (geno, ibd) of shape (n_chr, n_loci)."""

    def _sample_chiasmata(k: jax.Array, map_length: jnp.ndarray, nu: float, L: int):
        """Renewal process with gamma-distributed inter-chiasma distances."""
        shape, scale = nu, 1.0 / (2.0 * nu)

        def body(carry, _):
            kk, last_pos = carry
            kk, sub = jax.random.split(kk)
            step = jax.random.gamma(sub, shape) * scale
            pos = last_pos + step
            return (kk, pos), pos

        k, start_k = jax.random.split(k)
        start = jax.random.uniform(start_k, (), minval=-10.0, maxval=0.0)

        (_, _), positions = lax.scan(body, (k, start), None, length=L)
        valid = (positions > 0.0) & (positions < map_length)
        sentinel = map_length + 1.0
        # Keep shape static, sort to ensure monotonic positions
        positions = jnp.where(valid, positions, sentinel)
        positions = jnp.sort(positions)
        return positions  # (L,)

    def _recombine_chr(
        k: jax.Array,
        chr_geno: jnp.ndarray,   # (2, n_loci)
        chr_ibd: jnp.ndarray,    # (2, n_loci)
        loci_pos: jnp.ndarray,   # (n_loci,)
        nu: float,
        L: int,
    ):
        k_chiasma, k_hap = jax.random.split(k)
        cross_pos = _sample_chiasmata(k_chiasma, loci_pos[-1], nu, L)         # (L,)
        cross_idx = jnp.searchsorted(loci_pos, cross_pos, side="left")        # (L,)
        # segments = number of crossovers <= each locus index
        segments = jnp.searchsorted(cross_idx, jnp.arange(loci_pos.shape[0]), side="right")
        start_hap = jax.random.randint(k_hap, (), 0, 2, dtype=jnp.int32)
        hap_choice = (start_hap + segments) & 1  # 0/1 per locus

        gam_geno = jnp.where(hap_choice == 0, chr_geno[0], chr_geno[1])  # (n_loci,)
        gam_ibd  = jnp.where(hap_choice == 0, chr_ibd[0],  chr_ibd[1])   # (n_loci,)
        return gam_geno, gam_ibd

    n_chr = parent_geno.shape[0]
    chr_keys = jax.random.split(key, num=n_chr)
    vmapped = jax.vmap(_recombine_chr, in_axes=(0, 0, 0, 0, None, None))
    gamete_geno, gamete_ibd = vmapped(
        chr_keys, parent_geno, parent_ibd, genetic_map.locus_positions, interference_nu, max_crossovers
    )
    return gamete_geno, gamete_ibd  # (n_chr, n_loci), (n_chr, n_loci)


@partial(jax.jit, static_argnames=("max_crossovers",))
def cross_pair(
    key: jax.Array,
    mother_geno: jax.Array, father_geno: jax.Array,   # (n_chr, 2, n_loci)
    mother_ibd: jax.Array,  father_ibd: jax.Array,    # (n_chr, 2, n_loci)
    genetic_map: GeneticMap,
    max_crossovers: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    km, kf = jax.random.split(key)
    m_geno, m_ibd = create_gamete(km, mother_geno, mother_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    f_geno, f_ibd = create_gamete(kf, father_geno, father_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    offspring_geno = jnp.stack([m_geno, f_geno], axis=1)  # (n_chr, 2, n_loci)
    offspring_ibd  = jnp.stack([m_ibd,  f_ibd],  axis=1)  # (n_chr, 2, n_loci)
    return offspring_geno, offspring_ibd


# ---------------
# Breeder actions
# ---------------

@partial(jax.jit, static_argnames=("k",))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    """Select top-k individuals by `values` (1-D)."""
    vals = jnp.ravel(values).astype(jnp.float32)  # (n_ind,)
    _, top_idx = lax.top_k(vals, k)               # (k,)
    return tree_map(lambda x: x[top_idx], population)


@partial(jax.jit, static_argnames=("n_crosses",))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    """Sample (mother, father) pairs with replacement.
    Works even when `n_parents` is not a Python literal.
    """
    k_m, k_f = jax.random.split(key)
    mothers = jax.random.randint(k_m, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    fathers = jax.random.randint(k_f, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    return jnp.stack([mothers, fathers], axis=-1)

# -------------------------
# Demo / One-generation run
# -------------------------

def main():
    print("--- Running Full ChewC Breeding Cycle ---")

    n_individuals, n_chromosomes, n_loci_per_chr = 2000, 10, 1000
    seed, selection_intensity, n_offspring = 42, 0.1, 2000
    max_crossovers = 20

    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    print("\nStep 1: Generating founder population (G0)...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key,
        n_ind=n_individuals,
        n_chr=n_chromosomes,
        seg_sites=n_loci_per_chr,
        inbred=False,
        chr_length=1.0,
    )

    print("\nStep 2: Generating a two-trait architecture...")
    cov = -0.3 * jnp.sqrt(jnp.array(20.0, dtype=jnp.float32))  # -0.3 * sqrt(10*2)
    sigma = jnp.array([[10.0, cov], [cov, 2.0]], dtype=jnp.float32)
    trait_architecture = add_trait(
        key=trait_key,
        founder_pop=founder_pop,
        n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0], dtype=jnp.float32),
        var=jnp.array([10.0, 2.0], dtype=jnp.float32),
        sigma=sigma,
    )

    print("\nStep 3: Calculating phenotypes for G0...")
    heritabilities = jnp.array([0.4, 0.7], dtype=jnp.float32)
    g0_phenotypes, g0_tbvs = calculate_phenotypes(
        key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities
    )
    mean_g0_tbv = float(jnp.mean(g0_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Selected top {n_select} and created {n_offspring} pairs.")

    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno, fathers_geno = selected_parents.geno[mother_indices], selected_parents.geno[father_indices]
    mothers_ibd, fathers_ibd = selected_parents.ibd[mother_indices], selected_parents.ibd[father_indices]

    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)

    g1_geno, g1_ibd = vmapped_cross(
        offspring_keys,
        mothers_geno,
        fathers_geno,
        mothers_ibd,
        fathers_ibd,
        genetic_map,
        max_crossovers,
    )
    print(f"  - G1 genotypes created. Shape: {tuple(g1_geno.shape)}")

    g1_meta = jnp.stack(
        [
            jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
            selected_parents.meta[mother_indices, 0],
            selected_parents.meta[father_indices, 0],
            jnp.ones((n_offspring,), dtype=jnp.int32),
        ],
        axis=-1,
    )
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)

    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(
        key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities
    )
    mean_g1_tbv = float(jnp.mean(g1_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")

    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation: {genetic_gain:+.3f} ---")


if __name__ == "__main__":
    main()

In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from flax.struct import dataclass
from jax.tree_util import tree_map



# --------------------------------
# Initialization / Founders & Map
# --------------------------------

# ----------------------------
# Trait architecture utilities
# ----------------------------


# ----------------------
# Phenotype computation
# ----------------------


# ----------------------
# Meiosis / Crossing
# ----------------------




# ---------------
# Breeder actions
# ---------------



# -------------------------
# Demo / One-generation run
# -------------------------

def main():
    print("--- Running Full ChewC Breeding Cycle ---")

    n_individuals, n_chromosomes, n_loci_per_chr = 2000, 10, 1000
    seed, selection_intensity, n_offspring = 42, 0.1, 2000
    max_crossovers = 20

    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    print("\nStep 1: Generating founder population (G0)...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key,
        n_ind=n_individuals,
        n_chr=n_chromosomes,
        seg_sites=n_loci_per_chr,
        inbred=False,
        chr_length=1.0,
    )

    print("\nStep 2: Generating a two-trait architecture...")
    cov = -0.3 * jnp.sqrt(jnp.array(20.0, dtype=jnp.float32))  # -0.3 * sqrt(10*2)
    sigma = jnp.array([[10.0, cov], [cov, 2.0]], dtype=jnp.float32)
    trait_architecture = add_trait(
        key=trait_key,
        founder_pop=founder_pop,
        n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0], dtype=jnp.float32),
        var=jnp.array([10.0, 2.0], dtype=jnp.float32),
        sigma=sigma,
    )

    print("\nStep 3: Calculating phenotypes for G0...")
    heritabilities = jnp.array([0.4, 0.7], dtype=jnp.float32)
    g0_phenotypes, g0_tbvs = calculate_phenotypes(
        key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities
    )
    mean_g0_tbv = float(jnp.mean(g0_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Selected top {n_select} and created {n_offspring} pairs.")

    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno, fathers_geno = selected_parents.geno[mother_indices], selected_parents.geno[father_indices]
    mothers_ibd, fathers_ibd = selected_parents.ibd[mother_indices], selected_parents.ibd[father_indices]

    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)

    g1_geno, g1_ibd = vmapped_cross(
        offspring_keys,
        mothers_geno,
        fathers_geno,
        mothers_ibd,
        fathers_ibd,
        genetic_map,
        max_crossovers,
    )
    print(f"  - G1 genotypes created. Shape: {tuple(g1_geno.shape)}")

    g1_meta = jnp.stack(
        [
            jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
            selected_parents.meta[mother_indices, 0],
            selected_parents.meta[father_indices, 0],
            jnp.ones((n_offspring,), dtype=jnp.int32),
        ],
        axis=-1,
    )
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)

    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(
        key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities
    )
    mean_g1_tbv = float(jnp.mean(g1_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")

    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation: {genetic_gain:+.3f} ---")


if __name__ == "__main__":
    main()

In [1]:
from chewc.structs import *
from chewc.pheno import *
from chewc.select import *
from chewc.cross import *
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from flax.struct import dataclass
from jax.tree_util import tree_map



@partial(jax.jit, static_argnames=("n_select", "n_offspring", "max_crossovers"))
def selection_step(
    carry: BreedingState,
    _, # Placeholder for lax.scan's iteration number
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
) -> Tuple[BreedingState, jnp.ndarray]:
    """
    Executes one full cycle of selection and breeding.
    This function is designed to be the body of a lax.scan loop.
    """
    key, pheno_key, mating_key, cross_key = jax.random.split(carry.key, 4)
    current_pop = carry.population

    # 1. Evaluate the population
    phenotypes, tbvs = calculate_phenotypes(
        key=pheno_key, population=current_pop, trait=trait, heritability=heritabilities
    )
    mean_tbv = jnp.mean(tbvs[:, 0]) # Track genetic gain for trait 1

    # 2. Select top parents
    selected_parents = select_top_k(current_pop, phenotypes[:, 0], k=n_select)

    # 3. Generate a random mating plan
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]

    # 4. Create the next generation
    mothers_geno = selected_parents.geno[mother_indices]
    fathers_geno = selected_parents.geno[father_indices]
    mothers_ibd = selected_parents.ibd[mother_indices]
    fathers_ibd = selected_parents.ibd[father_indices]

    # Vectorize the crossing operation over all pairs
    vmapped_cross = jax.vmap(
        cross_pair, in_axes=(0, 0, 0, 0, 0, None, None)
    )
    offspring_keys = jax.random.split(cross_key, n_offspring)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map, max_crossovers
    )

    # 5. Form the new population and update the state
    new_generation = carry.generation + 1
    new_ids = jnp.arange(n_offspring, dtype=jnp.int32) + carry.next_id
    new_meta = jnp.stack(
        [
            new_ids,
            selected_parents.meta[mother_indices, 0],
            selected_parents.meta[father_indices, 0],
            jnp.full((n_offspring,), new_generation, dtype=jnp.int32),
        ],
        axis=-1,
    )
    new_population = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)

    next_state = BreedingState(
        population=new_population,
        key=key,
        generation=new_generation,
        next_id=carry.next_id + n_offspring
    )

    # The scan function requires a `(carry, output)` return signature.
    # `carry` is the state for the next iteration.
    # `output` is the data we want to collect at each step.
    return next_state, mean_tbv


def run_simulation_cycles(
    initial_state: BreedingState,
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_cycles: int,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
):
    """
    Runs the entire multi-cycle simulation using lax.scan for maximum performance.
    """
    # `lax.scan` requires the body function to be a callable that takes
    # (carry, x) and returns (new_carry, y). We use a lambda to fix the
    # static arguments and data that doesn't change over the loop.
    scan_fn = lambda carry, _: selection_step(
        carry,
        _,
        trait=trait,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_select=n_select,
        n_offspring=n_offspring,
        max_crossovers=max_crossovers
    )

    # Run the scan. The `None` is a placeholder for the `xs` array,
    # as we only care about the number of iterations (`length`).
    final_state, tbv_history = lax.scan(
        scan_fn, initial_state, None, length=n_cycles
    )

    return final_state, tbv_history


# ----------------------------------------------------
# Main execution script
# ----------------------------------------------------

if __name__ == "__main__":
    print("--- Running 10-Cycle Phenotypic Selection Experiment ---")

    # --- Simulation Parameters ---
    N_FOUNDERS = 100
    N_SELECT = 100
    N_OFFSPRING = 100 # Population size is kept constant
    N_CYCLES = 50

    N_CHR, N_LOCI = 5, 1000
    MAX_CROSSOVERS = 10
    SEED = 42

    key = jax.random.PRNGKey(SEED)

    # --- Setup ---
    key, pop_key, trait_key = jax.random.split(key, 3)
    print("\nStep 1: Initializing founder population...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI,
        inbred=False, chr_length=1.0
    )

    print("Step 2: Initializing trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([0.0]), var=jnp.array([1.0]),
        sigma=jnp.array([[10.0]])
    )
    heritabilities = jnp.array([0.4])

    # --- Initial State ---
    initial_state = BreedingState(
        population=founder_pop,
        key=key,
        generation=0,
        next_id=N_FOUNDERS # Next available individual ID
    )
    print(f"  - Founder population size: {N_FOUNDERS}")
    print(f"  - Selection: Top {N_SELECT} individuals")
    print(f"  - Offspring per cycle: {N_OFFSPRING}")


    # --- Run Simulation ---
    print(f"\nStep 3: Running {N_CYCLES} selection cycles (JIT compiling...)\n")
    final_state, tbv_history = run_simulation_cycles(
        initial_state=initial_state,
        trait=trait_architecture,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_cycles=N_CYCLES,
        n_select=N_SELECT,
        n_offspring=N_OFFSPRING,
        max_crossovers=MAX_CROSSOVERS,
    )

    # We need to calculate the TBV of the initial generation to show the full trend
    key, initial_pheno_key = jax.random.split(initial_state.key)
    _, initial_tbvs = calculate_phenotypes(
        key=initial_pheno_key, population=initial_state.population,
        trait=trait_architecture, heritability=heritabilities
    )
    initial_mean_tbv = jnp.mean(initial_tbvs[:, 0])

    # --- Report Results ---
    print("--- Results ---")
    print(f"Generation 00 Mean TBV: {initial_mean_tbv:.3f}")

    # Combine initial TBV with the history from the scan
    full_history = jnp.concatenate([jnp.array([initial_mean_tbv]), tbv_history])
    for i, tbv in enumerate(tbv_history, 1):
        gain = tbv - full_history[i-1]
        print(f"Generation {i:02d} Mean TBV: {tbv:.3f} (Gain: {gain:+.3f})")

    total_gain = tbv_history[-1] - initial_mean_tbv
    print(f"\n--- Total Genetic Gain over {N_CYCLES} generations: {total_gain:+.3f} ---")



--- Running 10-Cycle Phenotypic Selection Experiment ---

Step 1: Initializing founder population...
Step 2: Initializing trait architecture...
  - Founder population size: 100
  - Selection: Top 100 individuals
  - Offspring per cycle: 100

Step 3: Running 50 selection cycles (JIT compiling...)

--- Results ---
Generation 00 Mean TBV: -0.000
Generation 01 Mean TBV: -0.000 (Gain: +0.000)
Generation 02 Mean TBV: 0.034 (Gain: +0.034)
Generation 03 Mean TBV: -0.012 (Gain: -0.047)
Generation 04 Mean TBV: 0.120 (Gain: +0.132)
Generation 05 Mean TBV: 0.175 (Gain: +0.056)
Generation 06 Mean TBV: 0.030 (Gain: -0.145)
Generation 07 Mean TBV: -0.040 (Gain: -0.070)
Generation 08 Mean TBV: -0.002 (Gain: +0.038)
Generation 09 Mean TBV: 0.016 (Gain: +0.017)
Generation 10 Mean TBV: 0.110 (Gain: +0.095)
Generation 11 Mean TBV: 0.163 (Gain: +0.053)
Generation 12 Mean TBV: 0.056 (Gain: -0.107)
Generation 13 Mean TBV: 0.159 (Gain: +0.104)
Generation 14 Mean TBV: 0.252 (Gain: +0.093)
Generation 15 Mean TB