In [5]:
#| default_exp workflow

In [6]:
#|export
import jax
import jax.numpy as jnp

# Assuming 'chewc' is installed or in the python path
from chewc.structs import quick_haplo, add_trait


In [7]:
import jax
import jax.numpy as jnp
from chewc.structs import quick_haplo, add_trait
from chewc.pheno import calculate_phenotypes

def main():
    """
    An example script to demonstrate and verify the setup of a 
    ChewC simulation scenario.
    """
    print("--- Setting up and Verifying ChewC simulation scenario ---")

    # --- 1. Define Simulation Parameters (with increased scale) ---
    n_individuals = 2000  # Increased from 50 for better statistical accuracy
    n_chromosomes = 10
    n_loci_per_chr = 1000
    seed = 42
    print(f"\nIncreased scale: Simulating {n_individuals} individuals.")

    # --- 2. Create JAX PRNG keys ---
    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, pheno_key = jax.random.split(key, 3)

    # --- 3. Generate Founder Population ---
    print(f"\nStep 1: Generating founder population...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr
    )
    print(f"  - Population generated. Genotype shape: {founder_pop.geno.shape}")

    # --- 4. Generate Correlated Trait Architecture ---
    print("\nStep 2: Generating a two-trait architecture...")
    target_means = jnp.array([100.0, 50.0])
    target_vars = jnp.array([10.0, 2.0])
    target_genetic_corr = -0.3

    cov = target_genetic_corr * jnp.sqrt(target_vars[0] * target_vars[1])
    sigma = jnp.array([[target_vars[0], cov], [cov, target_vars[1]]])

    trait_architecture = add_trait(
        key=trait_key,
        founder_pop=founder_pop,
        n_qtl_per_chr=50,
        mean=target_means,
        var=target_vars,
        sigma=sigma,
    )
    print(f"  - Trait architecture created.")

    # --- 5. Calculate Phenotypes for the Founder Population ---
    print("\nStep 3: Calculating phenotypes for the founder population...")
    heritabilities = jnp.array([0.4, 0.7]) 
    print(f"  - Target heritabilities (h²): {heritabilities}")

    phenotypes, tbvs = calculate_phenotypes(
        key=pheno_key,
        population=founder_pop,
        trait=trait_architecture,
        heritability=heritabilities
    )
    print(f"  - Phenotypes and TBVs calculated.")

    # --- 6. Verification ---
    print("\n--- Verification Checks ---")
    
    # Check 1: Realized Heritability (h² = V_a / V_p)
    realized_h2 = jnp.var(tbvs, axis=0) / jnp.var(phenotypes, axis=0)
    print("\n[Check 1: Heritability]")
    print(f"  - Target h²:   {heritabilities}")
    print(f"  - Realized h²: {realized_h2.round(3)}")

    # Check 2: Genetic Correlation
    realized_genetic_corr_matrix = jnp.corrcoef(tbvs, rowvar=False)
    print("\n[Check 2: Genetic Correlation]")
    print(f"  - Target:   {target_genetic_corr:.3f}")
    print(f"  - Realized: {realized_genetic_corr_matrix[0, 1]:.3f}")

    # Check 3: Phenotypic Correlation
    # The expected phenotypic correlation is r_p = r_g * h_1 * h_2
    h_1 = jnp.sqrt(heritabilities[0])
    h_2 = jnp.sqrt(heritabilities[1])
    expected_phenotypic_corr = target_genetic_corr * h_1 * h_2
    realized_phenotypic_corr_matrix = jnp.corrcoef(phenotypes, rowvar=False)
    print("\n[Check 3: Phenotypic Correlation]")
    print(f"  - Expected: {expected_phenotypic_corr:.3f}")
    print(f"  - Realized: {realized_phenotypic_corr_matrix[0, 1]:.3f}")

    print("\n--- Verification complete! ---")


if __name__ == "__main__":
    main()


--- Setting up and Verifying ChewC simulation scenario ---

Increased scale: Simulating 2000 individuals.

Step 1: Generating founder population...
  - Population generated. Genotype shape: (2000, 10, 2, 1000)

Step 2: Generating a two-trait architecture...
  - Trait architecture created.

Step 3: Calculating phenotypes for the founder population...
  - Target heritabilities (h²): [0.4 0.7]
  - Phenotypes and TBVs calculated.

--- Verification Checks ---

[Check 1: Heritability]
  - Target h²:   [0.4 0.7]
  - Realized h²: [0.39200002 0.666     ]

[Check 2: Genetic Correlation]
  - Target:   -0.300
  - Realized: -0.329

[Check 3: Phenotypic Correlation]
  - Expected: -0.159
  - Realized: -0.138

--- Verification complete! ---


In [12]:
import jax
import jax.numpy as jnp
from functools import partial

# --- Assuming 'chewc' is in the python path ---
# Datatypes
from chewc.structs import Population, GeneticMap, Trait, quick_haplo, add_trait
# Initialization
from chewc.pheno import calculate_phenotypes
from chewc.cross import cross_pair
from jax import tree

from typing import Any

# --- New functions for this script ---
# These would normally go into `chewc/_internal/selection.py` and `mating.py`

# Define Population as a generic PyTree for the type hint
Population = Any 

@partial(jax.jit, static_argnames=('k',))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    """Selects the top k individuals based on a value array."""
    # Get the indices of the top k individuals by sorting in descending order
    indices = jnp.argsort(values, axis=0)[::-1].squeeze()[:k]

    # Use jax.tree.map to slice all arrays in the Population PyTree
    # This is the modern and correct way to subset a PyTree in JAX
    selected_pop = jax.tree.map(lambda x: x[indices], population)
    
    return selected_pop


@jax.jit
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    """Creates a random mating plan."""
    # Randomly choose `n_crosses` mothers and `n_crosses` fathers from the parents
    mothers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    fathers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    return jnp.stack([mothers, fathers], axis=-1)


