In [9]:
#| default_exp workflow

In [10]:
import jax
import jax.numpy as jnp

# Assume the new meiosis module is available
# Updated import to include make_crosses_robust
from chewc.meiosis import produce_offspring, make_crosses_robust
from chewc.structs import Population, SimConfig, SimParam, init_state_from_founders
from chewc.popgen import quick_haplo
# from chewc.state import 
from chewc.trait import add_trait_a, set_pheno_h2

# 1. Define static simulation configuration
config = SimConfig(
    n_chr=10,
    ploidy=2,
    max_pop_size=1000,
    n_loci_per_chr=100,
    n_generations=50,
    n_select=50,       # Number of parents to select
    population_size=100, # Number of offspring to produce each generation
)

# 2. Create and split master random key
key = jax.random.PRNGKey(42)
# Add an extra key for the phenotype update step
key, founder_key, trait_key, pheno_key, state_key, selection_key, meiosis_key, pheno_update_key = jax.random.split(key, 8)

# 3. Generate the founder population
founder_pop, genetic_map = quick_haplo(
    key=founder_key,
    n_ind=100,
    n_chr=config.n_chr,
    n_loci_per_chr=config.n_loci_per_chr,
    max_pop_size=config.max_pop_size,
    ploidy=config.ploidy,
)

# 4. Define initial simulation parameters
sp = SimParam(
    gen_map=genetic_map,
    ploidy=config.ploidy,
    recomb_params=(2.6, 0.0, 0.0)
)

# 5. Add an additive trait to the simulation parameters
print("--- Defining Trait ---")
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=10,
    mean=jnp.array([0.0]),
    var=jnp.array([1.0]),  # Target *genetic* variance
)
print(f"Trait added with {sp.traits.n_loci} QTLs.")

# 6. Set phenotypes using a target narrow-sense heritability (h2)
print("\n--- Setting Initial Phenotypes with h2 ---")
h2 = 0.4
founder_pop_with_pheno = set_pheno_h2(
    key=pheno_key, pop=founder_pop, sp=sp, h2=h2
)

# 7. Initialize the dynamic simulation state
print("\n--- Initializing Simulation State ---")
initial_state = init_state_from_founders(
    key=state_key,
    founder_pop=founder_pop_with_pheno,
    sp=sp,
    config=config,
)

print(f"Initial write position: {initial_state.write_pos}")
print(f"Next available ID: {initial_state.next_id}")
print(f"Founder population active: {jnp.sum(initial_state.is_active)}")

# 8. Verification
print("\n--- Verification ---")
active_mask = initial_state.is_active
var_a = jnp.var(initial_state.bv[active_mask])
var_p = jnp.var(initial_state.pheno[active_mask])
realized_h2 = var_a / var_p

print(f"Target narrow-sense heritability (h2): {h2:.4f}")
print(f"Realized additive variance (VarA) in founders: {var_a:.4f}")
print(f"Realized phenotypic variance (VarP) in founders: {var_p:.4f}")
print(f"Realized narrow-sense heritability (h2) in founders: {realized_h2:.4f}")


# ==============================================================================
# --- Selection, Mating, and Meiosis Workflow ---
# ==============================================================================

print("\n--- 9. Phenotypic Selection ---")

# Mask inactive individuals to ensure they are not selected.
pheno = initial_state.pheno[:, 0]
selectable_pheno = jnp.where(initial_state.is_active, pheno, -jnp.inf)

# Select the top `n_select` individuals using the efficient `lax.top_k`.
_, top_indices = jax.lax.top_k(selectable_pheno, k=config.n_select)

print(f"Selected {len(top_indices)} individuals with top phenotypes.")


print("\n--- 10. Create Mating Pairs ---")

# Use the new sexless crossing function. It creates `population_size` pairs
# by sampling with replacement from the `n_select` parents.
mother_indices, father_indices = make_crosses_robust(
    key=selection_key,
    selected_indices=top_indices,
    n_offspring=config.population_size
)

print(f"Created {config.population_size} random pairs for mating.")


print("\n--- 11. Produce Offspring via Meiosis ---")

# This kernel vmaps over all crosses and is extremely fast on GPU.
offspring_geno, offspring_ibd = produce_offspring(
    key=meiosis_key,
    state=initial_state,
    sp=sp,
    config=config,
    mother_indices=mother_indices,
    father_indices=father_indices,
)

print("Meiosis complete.")
print(f"Shape of offspring geno array: {offspring_geno.shape}")


# ==============================================================================
# --- Update Simulation State with New Offspring ---
# ==============================================================================
print("\n--- 12. Update State for Next Generation ---")

state = initial_state
n_offspring = config.population_size

