In [1]:
import jax
import jax.numpy as jnp
import numpy as np

# Import core chewc components
from chewc.structs import add_trait
from chewc.burnin import run_burnin

# UPDATED: Import the pipeline runner from chewc.pipe
from chewc.pipe import run_simulation_cycles

# --- 1. Experiment Parameters ---
N_POP = 150             # Constant population size
N_ENVIRONMENTS = 1      # Single Location/Trait
N_CHR = 5
N_LOCI = 1000
N_QTL = 50
SEED = 123

# Simulation Settings
BURN_IN_GENS = 50
MAX_CROSSOVERS = 10

# Selection Settings
N_SELECTION_GENS = 50
N_SELECT = 20           # Select top 20 parents
N_OFFSPRING = 150       # Next gen size

# --- 2. Burn-in (Initialize & Establish LD) ---
print(f"--- Setting up Single-Trait Experiment ---")
print(f"Population Size: {N_POP}")
print(f"Environments: {N_ENVIRONMENTS} (Single Trait)")
print(f"\n[Phase 1] Running {BURN_IN_GENS} generations of burn-in...")

key = jax.random.PRNGKey(SEED)
key, burnin_key, trait_key = jax.random.split(key, 3)

# Single call to handle initialization and burn-in
stable_state, final_ld, genetic_map = run_burnin(
    key=burnin_key,
    n_gens=BURN_IN_GENS,
    n_pop=N_POP,
    n_chr=N_CHR,
    n_loci=N_LOCI,
    max_crossovers=MAX_CROSSOVERS
)

print(f"Burn-in complete at Generation {stable_state.generation}")
print(f"Mean Adjacent LD (r^2) per chromosome: {final_ld}")


# --- 3. Define Single Trait Architecture ---
print(f"\n[Phase 2] Defining Single Trait Architecture...")

# 1x1 Correlation matrix (scalar 1.0)
genetic_correlation = jnp.eye(N_ENVIRONMENTS)

trait_arch = add_trait(
    key=trait_key,
    founder_pop=stable_state.population,
    n_qtl_per_chr=N_QTL,
    mean=jnp.zeros(N_ENVIRONMENTS),
    var_a=jnp.ones(N_ENVIRONMENTS),
    var_d=jnp.zeros(N_ENVIRONMENTS),
    sigma=genetic_correlation
)

# Single heritability value
HERITABILITIES = jnp.array([0.5])


# --- 4. Execute Selection Loop ---
print(f"\n--- Starting {N_SELECTION_GENS} Generations of Selection ---")

# REFACTORED: Use run_simulation_cycles from chewc.pipe
# This handles the partial binding and lax.scan internally
final_state, history = run_simulation_cycles(
    initial_state=stable_state,
    trait=trait_arch,
    genetic_map=genetic_map,
    heritabilities=HERITABILITIES,
    n_cycles=N_SELECTION_GENS,
    n_select=N_SELECT,
    n_offspring=N_OFFSPRING,
    max_crossovers=MAX_CROSSOVERS
)


# --- 5. Results Analysis ---
print("\nGeneration | Mean TBV (Genetic Gain) | Mean Phenotype")
print("---------------------------------------------------")
metrics_history = np.array(history)

for i in range(N_SELECTION_GENS):
    gen = stable_state.generation + i + 1
    tbv = metrics_history[i, 0]
    pheno = metrics_history[i, 1]
    print(f"Gen {gen:<3}    | {tbv:<23.4f} | {pheno:.4f}")

total_gain = metrics_history[-1, 0] - metrics_history[0, 0]
print(f"\nTotal Genetic Gain: {total_gain:.4f}")



--- Setting up Single-Trait Experiment ---
Population Size: 150
Environments: 1 (Single Trait)

[Phase 1] Running 50 generations of burn-in...
Burn-in complete at Generation 50
Mean Adjacent LD (r^2) per chromosome: [0.15110835 0.14266446 0.14125235 0.14965521 0.1522844 ]

[Phase 2] Defining Single Trait Architecture...

--- Starting 50 Generations of Selection ---

Generation | Mean TBV (Genetic Gain) | Mean Phenotype
---------------------------------------------------
Gen 51     | 0.3050                  | 0.0893
Gen 52     | 1.5707                  | 1.3299
Gen 53     | 2.5418                  | 2.2419
Gen 54     | 3.4875                  | 3.1454
Gen 55     | 4.3080                  | 3.9606
Gen 56     | 5.1795                  | 4.8874
Gen 57     | 5.7102                  | 5.3484
Gen 58     | 6.5932                  | 6.2297
Gen 59     | 7.3113                  | 7.1080
Gen 60     | 7.8562                  | 7.5016
Gen 61     | 8.4342                  | 8.0798
Gen 62     | 8.9574