def main():
    """
    An example script demonstrating a full breeding cycle with ChewC.
    """
    print("--- Running Full ChewC Breeding Cycle ---")

    # --- 1. Define Simulation Parameters ---
    n_individuals = 2000
    n_chromosomes = 10
    n_loci_per_chr = 1000
    seed = 42
    selection_intensity = 0.1  # Select the top 10%
    n_offspring = n_individuals # Keep population size constant

    # --- 2. Create JAX PRNG keys ---
    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    # --- 3. Generate Founder Population (G0) ---
    print(f"\nStep 1: Generating founder population (G0) of {n_individuals} individuals...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr
    )
    print(f"  - G0 Population generated.")

    # --- 4. Generate Trait Architecture ---
    print("\nStep 2: Generating a two-trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([10.0, 2.0]),
        sigma=jnp.array([[10.0, -0.3 * jnp.sqrt(10*2)], [-0.3 * jnp.sqrt(10*2), 2.0]])
    )
    print(f"  - Trait architecture created.")

    # --- 5. Phenotype G0 and Calculate Founder TBV ---
    print("\nStep 3: Calculating phenotypes for G0 population...")
    heritabilities = jnp.array([0.4, 0.7])
    g0_phenotypes, g0_tbvs = calculate_phenotypes(
        key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities
    )
    mean_g0_tbv = jnp.mean(g0_tbvs[:, 0])
    print(f"  - Phenotypes and TBVs calculated for G0.")
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    # --- 6. Selection and Mating ---
    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    
    # We select based on the phenotype of the first trait
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    print(f"  - Selected the top {n_select} individuals based on Trait 1 phenotype.")

    # Create a mating plan
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Created {n_offspring} random mating pairs.")

    # --- 7. Create Next Generation (G1) ---
    print("\nStep 5: Creating the G1 population...")
    
    # Get the geno/ibd data for all mothers and fathers based on the mating plan
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno = selected_parents.geno[mother_indices]
    fathers_geno = selected_parents.geno[father_indices]
    mothers_ibd = selected_parents.ibd[mother_indices]
    fathers_ibd = selected_parents.ibd[father_indices]

    # The magic of JAX: vmap our JIT-compiled cross_pair function.
    # This creates all offspring in a single, massively parallel operation.
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None))
    
    offspring_keys = jax.random.split(cross_key, n_offspring)
    g1_geno, g1_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map
    )
    print(f"  - G1 genotypes created. Shape: {g1_geno.shape}")

    # Assemble the new Population object
    g1_meta = jnp.stack([
        jnp.arange(n_offspring) + n_individuals, # New IDs
        selected_parents.meta[mother_indices, 0], # Mother IDs
        selected_parents.meta[father_indices, 0], # Father IDs
        jnp.ones(n_offspring)], # Birth generation
        axis=-1, dtype=jnp.int32
    )
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)
    
    # --- 8. Analyze G1 and Genetic Gain ---
    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    
    # Phenotype the new G1 population to find their true breeding values
    # Note: We only need the TBVs, so we can ignore the phenotypes `_`
    _, g1_tbvs = calculate_phenotypes(
        key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities
    )
    mean_g1_tbv = jnp.mean(g1_tbvs[:, 0])

    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")
    
    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation of selection: {genetic_gain:+.3f} ---")


if __name__ == "__main__":
    main()

--- Running Full ChewC Breeding Cycle ---

Step 1: Generating founder population (G0) of 2000 individuals...
  - G0 Population generated.

Step 2: Generating a two-trait architecture...
  - Trait architecture created.

Step 3: Calculating phenotypes for G0 population...
  - Phenotypes and TBVs calculated for G0.
  - Mean TBV of Trait 1 in G0: 100.000

Step 4: Performing selection and random mating...
  - Selected the top 200 individuals based on Trait 1 phenotype.


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The error occurred in jax.random.choice()
The error occurred while tracing the function random_mating at /tmp/ipykernel_173382/2788785294.py:34 for jit. This concrete value was not available in Python because it depends on the value of the argument n_parents.

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [None]:
import jax
import jax.numpy as jnp
from functools import partial

# --- Assuming 'chewc' is in the python path ---
# Datatypes
from chewc.structs import Population, GeneticMap, Trait, quick_haplo, add_trait
# Initialization
from chewc.pheno import calculate_phenotypes
from chewc.cross import cross_pair
from jax import tree

from typing import Any



# --- Functions for this script ---

@partial(jax.jit, static_argnames=('k',))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    """Selects the top k individuals based on a value array."""
    indices = jnp.argsort(values, axis=0)[::-1].squeeze()[:k]
    selected_pop = jax.tree.map(lambda x: x[indices], population)
    return selected_pop

# --- THE FIX IS HERE ---
# `n_parents` must be a static argument because `jax.random.choice` needs
# to know the concrete size of the population it is sampling from at compile time.
@partial(jax.jit, static_argnames=('n_parents', 'n_crosses'))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    """Creates a random mating plan."""
    mothers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    fathers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    return jnp.stack([mothers, fathers], axis=-1)


def main():
    """
    An example script demonstrating a full breeding cycle with ChewC.
    """
    print("--- Running Full ChewC Breeding Cycle ---")

    # --- 1. Define Simulation Parameters ---
    n_individuals = 2000
    n_chromosomes = 10
    n_loci_per_chr = 1000
    seed = 42
    selection_intensity = 0.1
    n_offspring = n_individuals

    # --- 2. Create JAX PRNG keys ---
    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    # --- 3. Generate Founder Population (G0) ---
    print(f"\nStep 1: Generating founder population (G0) of {n_individuals} individuals...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr
    )
    print(f"  - G0 Population generated.")

    # --- 4. Generate Trait Architecture ---
    print("\nStep 2: Generating a two-trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([10.0, 2.0]),
        sigma=jnp.array([[10.0, -0.3 * jnp.sqrt(10*2)], [-0.3 * jnp.sqrt(10*2), 2.0]])
    )
    print(f"  - Trait architecture created.")

    # --- 5. Phenotype G0 and Calculate Founder TBV ---
    print("\nStep 3: Calculating phenotypes for G0 population...")
    heritabilities = jnp.array([0.4, 0.7])
    g0_phenotypes, g0_tbvs = calculate_phenotypes(
        key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities
    )
    mean_g0_tbv = jnp.mean(g0_tbvs[:, 0])
    print(f"  - Phenotypes and TBVs calculated for G0.")
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    # --- 6. Selection and Mating ---
    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    print(f"  - Selected the top {n_select} individuals based on Trait 1 phenotype.")

    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Created {n_offspring} random mating pairs.")

    # --- 7. Create Next Generation (G1) ---
    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno = selected_parents.geno[mother_indices]
    fathers_geno = selected_parents.geno[father_indices]
    mothers_ibd = selected_parents.ibd[mother_indices]
    fathers_ibd = selected_parents.ibd[father_indices]

    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)
    g1_geno, g1_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map
    )
    print(f"  - G1 genotypes created. Shape: {g1_geno.shape}")

    g1_meta = jnp.stack([
        jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
        selected_parents.meta[mother_indices, 0],
        selected_parents.meta[father_indices, 0],
        jnp.ones(n_offspring, dtype=jnp.int32)],
        axis=-1
    )
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)
    
    # --- 8. Analyze G1 and Genetic Gain ---
    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(
        key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities
    )
    mean_g1_tbv = jnp.mean(g1_tbvs[:, 0])

    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")
    
    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation of selection: {genetic_gain:+.3f} ---")


if __name__ == "__main__":
    main()

--- Running Full ChewC Breeding Cycle ---

Step 1: Generating founder population (G0) of 2000 individuals...
  - G0 Population generated.

Step 2: Generating a two-trait architecture...
  - Trait architecture created.

Step 3: Calculating phenotypes for G0 population...
  - Phenotypes and TBVs calculated for G0.
  - Mean TBV of Trait 1 in G0: 100.000

Step 4: Performing selection and random mating...
  - Selected the top 200 individuals based on Trait 1 phenotype.
  - Created 2000 random mating pairs.

Step 5: Creating the G1 population...


ValueError: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=6, len(args)=5

In [19]:
# For this example, we put all the necessary functions directly in the script.
# In the real library, these would be in their respective modules.
from typing import Tuple

# --- From _internal/phenotype.py ---
@jax.jit
def compute_dosage(population: Population) -> jnp.ndarray:
    return jnp.sum(population.geno, axis=2)

