In [1]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
import numpy as np # For printing
from typing import Tuple

# Import functions from your library
from chewc.structs import Population, BreedingState, Trait, add_trait, GeneticMap
from chewc.pheno import calculate_phenotypes
from chewc.cross import random_mating, cross_pair


# --- 3. Define Metric Functions ---
@jax.jit
def calculate_ld(population: Population) -> jnp.ndarray:
    """Calculates mean squared correlation (r^2) between adjacent loci."""
    # Get dosages (n_ind, n_chr, n_loci)
    dosage = jnp.sum(population.geno, axis=2, dtype=jnp.float32)
    
    # Get adjacent loci pairs
    dosage_i = dosage[:, :, :-1] # (n_ind, n_chr, n_loci - 1)
    dosage_j = dosage[:, :, 1:]  # (n_ind, n_chr, n_loci - 1)
    
    # Calculate means (over individuals, axis=0)
    E_i = jnp.mean(dosage_i, axis=0)
    E_j = jnp.mean(dosage_j, axis=0)
    E_ij = jnp.mean(dosage_i * dosage_j, axis=0)
    
    # Calculate variance and covariance
    var_i = jnp.var(dosage_i, axis=0)
    var_j = jnp.var(dosage_j, axis=0)
    cov_ij = E_ij - E_i * E_j
    
    # Calculate r^2 per pair
    r2 = (cov_ij**2) / (var_i * var_j + 1e-8)
    
    # Return the mean r^2 over all pairs and chromosomes
    return jnp.mean(r2)

@jax.jit
def calculate_metrics(key: jax.Array, population: Population, trait: Trait) -> jnp.ndarray:
    """Calculates VA and LD."""
    # 1. Calculate Additive Variance (VA)
    # We pass a dummy heritability; it doesn't affect TBV calculation
    _, tbvs = calculate_phenotypes(key, population, trait, jnp.array([0.5]))
    V_A = jnp.var(tbvs, axis=0)[0]
    
    # 2. Calculate Linkage Disequilibrium (LD)
    ld = calculate_ld(population)
    
    return jnp.array([V_A, ld])

# --- 4. Define the JIT-compiled Scan Body Function ---
@partial(jax.jit, static_argnames=("n_pop_size", "max_crossovers"))
def burn_in_step(
    carry: BreedingState,
    _, # Placeholder for lax.scan
    trait: Trait,
    genetic_map: GeneticMap,
    n_pop_size: int,
    max_crossovers: int
) -> Tuple[BreedingState, jnp.ndarray]:
    """
    Executes one cycle of random mating and returns metrics.
    """
    key, mating_key, cross_key, pheno_key, metric_key = jax.random.split(carry.key, 5)
    current_pop = carry.population

    # --- A. Calculate metrics *before* breeding ---
    metrics = calculate_metrics(metric_key, current_pop, trait)
    
    # --- B. Breed next generation (no selection) ---
    # 1. Generate a random mating plan (N=200 -> N=200)
    pairings = random_mating(mating_key, n_parents=n_pop_size, n_crosses=n_pop_size)
    mother_indices, father_indices = pairings[:, 0], pairings[:, 1]

    # 2. Get parent data
    mothers_geno = current_pop.geno[mother_indices]
    fathers_geno = current_pop.geno[father_indices]
    mothers_ibd = current_pop.ibd[mother_indices]
    fathers_ibd = current_pop.ibd[father_indices]

    # 3. Vectorize the crossing operation
    vmapped_cross = jax.vmap(
        partial(cross_pair, max_crossovers=max_crossovers),
        in_axes=(0, 0, 0, 0, 0, None)
    )
    offspring_keys = jax.random.split(cross_key, n_pop_size)
    offspring_geno, offspring_ibd = vmapped_cross(
        offspring_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map
    )

    # 4. Form the new population and update the state
    new_generation = carry.generation + 1
    new_ids = jnp.arange(n_pop_size, dtype=jnp.int32) + carry.next_id
    new_meta = jnp.stack(
        [
            new_ids,
            current_pop.meta[mother_indices, 0], # Mother IDs
            current_pop.meta[father_indices, 0], # Father IDs
            jnp.full((n_pop_size,), new_generation, dtype=jnp.int32),
        ],
        axis=-1,
    )
    new_population = Population(geno=offspring_geno, ibd=offspring_ibd, meta=new_meta)

    next_state = BreedingState(
        population=new_population,
        key=key,
        generation=new_generation,
        next_id=carry.next_id + n_pop_size
    )

    return next_state, metrics