# A. Get the indices for the new cohort using the ring buffer `write_pos`
write_indices = state.write_pos + jnp.arange(n_offspring)

# B. Update core data arrays. The `.at[...].set(...)` pattern is the
#    standard JAX method for immutable array updates.
next_geno = state.geno.at[write_indices].set(offspring_geno)
next_ibd = state.ibd.at[write_indices].set(offspring_ibd)

# C. Update pedigree and metadata for the new cohort
new_ids = state.next_id + jnp.arange(n_offspring)
next_id_arr = state.id.at[write_indices].set(new_ids)
next_mother_arr = state.mother.at[write_indices].set(state.id[mother_indices])
next_father_arr = state.father.at[write_indices].set(state.id[father_indices])
next_gen_arr = state.gen.at[write_indices].set(state.gen_idx + 1)

# D. Update the active mask. Only the new cohort is active in the next generation.
next_is_active = jnp.zeros_like(state.is_active).at[write_indices].set(True)

# E. Calculate phenotypes and breeding values for the new active population
#    First, construct a temporary `Population` object with the updated arrays.
temp_pop = founder_pop_with_pheno.replace(
    geno=next_geno,
    ibd=next_ibd,
    id=next_id_arr,
    mother=next_mother_arr,
    father=next_father_arr,
    gen=next_gen_arr,
    is_active=next_is_active,
    pheno=jnp.full_like(founder_pop_with_pheno.pheno, jnp.nan), # Clear old values
    bv=jnp.full_like(founder_pop_with_pheno.bv, jnp.nan),
    gv=jnp.full_like(founder_pop_with_pheno.gv, jnp.nan),
)

#    Then, call `set_pheno_h2` to compute GVs, BVs, and Phenos for the new cohort.
updated_pop = set_pheno_h2(
    key=pheno_update_key, pop=temp_pop, sp=sp, h2=h2
)

# F. Assemble the final state object for the next generation
next_state = state.replace(
    # Updated data arrays from phenotyping
    geno=updated_pop.geno,
    ibd=updated_pop.ibd,
    pheno=updated_pop.pheno,
    bv=updated_pop.bv,
    is_active=updated_pop.is_active,
    # Updated pedigree arrays
    id=updated_pop.id,
    mother=updated_pop.mother,
    father=updated_pop.father,
    gen=updated_pop.gen,
    # Update state pointers
    key=key, # Use the new top-level key
    write_pos=(state.write_pos + n_offspring) % config.max_pop_size,
    gen_idx=state.gen_idx + 1,
    next_id=state.next_id + n_offspring,
)

print(f"New write position: {next_state.write_pos}")
print(f"Next available ID: {next_state.next_id}")
print(f"Number of active individuals in next generation: {jnp.sum(next_state.is_active)}")

# ==============================================================================
# --- Verification of the New State ---
# ==============================================================================
print("\n--- 13. Verification of New State ---")

assert jnp.sum(next_state.is_active) == config.population_size

offspring_ids = next_state.id[next_state.is_active]
print(f"First 5 new offspring IDs: {offspring_ids[:5]}")
assert offspring_ids[0] == initial_state.next_id

var_a_g1 = jnp.var(next_state.bv[next_state.is_active])
var_p_g1 = jnp.var(next_state.pheno[next_state.is_active])
realized_h2_g1 = var_a_g1 / var_p_g1

print(f"\nAdditive variance (VarA) in G1: {var_a_g1:.4f}")
print(f"Phenotypic variance (VarP) in G1: {var_p_g1:.4f}")
print(f"Realized heritability (h2) in G1: {realized_h2_g1:.4f}")

--- Defining Trait ---
Trait added with 100 QTLs.

--- Setting Initial Phenotypes with h2 ---

--- Initializing Simulation State ---
Initial write position: 100
Next available ID: 100
Founder population active: 100

--- Verification ---
Target narrow-sense heritability (h2): 0.4000
Realized additive variance (VarA) in founders: 1.0000
Realized phenotypic variance (VarP) in founders: 2.6278
Realized narrow-sense heritability (h2) in founders: 0.3805

--- 9. Phenotypic Selection ---
Selected 50 individuals with top phenotypes.

--- 10. Create Mating Pairs ---
Created 100 random pairs for mating.

--- 11. Produce Offspring via Meiosis ---
Meiosis complete.
Shape of offspring geno array: (100, 10, 2, 100)

--- 12. Update State for Next Generation ---
New write position: 200
Next available ID: 200
Number of active individuals in next generation: 100

--- 13. Verification of New State ---
First 5 new offspring IDs: [100 101 102 103 104]

Additive variance (VarA) in G1: 0.7489
Phenotypic vari

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()