@jax.jit
def calculate_phenotypes(key: jax.Array, population: Population, trait: Trait, heritability: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    dosage = compute_dosage(population)
    qtl_dosage = dosage[:, trait.qtl_chromosome, trait.qtl_position]
    true_breeding_values = trait.intercept + qtl_dosage @ trait.qtl_effects
    additive_variance = jnp.var(true_breeding_values, axis=0)
    h2 = jnp.clip(heritability, a_min=1e-8, a_max=1.0 - 1e-8)
    environmental_variance = additive_variance * ((1 / h2) - 1)
    environmental_variance = jnp.maximum(0.0, environmental_variance)
    noise = jax.random.normal(key, true_breeding_values.shape)
    environmental_effects = noise * jnp.sqrt(environmental_variance)
    phenotypes = true_breeding_values + environmental_effects
    return phenotypes, true_breeding_values

# --- From _internal/meiosis.py ---
@partial(jax.jit, static_argnames=("max_crossovers",))
def _sample_chiasmata(key: jax.Array, map_length: float, interference_nu: float, max_crossovers: int) -> jax.Array:
    shape, scale = interference_nu, 1.0 / (2.0 * interference_nu)
    def scan_body(carry, _):
        key, last_pos = carry
        key, subkey = jax.random.split(key)
        new_pos = last_pos + jax.random.gamma(subkey, shape) * scale
        return (key, new_pos), new_pos
    key, initial_key = jax.random.split(key)
    initial_start_pos = jax.random.uniform(initial_key, minval=-10.0, maxval=0.0)
    _, crossover_positions = lax.scan(scan_body, (key, initial_start_pos), None, length=max_crossovers)
    valid_mask = (crossover_positions > 0) & (crossover_positions < map_length)
    return jnp.where(valid_mask, crossover_positions, jnp.nan)

@partial(jax.jit, static_argnames=("max_crossovers",))
def _create_recombinant_chromosome(key: jax.Array, parent_chr_geno: jax.Array, parent_chr_ibd: jax.Array, locus_positions: jax.Array, interference_nu: float, max_crossovers: int) -> Tuple[jax.Array, jax.Array]:
    key, chiasma_key, hap_key = jax.random.split(key, 3)
    crossover_positions = _sample_chiasmata(chiasma_key, locus_positions[-1], interference_nu, max_crossovers)
    crossover_indices = jnp.searchsorted(locus_positions, crossover_positions)
    start_hap = jax.random.randint(hap_key, (), 0, 2)
    locus_segments = jnp.searchsorted(crossover_indices, jnp.arange(locus_positions.shape[0]), side='right')
    haplotype_choice = (start_hap + locus_segments) % 2
    gamete_geno = jnp.where(haplotype_choice == 0, parent_chr_geno[0], parent_chr_geno[1])
    gamete_ibd = jnp.where(haplotype_choice == 0, parent_chr_ibd[0], parent_chr_ibd[1])
    return gamete_geno, gamete_ibd

@partial(jax.jit, static_argnames=('max_crossovers',))
def create_gamete(key: jax.Array, parent_geno: jax.Array, parent_ibd: jax.Array, genetic_map: GeneticMap, interference_nu: float, max_crossovers: int) -> Tuple[jax.Array, jax.Array]:
    chr_keys = jax.random.split(key, num=parent_geno.shape[0])
    vmapped_recombine = jax.vmap(_create_recombinant_chromosome, in_axes=(0, 0, 0, 0, None, None))
    # --- THE FIX IS HERE: Pass `max_crossovers` to the vmapped function ---
    gamete_geno, gamete_ibd = vmapped_recombine(chr_keys, parent_geno, parent_ibd, genetic_map.locus_positions, interference_nu, max_crossovers)
    return gamete_geno, gamete_ibd

# --- From _internal/crossing.py ---
@partial(jax.jit, static_argnames=('max_crossovers',))
def cross_pair(key: jax.Array, mother_geno: jax.Array, father_geno: jax.Array, mother_ibd: jax.Array, father_ibd: jax.Array, genetic_map: GeneticMap, max_crossovers: int) -> Tuple[jax.Array, jax.Array]:
    key_mother, key_father = jax.random.split(key)
    mother_gamete_geno, mother_gamete_ibd = create_gamete(key_mother, mother_geno, mother_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    father_gamete_geno, father_gamete_ibd = create_gamete(key_father, father_geno, father_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    offspring_geno = jnp.stack([mother_gamete_geno, father_gamete_geno], axis=1)
    offspring_ibd = jnp.stack([mother_gamete_ibd, father_gamete_ibd], axis=1)
    return offspring_geno, offspring_ibd

# --- From `selection.py` and `mating.py` ---
@partial(jax.jit, static_argnames=('k',))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    indices = jnp.argsort(values, axis=0)[::-1].squeeze()[:k]
    return jax.tree.map(lambda x: x[indices], population)

@partial(jax.jit, static_argnames=('n_parents', 'n_crosses'))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    mothers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    fathers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    return jnp.stack([mothers, fathers], axis=-1)

# --- Main simulation script ---
def main():
    print("--- Running Full ChewC Breeding Cycle ---")

    # Parameters
    n_individuals, n_chromosomes, n_loci_per_chr = 2000, 10, 1000
    seed, selection_intensity, n_offspring = 42, 0.1, n_individuals
    max_crossovers = 20 # Static parameter for meiosis

    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    print(f"\nStep 1: Generating founder population (G0)...")
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr)
    
    print("\nStep 2: Generating a two-trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([10.0, 2.0]),
        sigma=jnp.array([[10.0, -0.3 * jnp.sqrt(10*2)], [-0.3 * jnp.sqrt(10*2), 2.0]])
    )
    
    print("\nStep 3: Calculating phenotypes for G0...")
    heritabilities = jnp.array([0.4, 0.7])
    g0_phenotypes, g0_tbvs = calculate_phenotypes(key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities)
    mean_g0_tbv = jnp.mean(g0_tbvs[:, 0])
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Selected top {n_select} and created {n_offspring} pairs.")

    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno, fathers_geno = selected_parents.geno[mother_indices], selected_parents.geno[father_indices]
    mothers_ibd, fathers_ibd = selected_parents.ibd[mother_indices], selected_parents.ibd[father_indices]

    # vmap `cross_pair`. Note that `max_crossovers` is static and doesn't need an in_axes entry.
    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)
    
    g1_geno, g1_ibd = vmapped_cross(offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map, max_crossovers)
    print(f"  - G1 genotypes created. Shape: {g1_geno.shape}")

    g1_meta = jnp.stack([
        jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
        selected_parents.meta[mother_indices, 0],
        selected_parents.meta[father_indices, 0],
        jnp.ones(n_offspring, dtype=jnp.int32)], axis=-1)
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)
    
    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities)
    mean_g1_tbv = jnp.mean(g1_tbvs[:, 0])
    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")
    
    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation: {genetic_gain:+.3f} ---")

if __name__ == "__main__":
    main()








--- Running Full ChewC Breeding Cycle ---

Step 1: Generating founder population (G0)...

Step 2: Generating a two-trait architecture...

Step 3: Calculating phenotypes for G0...
  - Mean TBV of Trait 1 in G0: 100.000

Step 4: Performing selection and random mating...
  - Selected top 200 and created 2000 pairs.

Step 5: Creating the G1 population...


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (10 of them) had size 1000, e.g. axis 0 of argument locus_positions[0] of type float32[1000];
  * some axes (3 of them) had size 10, e.g. axis 0 of argument key of type uint32[10,2]