In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from functools import partial

# Import core chewc components
from chewc.structs import add_trait, BreedingState, Population
from chewc.pheno import calculate_phenotypes
from chewc.select import select_top_k
from chewc.cross import random_mating, cross_pair
# Import the new burnin module
from chewc.burnin import run_burnin

# --- 1. Experiment Parameters ---
N_POP = 150             # Constant population size
N_ENVIRONMENTS = 3      # Locations A, B, C
N_CHR = 5
N_LOCI = 1000
N_QTL = 50
SEED = 123

# Simulation Settings
BURN_IN_GENS = 50
MAX_CROSSOVERS = 10

# Selection Settings
N_SELECTION_GENS = 50
N_SELECT = 20           # Select top 20 parents
N_OFFSPRING = 150       # Next gen size

# --- 2. Burn-in (Initialize & Establish LD) ---
print(f"--- Setting up Multi-Environment Experiment ---")
print(f"Population Size: {N_POP}")
print(f"Environments: {N_ENVIRONMENTS} (A, B, C)")
print(f"\n[Phase 1] Running {BURN_IN_GENS} generations of burn-in...")

key = jax.random.PRNGKey(SEED)
key, burnin_key, trait_key = jax.random.split(key, 3)

# REFACTORED: Single call to handle initialization and burn-in
stable_state, final_ld, genetic_map = run_burnin(
    key=burnin_key,
    n_gens=BURN_IN_GENS,
    n_pop=N_POP,
    n_chr=N_CHR,
    n_loci=N_LOCI,
    max_crossovers=MAX_CROSSOVERS
)

print(f"Burn-in complete at Generation {stable_state.generation}")
print(f"Mean Adjacent LD (r^2) per chromosome: {final_ld}")


# --- 3. Define Multi-Environment Trait Architecture ---
print(f"\n[Phase 2] Defining Correlated Traits for Locations A, B, C...")

genetic_correlation = jnp.array([
    [1.0, 0.8, 0.8],  # Env A vs A, B, C
    [0.8, 1.0, 0.8],  # Env B vs A, B, C
    [0.8, 0.8, 1.0]   # Env C vs A, B, C
])

trait_arch = add_trait(
    key=trait_key,
    founder_pop=stable_state.population,
    n_qtl_per_chr=N_QTL,
    mean=jnp.zeros(N_ENVIRONMENTS),
    var_a=jnp.ones(N_ENVIRONMENTS),
    var_d=jnp.zeros(N_ENVIRONMENTS),
    sigma=genetic_correlation
)

HERITABILITIES = jnp.array([0.5, 0.5, 0.3])