In [21]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple, List

# Assume chewc is installed and datatypes are importable
# We redefine them here for a self-contained script
from flax.struct import dataclass

@dataclass(frozen=True)
class Population:
    geno: jnp.ndarray
    ibd: jnp.ndarray
    meta: jnp.ndarray

@dataclass(frozen=True)
class Trait:
    qtl_chromosome: jnp.ndarray
    qtl_position: jnp.ndarray
    qtl_effects: jnp.ndarray
    intercept: jnp.ndarray

# --- THE FIX IS HERE (Part 1) ---
# locus_positions is now a single JAX array, not a list of arrays.
@dataclass(frozen=True)
class GeneticMap:
    chromosome_lengths: jnp.ndarray
    locus_positions: jnp.ndarray

# --- From initialization.py ---
def quick_haplo(key: jax.Array, n_ind: int, n_chr: int, seg_sites: int, inbred: bool = False, chr_length: float = 1.0) -> Tuple[Population, GeneticMap]:
    if inbred:
        haplotypes = jax.random.randint(key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.int8)
        geno = jnp.broadcast_to(haplotypes, (n_ind, n_chr, 2, seg_sites))
    else:
        geno = jax.random.randint(key, (n_ind, n_chr, 2, seg_sites), 0, 2, dtype=jnp.int8)
    ibd = jnp.broadcast_to(jnp.arange(n_ind * 2, dtype=jnp.int32).reshape(n_ind, 1, 2, 1), (n_ind, n_chr, 2, seg_sites))
    meta = jnp.stack([jnp.arange(n_ind), jnp.full((n_ind,), -1), jnp.full((n_ind,), -1), jnp.zeros(n_ind)], axis=-1, dtype=jnp.int32)
    population = Population(geno=geno, ibd=ibd, meta=meta)
    
    chromosome_lengths = jnp.full((n_chr,), chr_length, dtype=jnp.float32)
    # --- THE FIX IS HERE (Part 2) ---
    # Create a list of arrays, then stack them into a single 2D array.
    locus_positions_list = [jnp.linspace(0., chr_length, seg_sites) for _ in range(n_chr)]
    locus_positions = jnp.stack(locus_positions_list)
    
    genetic_map = GeneticMap(chromosome_lengths=chromosome_lengths, locus_positions=locus_positions)
    return population, genetic_map

# The `add_trait` function remains correct.
def add_trait(key: jax.Array, founder_pop: Population, n_qtl_per_chr: int, mean: jnp.ndarray, var: jnp.ndarray, sigma: jnp.ndarray) -> Trait:
    key, qtl_key, effect_key = jax.random.split(key, 3)
    n_chr, _, n_loci_per_chr = founder_pop.geno.shape[1:]
    n_total_qtl = n_qtl_per_chr * n_chr
    all_loci_indices = jnp.arange(n_chr * n_loci_per_chr)
    qtl_loc_flat = jax.random.choice(qtl_key, all_loci_indices, (n_total_qtl,), replace=False)
    qtl_chromosome, qtl_position = jnp.divmod(jnp.sort(qtl_loc_flat), n_loci_per_chr)
    n_traits = mean.shape[0]
    raw_effects = jax.random.normal(effect_key, (n_total_qtl, n_traits))
    cholesky_factor = jnp.linalg.cholesky(sigma)
    effects = raw_effects @ cholesky_factor.T
    founder_dosage = jnp.sum(founder_pop.geno, axis=2)
    qtl_dosage = founder_dosage[:, qtl_chromosome, qtl_position]
    gvs = qtl_dosage @ effects
    scale = jnp.sqrt(var / (jnp.var(gvs, axis=0) + 1e-8))
    intercept = mean - jnp.mean(gvs, axis=0) * scale
    final_effects = effects * scale
    return Trait(qtl_chromosome=qtl_chromosome, qtl_position=qtl_position, qtl_effects=final_effects, intercept=intercept)

# --- From _internal/phenotype.py ---
@jax.jit
def calculate_phenotypes(key: jax.Array, population: Population, trait: Trait, heritability: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    dosage = jnp.sum(population.geno, axis=2)
    qtl_dosage = dosage[:, trait.qtl_chromosome, trait.qtl_position]
    true_breeding_values = trait.intercept + qtl_dosage @ trait.qtl_effects
    additive_variance = jnp.var(true_breeding_values, axis=0)
    h2 = jnp.clip(heritability, a_min=1e-8, a_max=1.0 - 1e-8)
    environmental_variance = additive_variance * ((1 / h2) - 1)
    environmental_variance = jnp.maximum(0.0, environmental_variance)
    noise = jax.random.normal(key, true_breeding_values.shape)
    environmental_effects = noise * jnp.sqrt(environmental_variance)
    phenotypes = true_breeding_values + environmental_effects
    return phenotypes, true_breeding_values

# --- From _internal/meiosis.py and crossing.py ---
@partial(jax.jit, static_argnames=("max_crossovers",))
def create_gamete(key: jax.Array, parent_geno: jax.Array, parent_ibd: jax.Array, genetic_map: GeneticMap, interference_nu: float, max_crossovers: int) -> Tuple[jax.Array, jax.Array]:
    @partial(jax.jit, static_argnames=("max_crossovers",))
    def _create_recombinant_chromosome(key, parent_chr_geno, parent_chr_ibd, locus_positions, interference_nu, max_crossovers):
        @partial(jax.jit, static_argnames=("max_crossovers",))
        def _sample_chiasmata(key, map_length, interference_nu, max_crossovers):
            shape, scale = interference_nu, 1.0 / (2.0 * interference_nu)
            def scan_body(carry, _):
                key, last_pos = carry; key, subkey = jax.random.split(key); new_pos = last_pos + jax.random.gamma(subkey, shape) * scale; return (key, new_pos), new_pos
            key, initial_key = jax.random.split(key); initial_start_pos = jax.random.uniform(initial_key, minval=-10.0, maxval=0.0)
            _, crossover_positions = lax.scan(scan_body, (key, initial_start_pos), None, length=max_crossovers)
            valid_mask = (crossover_positions > 0) & (crossover_positions < map_length)
            return jnp.where(valid_mask, crossover_positions, jnp.nan)

        key, chiasma_key, hap_key = jax.random.split(key, 3)
        crossover_positions = _sample_chiasmata(chiasma_key, locus_positions[-1], interference_nu, max_crossovers)
        crossover_indices = jnp.searchsorted(locus_positions, crossover_positions)
        start_hap = jax.random.randint(hap_key, (), 0, 2)
        locus_segments = jnp.searchsorted(crossover_indices, jnp.arange(locus_positions.shape[0]), side='right')
        haplotype_choice = (start_hap + locus_segments) % 2
        gamete_geno = jnp.where(haplotype_choice == 0, parent_chr_geno[0], parent_chr_geno[1])
        gamete_ibd = jnp.where(haplotype_choice == 0, parent_chr_ibd[0], parent_chr_ibd[1])
        return gamete_geno, gamete_ibd

    chr_keys = jax.random.split(key, num=parent_geno.shape[0])
    vmapped_recombine = jax.vmap(_create_recombinant_chromosome, in_axes=(0, 0, 0, 0, None, None))
    gamete_geno, gamete_ibd = vmapped_recombine(chr_keys, parent_geno, parent_ibd, genetic_map.locus_positions, interference_nu, max_crossovers)
    return gamete_geno, gamete_ibd

@partial(jax.jit, static_argnames=('max_crossovers',))
def cross_pair(key: jax.Array, mother_geno: jax.Array, father_geno: jax.Array, mother_ibd: jax.Array, father_ibd: jax.Array, genetic_map: GeneticMap, max_crossovers: int) -> Tuple[jax.Array, jax.Array]:
    key_mother, key_father = jax.random.split(key)
    mother_gamete_geno, mother_gamete_ibd = create_gamete(key_mother, mother_geno, mother_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    father_gamete_geno, father_gamete_ibd = create_gamete(key_father, father_geno, father_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    offspring_geno = jnp.stack([mother_gamete_geno, father_gamete_geno], axis=1)
    offspring_ibd = jnp.stack([mother_gamete_ibd, father_gamete_ibd], axis=1)
    return offspring_geno, offspring_ibd

@partial(jax.jit, static_argnames=('k',))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    indices = jnp.argsort(values, axis=0)[::-1].squeeze()[:k]
    return jax.tree.map(lambda x: x[indices], population)

@partial(jax.jit, static_argnames=('n_parents', 'n_crosses'))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    mothers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    fathers = jax.random.choice(key, n_parents, shape=(n_crosses,))
    return jnp.stack([mothers, fathers], axis=-1)

# --- Main simulation script ---
def main():
    print("--- Running Full ChewC Breeding Cycle ---")

    n_individuals, n_chromosomes, n_loci_per_chr = 2000, 10, 1000
    seed, selection_intensity, n_offspring = 42, 0.1, n_individuals
    max_crossovers = 20

    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    print(f"\nStep 1: Generating founder population (G0)...")
    founder_pop, genetic_map = quick_haplo(key=pop_key, n_ind=n_individuals, n_chr=n_chromosomes, seg_sites=n_loci_per_chr)
    
    print("\nStep 2: Generating a two-trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0]), var=jnp.array([10.0, 2.0]),
        sigma=jnp.array([[10.0, -0.3 * jnp.sqrt(10*2)], [-0.3 * jnp.sqrt(10*2), 2.0]])
    )
    
    print("\nStep 3: Calculating phenotypes for G0...")
    heritabilities = jnp.array([0.4, 0.7])
    g0_phenotypes, g0_tbvs = calculate_phenotypes(key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities)
    mean_g0_tbv = jnp.mean(g0_tbvs[:, 0])
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Selected top {n_select} and created {n_offspring} pairs.")

    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno, fathers_geno = selected_parents.geno[mother_indices], selected_parents.geno[father_indices]
    mothers_ibd, fathers_ibd = selected_parents.ibd[mother_indices], selected_parents.ibd[father_indices]

    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)
    
    g1_geno, g1_ibd = vmapped_cross(offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map, max_crossovers)
    print(f"  - G1 genotypes created. Shape: {g1_geno.shape}")

    g1_meta = jnp.stack([
        jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
        selected_parents.meta[mother_indices, 0],
        selected_parents.meta[father_indices, 0],
        jnp.ones(n_offspring, dtype=jnp.int32)], axis=-1)
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)
    
    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities)
    mean_g1_tbv = jnp.mean(g1_tbvs[:, 0])
    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")
    
    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation: {genetic_gain:+.3f} ---")

if __name__ == "__main__":
    main()

--- Running Full ChewC Breeding Cycle ---

Step 1: Generating founder population (G0)...

Step 2: Generating a two-trait architecture...

Step 3: Calculating phenotypes for G0...
  - Mean TBV of Trait 1 in G0: 100.000

Step 4: Performing selection and random mating...
  - Selected top 200 and created 2000 pairs.

Step 5: Creating the G1 population...
  - G1 genotypes created. Shape: (2000, 10, 2, 1000)

Step 6: Analyzing G1 and calculating genetic gain...
  - Mean TBV of Trait 1 in G1: 103.635

--- Result: Genetic Gain for Trait 1 after one generation: +3.635 ---


In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from flax.struct import dataclass
from jax.tree_util import tree_map


# ---------------------------
# Datatypes (immutable PyTrees)
# ---------------------------

@dataclass(frozen=True)
class Population:
    geno: jnp.ndarray  # (n_ind, n_chr, 2, n_loci)
    ibd: jnp.ndarray   # (n_ind, n_chr, 2, n_loci) integer labels per hap-locus
    meta: jnp.ndarray  # (n_ind, 4) -> [id, mother_id, father_id, birth_gen] (int32)


@dataclass(frozen=True)
class Trait:
    qtl_chromosome: jnp.ndarray  # (n_qtl,) int32
    qtl_position: jnp.ndarray    # (n_qtl,) int32
    qtl_effects: jnp.ndarray     # (n_qtl, n_traits) float32
    intercept: jnp.ndarray       # (n_traits,) float32


@dataclass(frozen=True)
class GeneticMap:
    # locus_positions[c, :] is a sorted vector of positions for chromosome c
    chromosome_lengths: jnp.ndarray  # (n_chr,) float32
    locus_positions: jnp.ndarray     # (n_chr, n_loci) float32


# --------------------------------
# Initialization / Founders & Map
# --------------------------------

def quick_haplo(
    key: jax.Array,
    n_ind: int,
    n_chr: int,
    seg_sites: int,
    inbred: bool = False,
    chr_length: float = 1.0,
) -> Tuple[Population, GeneticMap]:
    """Create a simple founder population and a uniform genetic map."""
    if inbred:
        haplotypes = jax.random.randint(
            key, (n_ind, n_chr, 1, seg_sites), 0, 2, dtype=jnp.int8
        )
        geno = jnp.broadcast_to(haplotypes, (n_ind, n_chr, 2, seg_sites))
    else:
        geno = jax.random.randint(
            key, (n_ind, n_chr, 2, seg_sites), 0, 2, dtype=jnp.int8
        )

    # IBD labels: unique per haplotype across individuals
    ibd = jnp.broadcast_to(
        jnp.arange(n_ind * 2, dtype=jnp.int32).reshape(n_ind, 1, 2, 1),
        (n_ind, n_chr, 2, seg_sites),
    )

    # Meta (int32 all the way, stack doesn't accept dtype= kwarg)
    meta = jnp.stack(
        [
            jnp.arange(n_ind, dtype=jnp.int32),               # id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # mother_id
            jnp.full((n_ind,), -1, dtype=jnp.int32),          # father_id
            jnp.zeros((n_ind,), dtype=jnp.int32),             # birth_gen
        ],
        axis=-1,
    )

    population = Population(geno=geno, ibd=ibd, meta=meta)

    # Genetic map: same number of loci per chromosome, uniform spacing
    chromosome_lengths = jnp.full((n_chr,), chr_length, dtype=jnp.float32)
    locus_positions = jnp.stack(
        [jnp.linspace(0.0, chr_length, seg_sites, dtype=jnp.float32) for _ in range(n_chr)],
        axis=0,
    )
    genetic_map = GeneticMap(
        chromosome_lengths=chromosome_lengths, locus_positions=locus_positions
    )
    return population, genetic_map


# ----------------------------
# Trait architecture utilities
# ----------------------------