# --- 4. Selection Loop Definition ---
@partial(jax.jit, static_argnames=("n_select", "n_offspring", "max_crossovers"))
def phenotypic_selection_step(carry, _, trait, genetic_map, heritability, n_select, n_offspring, max_crossovers):
    state = carry
    key, pheno_key, mate_key, cross_key = jax.random.split(state.key, 4)

    # A. Phenotype
    phenotypes, tbvs = calculate_phenotypes(pheno_key, state.population, trait, heritability)

    # B. Select
    selected_parents = select_top_k(state.population, phenotypes[:, 0], n_select)

    # C. Mate
    pairings = random_mating(mate_key, n_parents=n_select, n_crosses=n_offspring)

    mothers_geno = selected_parents.geno[pairings[:, 0]]
    fathers_geno = selected_parents.geno[pairings[:, 1]]
    mothers_ibd = selected_parents.ibd[pairings[:, 0]]
    fathers_ibd = selected_parents.ibd[pairings[:, 1]]

    # D. Cross
    vmapped_cross = jax.vmap(
        partial(cross_pair, max_crossovers=max_crossovers), 
        in_axes=(0, 0, 0, 0, 0, None)
    )
    off_keys = jax.random.split(cross_key, n_offspring)
    off_geno, off_ibd = vmapped_cross(off_keys, mothers_geno, fathers_geno, mothers_ibd, fathers_ibd, genetic_map)

    # E. Update State
    new_gen = state.generation + 1
    new_ids = jnp.arange(n_offspring, dtype=jnp.int32) + state.next_id
    
    new_meta = jnp.stack([
        new_ids,
        selected_parents.meta[pairings[:, 0], 0],
        selected_parents.meta[pairings[:, 1], 0],
        jnp.full((n_offspring,), new_gen, dtype=jnp.int32)
    ], axis=-1)

    new_pop = Population(geno=off_geno, ibd=off_ibd, meta=new_meta)
    
    next_state = BreedingState(
        population=new_pop,
        key=key,
        generation=new_gen,
        next_id=state.next_id + n_offspring
    )

    metrics = jnp.array([jnp.mean(tbvs[:, 0]), jnp.mean(phenotypes[:, 0])])
    return next_state, metrics


# --- 5. Execute Selection Loop ---
print(f"\n--- Starting {N_SELECTION_GENS} Generations of Selection ---")

select_scan_fn = partial(
    phenotypic_selection_step,
    trait=trait_arch,
    genetic_map=genetic_map,
    heritability=HERITABILITIES,
    n_select=N_SELECT,
    n_offspring=N_OFFSPRING,
    max_crossovers=MAX_CROSSOVERS
)

final_state, history = lax.scan(
    select_scan_fn, stable_state, None, length=N_SELECTION_GENS
)


# --- 6. Results Analysis ---
print("\nGeneration | Mean TBV (Genetic Gain) | Mean Phenotype")
print("---------------------------------------------------")
metrics_history = np.array(history)

for i in range(N_SELECTION_GENS):
    gen = stable_state.generation + i + 1
    tbv = metrics_history[i, 0]
    pheno = metrics_history[i, 1]
    print(f"Gen {gen:<3}    | {tbv:<23.4f} | {pheno:.4f}")

total_gain = metrics_history[-1, 0] - metrics_history[0, 0]
print(f"\nTotal Genetic Gain: {total_gain:.4f}")

--- Setting up Multi-Environment Experiment ---
Population Size: 150
Environments: 3 (A, B, C)

[Phase 1] Running 50 generations of burn-in...
Burn-in complete at Generation 50
Mean Adjacent LD (r^2) per chromosome: [0.15110835 0.14266446 0.14125235 0.14965521 0.1522844 ]

[Phase 2] Defining Correlated Traits for Locations A, B, C...

--- Starting 50 Generations of Selection ---

Generation | Mean TBV (Genetic Gain) | Mean Phenotype
---------------------------------------------------
Gen 51     | -3.1896                 | 0.1446
Gen 52     | -2.1088                 | 1.1518
Gen 53     | -0.9219                 | 2.2910
Gen 54     | -0.0862                 | 3.0515
Gen 55     | 0.8446                  | 4.1066
Gen 56     | 1.7936                  | 4.9539
Gen 57     | 2.6764                  | 5.7927
Gen 58     | 3.7430                  | 6.8852
Gen 59     | 4.4662                  | 7.6291
Gen 60     | 5.0609                  | 8.3524
Gen 61     | 5.8086                  | 8.9705
Gen 6