def _flatten_gather_chr_locus(dosage_chr_locus: jnp.ndarray,
                              chr_idx: jnp.ndarray,
                              locus_idx: jnp.ndarray) -> jnp.ndarray:
    """
    dosage_chr_locus: (n_ind, n_chr, n_loci)
    chr_idx, locus_idx: (n_qtl,)
    returns: (n_ind, n_qtl) gathered pairwise along (chr,locus).
    """
    n_chr, n_loci = dosage_chr_locus.shape[1], dosage_chr_locus.shape[2]
    flat = dosage_chr_locus.reshape(dosage_chr_locus.shape[0], n_chr * n_loci)
    flat_ids = chr_idx * n_loci + locus_idx
    return jnp.take(flat, flat_ids, axis=1)


def add_trait(
    key: jax.Array,
    founder_pop: Population,
    n_qtl_per_chr: int,
    mean: jnp.ndarray,   # (n_traits,)
    var: jnp.ndarray,    # (n_traits,)
    sigma: jnp.ndarray,  # (n_traits, n_traits) PSD
) -> Trait:
    """Sample QTLs and multi-trait effects with covariance `sigma`."""
    key, qtl_key, effect_key = jax.random.split(key, 3)

    n_chr = founder_pop.geno.shape[1]
    n_loci_per_chr = founder_pop.geno.shape[3]

    n_total_qtl = n_qtl_per_chr * n_chr
    all_loci_indices = jnp.arange(n_chr * n_loci_per_chr, dtype=jnp.int32)
    qtl_loc_flat = jax.random.choice(
        qtl_key, all_loci_indices, (n_total_qtl,), replace=False
    )
    qtl_chromosome, qtl_position = jnp.divmod(jnp.sort(qtl_loc_flat), n_loci_per_chr)

    n_traits = int(mean.shape[0])
    raw_effects = jax.random.normal(effect_key, (n_total_qtl, n_traits), dtype=jnp.float32)
    cholesky_factor = jnp.linalg.cholesky(sigma.astype(jnp.float32))  # (T, T)
    effects = raw_effects @ cholesky_factor.T  # (n_qtl, T)

    # Use founders to scale effects to desired trait variances
    founder_dosage = jnp.sum(founder_pop.geno, axis=2, dtype=jnp.int32)  # (n, chr, loci)
    qtl_dosage = _flatten_gather_chr_locus(founder_dosage, qtl_chromosome, qtl_position).astype(jnp.float32)
    gvs = qtl_dosage @ effects  # (n, T)

    # Scale to hit target marginal variances
    scale = jnp.sqrt(var.astype(jnp.float32) / (jnp.var(gvs, axis=0) + 1e-8))
    intercept = mean.astype(jnp.float32) - jnp.mean(gvs, axis=0) * scale
    final_effects = effects * scale  # (n_qtl, T)

    return Trait(
        qtl_chromosome=qtl_chromosome.astype(jnp.int32),
        qtl_position=qtl_position.astype(jnp.int32),
        qtl_effects=final_effects,
        intercept=intercept,
    )


# ----------------------
# Phenotype computation
# ----------------------




# ----------------------
# Meiosis / Crossing
# ----------------------




@partial(jax.jit, static_argnames=("max_crossovers",))
def cross_pair(
    key: jax.Array,
    mother_geno: jax.Array, father_geno: jax.Array,   # (n_chr, 2, n_loci)
    mother_ibd: jax.Array,  father_ibd: jax.Array,    # (n_chr, 2, n_loci)
    genetic_map: GeneticMap,
    max_crossovers: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    km, kf = jax.random.split(key)
    m_geno, m_ibd = create_gamete(km, mother_geno, mother_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    f_geno, f_ibd = create_gamete(kf, father_geno, father_ibd, genetic_map, interference_nu=4.0, max_crossovers=max_crossovers)
    offspring_geno = jnp.stack([m_geno, f_geno], axis=1)  # (n_chr, 2, n_loci)
    offspring_ibd  = jnp.stack([m_ibd,  f_ibd],  axis=1)  # (n_chr, 2, n_loci)
    return offspring_geno, offspring_ibd


# ---------------
# Breeder actions
# ---------------

@partial(jax.jit, static_argnames=("k",))
def select_top_k(population: Population, values: jnp.ndarray, k: int) -> Population:
    """Select top-k individuals by `values` (1-D)."""
    vals = jnp.ravel(values).astype(jnp.float32)  # (n_ind,)
    _, top_idx = lax.top_k(vals, k)               # (k,)
    return tree_map(lambda x: x[top_idx], population)


@partial(jax.jit, static_argnames=("n_crosses",))
def random_mating(key: jax.Array, n_parents: int, n_crosses: int) -> jnp.ndarray:
    """Sample (mother, father) pairs with replacement.
    Works even when `n_parents` is not a Python literal.
    """
    k_m, k_f = jax.random.split(key)
    mothers = jax.random.randint(k_m, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    fathers = jax.random.randint(k_f, (n_crosses,), minval=0, maxval=n_parents, dtype=jnp.int32)
    return jnp.stack([mothers, fathers], axis=-1)

# -------------------------
# Demo / One-generation run
# -------------------------

def main():
    print("--- Running Full ChewC Breeding Cycle ---")

    n_individuals, n_chromosomes, n_loci_per_chr = 2000, 10, 1000
    seed, selection_intensity, n_offspring = 42, 0.1, 2000
    max_crossovers = 20

    key = jax.random.PRNGKey(seed)
    pop_key, trait_key, g0_pheno_key, mating_key, cross_key, g1_pheno_key = jax.random.split(key, 6)

    print("\nStep 1: Generating founder population (G0)...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key,
        n_ind=n_individuals,
        n_chr=n_chromosomes,
        seg_sites=n_loci_per_chr,
        inbred=False,
        chr_length=1.0,
    )

    print("\nStep 2: Generating a two-trait architecture...")
    cov = -0.3 * jnp.sqrt(jnp.array(20.0, dtype=jnp.float32))  # -0.3 * sqrt(10*2)
    sigma = jnp.array([[10.0, cov], [cov, 2.0]], dtype=jnp.float32)
    trait_architecture = add_trait(
        key=trait_key,
        founder_pop=founder_pop,
        n_qtl_per_chr=50,
        mean=jnp.array([100.0, 50.0], dtype=jnp.float32),
        var=jnp.array([10.0, 2.0], dtype=jnp.float32),
        sigma=sigma,
    )

    print("\nStep 3: Calculating phenotypes for G0...")
    heritabilities = jnp.array([0.4, 0.7], dtype=jnp.float32)
    g0_phenotypes, g0_tbvs = calculate_phenotypes(
        key=g0_pheno_key, population=founder_pop, trait=trait_architecture, heritability=heritabilities
    )
    mean_g0_tbv = float(jnp.mean(g0_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G0: {mean_g0_tbv:.3f}")

    print("\nStep 4: Performing selection and random mating...")
    n_select = int(n_individuals * selection_intensity)
    selected_parents = select_top_k(founder_pop, g0_phenotypes[:, 0], k=n_select)
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    print(f"  - Selected top {n_select} and created {n_offspring} pairs.")

    print("\nStep 5: Creating the G1 population...")
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]
    mothers_geno, fathers_geno = selected_parents.geno[mother_indices], selected_parents.geno[father_indices]
    mothers_ibd, fathers_ibd = selected_parents.ibd[mother_indices], selected_parents.ibd[father_indices]

    vmapped_cross = jax.vmap(cross_pair, in_axes=(0, 0, 0, 0, 0, None, None))
    offspring_keys = jax.random.split(cross_key, n_offspring)

    g1_geno, g1_ibd = vmapped_cross(
        offspring_keys,
        mothers_geno,
        fathers_geno,
        mothers_ibd,
        fathers_ibd,
        genetic_map,
        max_crossovers,
    )
    print(f"  - G1 genotypes created. Shape: {tuple(g1_geno.shape)}")

    g1_meta = jnp.stack(
        [
            jnp.arange(n_offspring, dtype=jnp.int32) + n_individuals,
            selected_parents.meta[mother_indices, 0],
            selected_parents.meta[father_indices, 0],
            jnp.ones((n_offspring,), dtype=jnp.int32),
        ],
        axis=-1,
    )
    g1_population = Population(geno=g1_geno, ibd=g1_ibd, meta=g1_meta)

    print("\nStep 6: Analyzing G1 and calculating genetic gain...")
    _, g1_tbvs = calculate_phenotypes(
        key=g1_pheno_key, population=g1_population, trait=trait_architecture, heritability=heritabilities
    )
    mean_g1_tbv = float(jnp.mean(g1_tbvs[:, 0]))
    print(f"  - Mean TBV of Trait 1 in G1: {mean_g1_tbv:.3f}")

    genetic_gain = mean_g1_tbv - mean_g0_tbv
    print(f"\n--- Result: Genetic Gain for Trait 1 after one generation: {genetic_gain:+.3f} ---")


if __name__ == "__main__":
    main()


--- Running Full ChewC Breeding Cycle ---

Step 1: Generating founder population (G0)...

Step 2: Generating a two-trait architecture...

Step 3: Calculating phenotypes for G0...
  - Mean TBV of Trait 1 in G0: 100.000

Step 4: Performing selection and random mating...
  - Selected top 200 and created 2000 pairs.

Step 5: Creating the G1 population...
  - G1 genotypes created. Shape: (2000, 10, 2, 1000)

Step 6: Analyzing G1 and calculating genetic gain...
  - Mean TBV of Trait 1 in G1: 103.749

--- Result: Genetic Gain for Trait 1 after one generation: +3.749 ---


In [None]:


import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
from typing import Tuple

from flax.struct import dataclass
from jax.tree_util import tree_map


# ----------------------------------------------------
# NEW: State Carrier for the Simulation Loop
# ----------------------------------------------------
@dataclass(frozen=True)
class BreedingState:
    """The complete state of one simulation at a point in time."""
    population: Population
    key: jax.Array
    generation: int
    next_id: int # To ensure unique IDs for new individuals

# ----------------------------------------------------
# NEW: The JIT-Compiled Simulation Step
# ----------------------------------------------------

@partial(jax.jit, static_argnames=("n_select", "n_offspring", "max_crossovers"))
def selection_step(
    carry: BreedingState,
    _, # Placeholder for lax.scan's iteration number
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
) -> Tuple[BreedingState, jnp.ndarray]:
    """
    Executes one full cycle of selection and breeding.
    This function is designed to be the body of a lax.scan loop.
    """
    key, pheno_key, mating_key, cross_key = jax.random.split(carry.key, 4)
    current_pop = carry.population

    # 1. Evaluate the population
    phenotypes, tbvs = calculate_phenotypes(
        key=pheno_key, population=current_pop, trait=trait, heritability=heritabilities
    )
    mean_tbv = jnp.mean(tbvs[:, 0]) # Track genetic gain for trait 1

    # 2. Select top parents
    selected_parents = select_top_k(current_pop, phenotypes[:, 0], k=n_select)

    # 3. Generate a random mating plan
    pairings = random_mating(mating_key, n_parents=n_select, n_crosses=n_offspring)
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]

    # 4. Create the next generation
    mothers_geno = selected_parents.geno[mother_indices]
    fathers_geno = selected_parents.geno[father_indices]
    mothers_ibd = selected_parents.ibd[mother_indices]
    fathers_ibd = selected_parents.ibd[father_indices]

    # Vectorize the crossing operation over all pairs
    vmapped_cross = jax.vmap(
        cross_pair, in_axes=(0, 0, 0, 0, 0, None, None)
    )
    offspring_keys = jax.random.split(cross_key, n_offspring)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map, max_crossovers
    )

    # 5. Form the new population and update the state
    new_generation = carry.generation + 1
    new_ids = jnp.arange(n_offspring, dtype=jnp.int32) + carry.next_id
    new_meta = jnp.stack(
        [
            new_ids,
            selected_parents.meta[mother_indices, 0],
            selected_parents.meta[father_indices, 0],
            jnp.full((n_offspring,), new_generation, dtype=jnp.int32),
        ],
        axis=-1,
    )
    new_population = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)

    next_state = BreedingState(
        population=new_population,
        key=key,
        generation=new_generation,
        next_id=carry.next_id + n_offspring
    )

    # The scan function requires a `(carry, output)` return signature.
    # `carry` is the state for the next iteration.
    # `output` is the data we want to collect at each step.
    return next_state, mean_tbv

# ----------------------------------------------------
# NEW: The Top-Level Simulation Runner
# ----------------------------------------------------

def run_simulation_cycles(
    initial_state: BreedingState,
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_cycles: int,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
):
    """
    Runs the entire multi-cycle simulation using lax.scan for maximum performance.
    """
    # `lax.scan` requires the body function to be a callable that takes
    # (carry, x) and returns (new_carry, y). We use a lambda to fix the
    # static arguments and data that doesn't change over the loop.
    scan_fn = lambda carry, _: selection_step(
        carry,
        _,
        trait=trait,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_select=n_select,
        n_offspring=n_offspring,
        max_crossovers=max_crossovers
    )

    # Run the scan. The `None` is a placeholder for the `xs` array,
    # as we only care about the number of iterations (`length`).
    final_state, tbv_history = lax.scan(
        scan_fn, initial_state, None, length=n_cycles
    )

    return final_state, tbv_history


# ----------------------------------------------------
# Main execution script
# ----------------------------------------------------

if __name__ == "__main__":
    print("--- Running 10-Cycle Phenotypic Selection Experiment ---")

    # --- Simulation Parameters ---
    N_FOUNDERS = 100
    N_SELECT = 50
    N_OFFSPRING = 100 # Population size is kept constant
    N_CYCLES = 50

    N_CHR, N_LOCI = 5, 1000
    MAX_CROSSOVERS = 10
    SEED = 42

    key = jax.random.PRNGKey(SEED)

    # --- Setup ---
    key, pop_key, trait_key = jax.random.split(key, 3)
    print("\nStep 1: Initializing founder population...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI,
        inbred=False, chr_length=1.0
    )

    print("Step 2: Initializing trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([0.0]), var=jnp.array([1.0]),
        sigma=jnp.array([[10.0]])
    )
    heritabilities = jnp.array([0.4])

    # --- Initial State ---
    initial_state = BreedingState(
        population=founder_pop,
        key=key,
        generation=0,
        next_id=N_FOUNDERS # Next available individual ID
    )
    print(f"  - Founder population size: {N_FOUNDERS}")
    print(f"  - Selection: Top {N_SELECT} individuals")
    print(f"  - Offspring per cycle: {N_OFFSPRING}")


    # --- Run Simulation ---
    print(f"\nStep 3: Running {N_CYCLES} selection cycles (JIT compiling...)\n")
    final_state, tbv_history = run_simulation_cycles(
        initial_state=initial_state,
        trait=trait_architecture,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_cycles=N_CYCLES,
        n_select=N_SELECT,
        n_offspring=N_OFFSPRING,
        max_crossovers=MAX_CROSSOVERS,
    )

    # We need to calculate the TBV of the initial generation to show the full trend
    key, initial_pheno_key = jax.random.split(initial_state.key)
    _, initial_tbvs = calculate_phenotypes(
        key=initial_pheno_key, population=initial_state.population,
        trait=trait_architecture, heritability=heritabilities
    )
    initial_mean_tbv = jnp.mean(initial_tbvs[:, 0])

    # --- Report Results ---
    print("--- Results ---")
    print(f"Generation 00 Mean TBV: {initial_mean_tbv:.3f}")

    # Combine initial TBV with the history from the scan
    full_history = jnp.concatenate([jnp.array([initial_mean_tbv]), tbv_history])
    for i, tbv in enumerate(tbv_history, 1):
        gain = tbv - full_history[i-1]
        print(f"Generation {i:02d} Mean TBV: {tbv:.3f} (Gain: {gain:+.3f})")

    total_gain = tbv_history[-1] - initial_mean_tbv
    print(f"\n--- Total Genetic Gain over {N_CYCLES} generations: {total_gain:+.3f} ---")

--- Running 10-Cycle Phenotypic Selection Experiment ---

Step 1: Initializing founder population...
Step 2: Initializing trait architecture...
  - Founder population size: 100
  - Selection: Top 50 individuals
  - Offspring per cycle: 100

Step 3: Running 50 selection cycles (JIT compiling...)

--- Results ---
Generation 00 Mean TBV: -0.000
Generation 01 Mean TBV: -0.000 (Gain: +0.000)
Generation 02 Mean TBV: 0.598 (Gain: +0.598)
Generation 03 Mean TBV: 1.062 (Gain: +0.465)
Generation 04 Mean TBV: 1.456 (Gain: +0.394)
Generation 05 Mean TBV: 1.847 (Gain: +0.391)
Generation 06 Mean TBV: 2.173 (Gain: +0.326)
Generation 07 Mean TBV: 2.898 (Gain: +0.725)
Generation 08 Mean TBV: 3.335 (Gain: +0.437)
Generation 09 Mean TBV: 3.767 (Gain: +0.432)
Generation 10 Mean TBV: 4.069 (Gain: +0.302)
Generation 11 Mean TBV: 4.565 (Gain: +0.496)
Generation 12 Mean TBV: 4.932 (Gain: +0.366)
Generation 13 Mean TBV: 5.395 (Gain: +0.463)
Generation 14 Mean TBV: 5.772 (Gain: +0.378)
Generation 15 Mean TBV: 5

In [None]:

def run_simulation_cycles(
    initial_state: BreedingState,
    trait: Trait,
    genetic_map: GeneticMap,
    heritabilities: jnp.ndarray,
    n_cycles: int,
    n_select: int,
    n_offspring: int,
    max_crossovers: int
):
    """
    Runs the entire multi-cycle simulation using lax.scan for maximum performance.
    """
    # `lax.scan` requires the body function to be a callable that takes
    # (carry, x) and returns (new_carry, y). We use a lambda to fix the
    # static arguments and data that doesn't change over the loop.
    scan_fn = lambda carry, _: selection_step(
        carry,
        _,
        trait=trait,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_select=n_select,
        n_offspring=n_offspring,
        max_crossovers=max_crossovers
    )

    # Run the scan. The `None` is a placeholder for the `xs` array,
    # as we only care about the number of iterations (`length`).
    final_state, tbv_history = lax.scan(
        scan_fn, initial_state, None, length=n_cycles
    )

    return final_state, tbv_history


# ----------------------------------------------------
# Main execution script
# ----------------------------------------------------

if __name__ == "__main__":
    print("--- Running 10-Cycle Phenotypic Selection Experiment ---")

    # --- Simulation Parameters ---
    N_FOUNDERS = 100
    N_SELECT = 50
    N_OFFSPRING = 100 # Population size is kept constant
    N_CYCLES = 50

    N_CHR, N_LOCI = 5, 1000
    MAX_CROSSOVERS = 10
    SEED = 42

    key = jax.random.PRNGKey(SEED)

    # --- Setup ---
    key, pop_key, trait_key = jax.random.split(key, 3)
    print("\nStep 1: Initializing founder population...")
    founder_pop, genetic_map = quick_haplo(
        key=pop_key, n_ind=N_FOUNDERS, n_chr=N_CHR, seg_sites=N_LOCI,
        inbred=False, chr_length=1.0
    )

    print("Step 2: Initializing trait architecture...")
    trait_architecture = add_trait(
        key=trait_key, founder_pop=founder_pop, n_qtl_per_chr=50,
        mean=jnp.array([0.0]), var=jnp.array([1.0]),
        sigma=jnp.array([[10.0]])
    )
    heritabilities = jnp.array([0.4])

    # --- Initial State ---
    initial_state = BreedingState(
        population=founder_pop,
        key=key,
        generation=0,
        next_id=N_FOUNDERS # Next available individual ID
    )
    print(f"  - Founder population size: {N_FOUNDERS}")
    print(f"  - Selection: Top {N_SELECT} individuals")
    print(f"  - Offspring per cycle: {N_OFFSPRING}")


    # --- Run Simulation ---
    print(f"\nStep 3: Running {N_CYCLES} selection cycles (JIT compiling...)\n")
    final_state, tbv_history = run_simulation_cycles(
        initial_state=initial_state,
        trait=trait_architecture,
        genetic_map=genetic_map,
        heritabilities=heritabilities,
        n_cycles=N_CYCLES,
        n_select=N_SELECT,
        n_offspring=N_OFFSPRING,
        max_crossovers=MAX_CROSSOVERS,
    )

    # We need to calculate the TBV of the initial generation to show the full trend
    key, initial_pheno_key = jax.random.split(initial_state.key)
    _, initial_tbvs = calculate_phenotypes(
        key=initial_pheno_key, population=initial_state.population,
        trait=trait_architecture, heritability=heritabilities
    )
    initial_mean_tbv = jnp.mean(initial_tbvs[:, 0])

    # --- Report Results ---
    print("--- Results ---")
    print(f"Generation 00 Mean TBV: {initial_mean_tbv:.3f}")

    # Combine initial TBV with the history from the scan
    full_history = jnp.concatenate([jnp.array([initial_mean_tbv]), tbv_history])
    for i, tbv in enumerate(tbv_history, 1):
        gain = tbv - full_history[i-1]
        print(f"Generation {i:02d} Mean TBV: {tbv:.3f} (Gain: {gain:+.3f})")

    total_gain = tbv_history[-1] - initial_mean_tbv
    print(f"\n--- Total Genetic Gain over {N_CYCLES} generations: {total_gain:+.3f} ---")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
