

Step 1: Simulating base population with msprime...
-> Initial founder population created: Population(nInd=200, nTraits=0, has_ebv=No)
-> SimParam object created: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

Step 2: Establishing the primary trait for the main experiment...
-> Trait with h^2=0.4 and gamma effects added to SimParam.

Step 3: Starting 10-generation burn-in with weak selection...
   > Gen 1/10 complete. Living animals: 200. Total in pedigree: 400
   > Gen 2/10 complete. Living animals: 200. Total in pedigree: 600
   > Gen 3/10 complete. Living animals: 200. Total in pedigree: 800
   > Gen 4/10 complete. Living animals: 200. Total in pedigree: 1000
   > Gen 5/10 complete. Living animals: 200. Total in pedigree: 1200
   > Gen 6/10 complete. Living animals: 200. Total in pedigree: 1400
   > Gen 7/10 complete. Living animals: 200. Total in pedigree: 1600
   > Gen 8/10 complete. Living animals: 200. Total in pedigree: 1800
   > Gen 9/10 complete. Living animals: 200. Tota

--- Step 1: Phenotypes Generated ---
Phenotyped 200 animals from the current generation.

--- Step 2: Building A-Inverse Matrix from Pedigree Book---
Successfully calculated A-inverse matrix with shape: (2200, 2200)

--- Step 3: Calculating Alpha ---
With h^2 = 0.3, alpha = 2.3333

--- Step 4: Solving Mixed Model Equations ---
Solved MME for all 2200 animals. Overall mean estimate (μ̂): 4.6339
Example EBV for a current animal: 0.4198

--- Step 5: Selecting Parents and Creating Progeny ---
Created 200 progeny for Generation 11.


--- Step 1: Phenotypes Generated ---
Phenotyped 200 animals from the current generation.

--- Step 2: Building A-Inverse Matrix from Pedigree Book---
Successfully calculated A-inverse matrix with shape: (2200, 2200)

--- Step 3: Calculating Alpha ---
With h^2 = 0.3, alpha = 2.3333

--- Step 4: Solving Mixed Model Equations ---
Solved MME for all 2200 animals. Overall mean estimate (μ̂): 4.6577
Example EBV for a current animal: 0.4484

--- Step 5: Selecting Parents and Creating Progeny ---
Created 200 progeny for Generation 11.


In [6]:
# 1) Make all the important imports
import jax
import jax.numpy as jnp
import numpy as np # Still useful for setting up some initial arrays
from functools import partial

# Import the necessary functions and classes from the chewc library
from chewc.population import msprime_pop, combine_populations
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.select import TruncationSelection
from chewc.cross import make_cross


# --- Simulation Setup ---
# Create a master JAX random key for reproducibility
key = jax.random.PRNGKey(42)

# --- Parameters for the founder population ---
N_IND = 200       # Number of individuals in the founder pop
N_CHR = 10        # Number of chromosomes
N_LOCI_PER_CHR = 1000 # Number of loci per chromosome
EFFECTIVE_POP_SIZE = 10000 # Effective population size for msprime
N_QTL_PER_CHR = 100

# 2) Use msprime_pop to simulate a population
print("Step 1: Simulating base population with msprime...")
key, pop_key = jax.random.split(key)
founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=N_IND,
    n_loci_per_chr=N_LOCI_PER_CHR,
    n_chr=N_CHR,
    effective_population_size=EFFECTIVE_POP_SIZE,
    enforce_founder_maf=True # Ensures we have good variation to start
)
print(f"-> Initial founder population created: {founder_pop}")
# founder_pop.plot_maf()



# Create the main simulation parameter object
sp = SimParam.from_founder_pop(founder_pop, genetic_map, sexes='no')
print(f"-> SimParam object created: {sp}")


# 3) Establish a trait with 0.4 heritability and gamma-distributed effects
print("\nStep 2: Establishing the primary trait for the main experiment...")
key, trait_key = jax.random.split(key)
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=N_QTL_PER_CHR, # 100 QTLs per chromosome
    mean=jnp.array([0.0]), # A single trait with a mean of 0
    var=jnp.array([1.0]),  # and a genetic variance of 1
    gamma=True,            # Use the Gamma distribution for effects
    shape=0.4              # Shape parameter for the Gamma distribution
)
print("-> Trait with h^2=0.4 and gamma effects added to SimParam.")
# Note: The heritability of 0.4 isn't set here; this just defines the *genetic architecture*.
# The heritability is realized when we add environmental noise with set_pheno.

# 4) Simulate a burn-in for 10 generations
print("\nStep 3: Starting 10-generation burn-in with weak selection...")

# --- Parameters for the burn-in phase ---
N_GEN = 10
H2_BURN_IN = 0.1 # Low heritability for selection during burn-in
N_MALES_SELECTED = 20
N_FEMALES_SELECTED = 100
POP_SIZE_FIXED = 200 # Keep the population size constant


# --- REVISED Burn-in Loop (Pedigree Book Method) ---
from chewc.population import subset_population # Ensure this is imported

# Initialize
current_pop = founder_pop.replace(gen=jnp.zeros(founder_pop.nInd, dtype=jnp.int32))
pedigree_book = jnp.stack([current_pop.id, current_pop.father, current_pop.mother], axis=1)

for gen in range(1, N_GEN + 1):
    key, pheno_key, select_key, cross_key, sample_key = jax.random.split(key, 5)

    # 1. Phenotype the current population
    current_pop = set_pheno(
        key=pheno_key, pop=current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([H2_BURN_IN])
    )

    # 2. Select parents from the current generation
    selection_method = TruncationSelection()
    n_total_to_select = N_MALES_SELECTED + N_FEMALES_SELECTED
    top_indices = selection_method.select_parents(
        key=select_key, pop=current_pop, sp=sp, n_select=n_total_to_select
    )
    elite_parent_iids = current_pop.iid[top_indices]

    # 3. Create mating plan
    key_dams, key_sires = jax.random.split(sample_key)
    dam_iids = jax.random.choice(key_dams, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    sire_iids = jax.random.choice(key_sires, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    cross_plan = jnp.stack([dam_iids, sire_iids], axis=1)

    # 4. Create progeny, passing the full current_pop for parent lookup
    progeny_pop = make_cross(
        key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
        next_id_start=pedigree_book[:, 0].max() + 1
    )
    # Assign the correct generation number to the new progeny
    progeny_pop = progeny_pop.replace(gen=jnp.full(progeny_pop.nInd, gen, dtype=jnp.int32))
    
    # 5. Update the pedigree book and replace the current population
    progeny_pedigree = jnp.stack([progeny_pop.id, progeny_pop.father, progeny_pop.mother], axis=1)
    pedigree_book = jnp.vstack([pedigree_book, progeny_pedigree])
    current_pop = progeny_pop
    
    print(f"   > Gen {gen}/{N_GEN} complete. Living animals: {current_pop.nInd}. Total in pedigree: {pedigree_book.shape[0]}")

final_founder_pop = current_pop
print(f"\n--- Burn-in complete! ---")
print(f"Final living population: {final_founder_pop}")
print(f"Total animals in pedigree book: {pedigree_book.shape[0]}")


import jax.numpy as jnp

# --- ADD THIS NEW FUNCTION to chewc/population.py ---

def calc_a_inverse_matrix_from_pedigree_book(pedigree_book: jnp.ndarray) -> jnp.ndarray:
    """
    Calculates A-inverse directly from a pedigree book array.

    Args:
        pedigree_book: A JAX array of shape (n_animals, 3) with columns
                       for Animal_ID, Sire_ID, and Dam_ID.
    """
    n_ind = pedigree_book.shape[0]
    all_ids = pedigree_book[:, 0]
    father_ids = pedigree_book[:, 2] # Dam is column 2
    mother_ids = pedigree_book[:, 1] # Sire is column 1

    # 1. --- Python-side Logic ---
    # Create the public ID to internal index (0, 1, 2...) mapping.
    # The pedigree book is assumed to be sorted by ID, so the index is just the row number.
    id_to_iid = {int(pub_id): i for i, pub_id in enumerate(all_ids)}
    unknown_parent_iid = -1

    # Convert public parent IDs to internal indices (iids)
    dam_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)
    sire_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)

    for i in range(n_ind):
        dam_pub_id = mother_ids[i]
        sire_pub_id = father_ids[i]
        if dam_pub_id != -1:
            dam_iids_np[i] = id_to_iid[int(dam_pub_id)]
        if sire_pub_id != -1:
            sire_iids_np[i] = id_to_iid[int(sire_pub_id)]

    # 2. --- Call the JIT-compiled Core ---
    dam_iids_jax = jnp.asarray(dam_iids_np)
    sire_iids_jax = jnp.asarray(sire_iids_np)

    # Assumes `_calc_a_inv_jax_loop` exists from the previous step.
    # If not, make sure to add it to your chewc/population.py file as well.
    return _calc_a_inv_jax_loop(n_ind, dam_iids_jax, sire_iids_jax)


# In chewc/population.py

import jax
import jax.numpy as jnp
from jax import lax
import numpy as np # Use numpy for the non-JIT part
from chewc.population import Population

from functools import partial

@partial(jax.jit, static_argnames=['n_ind'])
def _calc_a_inv_jax_loop(
    n_ind: int,
    dam_iids: jnp.ndarray,
    sire_iids: jnp.ndarray
) -> jnp.ndarray:
    """
    Core JIT-compiled loop for A-inverse calculation.
    Assumes parent IDs are already mapped to internal indices (iids).
    """
    unknown_parent_iid = -1

    def loop_body(i, A_inv):
        s = sire_iids[i]
        d = dam_iids[i]

        # This logic is now fully JAX-compatible as it only uses JAX arrays
        # and traceable operations.
        case_index = (s != unknown_parent_iid) + (d != unknown_parent_iid) * 2

        def case_0(mat): # Both parents unknown
            return mat.at[i, i].add(1.0)

        def case_1(mat): # Sire known, Dam unknown
            return mat.at[i, i].add(4/3).at[s, s].add(1/3).at[i, s].add(-2/3).at[s, i].add(-2/3)

        def case_2(mat): # Sire unknown, Dam known
            return mat.at[i, i].add(4/3).at[d, d].add(1/3).at[i, d].add(-2/3).at[d, i].add(-2/3)

        def case_3(mat): # Both parents known
            return (mat.at[i, i].add(2.0).at[s, s].add(0.5).at[d, d].add(0.5)
                       .at[s, d].add(0.5).at[d, s].add(0.5).at[i, s].add(-1.0)
                       .at[s, i].add(-1.0).at[i, d].add(-1.0).at[d, i].add(-1.0))

        return lax.switch(case_index, [case_0, case_1, case_2, case_3], A_inv)

    initial_A_inv = jnp.zeros((n_ind, n_ind))
    A_inv = lax.fori_loop(0, n_ind, loop_body, initial_A_inv)
    return A_inv

# --- REVISED: Public-facing wrapper function ---
def calc_a_inverse_matrix_pedigree_jax(pop: Population) -> jnp.ndarray:
    """
    Calculates the inverse of the pedigree-based A-matrix (A-inverse).
    This function handles the Python-side logic of mapping IDs and then
    calls the JIT-compiled core loop.
    """
    n_ind = pop.nInd

    # 1. --- Python-side Logic (Not JIT-compiled) ---
    # Create the ID-to-internal-index mapping using standard Python/NumPy
    id_to_iid = {int(pub_id): int(iid) for pub_id, iid in zip(pop.id, pop.iid)}
    unknown_parent_iid = -1

    # Convert public parent IDs to internal indices (iids)
    # Using NumPy here is fine as this part of the code is not traced by JAX.
    dam_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)
    sire_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)

    for i in range(n_ind):
        dam_pub_id = pop.mother[i]
        sire_pub_id = pop.father[i]
        if dam_pub_id != -1:
            dam_iids_np[i] = id_to_iid[int(dam_pub_id)]
        if sire_pub_id != -1:
            sire_iids_np[i] = id_to_iid[int(sire_pub_id)]

    # 2. --- Call the JIT-compiled Core ---
    # Convert the prepared NumPy arrays to JAX arrays before calling the kernel.
    dam_iids_jax = jnp.asarray(dam_iids_np)
    sire_iids_jax = jnp.asarray(sire_iids_np)

    return _calc_a_inv_jax_loop(n_ind, dam_iids_jax, sire_iids_jax)


# Import the new function
# from chewc.population import calc_a_inverse_matrix_from_pedigree_book

# --- ABLUP Generation 1 ---

# Parameters for our selection experiment
H2_SELECT = 0.3
N_MALES_TO_SELECT = 20
N_FEMALES_TO_SELECT = 100
N_PROGENY = 200

# Let's call the final living population from the burn-in `gen10_pop`
gen10_pop = final_founder_pop

# --- Step 1: Collect Phenotypes ---
# We only phenotype the current, living generation.
key, pheno_key_g11 = jax.random.split(key)
gen10_pop_phenotyped = set_pheno(
    key=pheno_key_g11,
    pop=gen10_pop,
    traits=sp.traits,
    ploidy=sp.ploidy,
    h2=jnp.array([H2_SELECT])
)
print("--- Step 1: Phenotypes Generated ---")
print(f"Phenotyped {gen10_pop_phenotyped.nInd} animals from the current generation.")


# --- Step 2: Assemble the Pedigree & Build A⁻¹ ---
print("\n--- Step 2: Building A-Inverse Matrix from Pedigree Book---")
# We use our new function and the complete pedigree_book
A_inv = calc_a_inverse_matrix_from_pedigree_book(pedigree_book)
print(f"Successfully calculated A-inverse matrix with shape: {A_inv.shape}")


# --- Step 3: Calculate Alpha ---
print("\n--- Step 3: Calculating Alpha ---")
alpha = (1 - H2_SELECT) / H2_SELECT
print(f"With h^2 = {H2_SELECT}, alpha = {alpha:.4f}")


# --- Step 4: Assemble and Solve the MME ---
print("\n--- Step 4: Solving Mixed Model Equations ---")
# The MME must be solved for ALL animals in the pedigree.
n_total_animals = pedigree_book.shape[0]

# Create a full-sized phenotype vector with NaNs for ancestors
y_full = jnp.full((n_total_animals, 1), jnp.nan)

# Map the current generation's phenotypes into the correct slots in y_full
# We find the start index of the current generation in the pedigree book
current_gen_start_index = pedigree_book.shape[0] - gen10_pop_phenotyped.nInd
y_full = y_full.at[current_gen_start_index:, 0].set(gen10_pop_phenotyped.pheno.flatten())

# Create a mask to identify animals with phenotypes
pheno_mask = ~jnp.isnan(y_full.flatten())

# Incidence matrices (X and Z) for the animals WITH phenotypes
X_pheno = jnp.ones((gen10_pop_phenotyped.nInd, 1))
# Z_pheno correctly maps the subset of phenotyped animals to their place in the full pedigree
Z_pheno = jnp.zeros((gen10_pop_phenotyped.nInd, n_total_animals)).at[jnp.arange(gen10_pop_phenotyped.nInd), current_gen_start_index + jnp.arange(gen10_pop_phenotyped.nInd)].set(1)

# Left-Hand Side (LHS)
LHS_top_row = jnp.hstack([X_pheno.T @ X_pheno, X_pheno.T @ Z_pheno])
LHS_bottom_row = jnp.hstack([Z_pheno.T @ X_pheno, Z_pheno.T @ Z_pheno + A_inv * alpha])
LHS = jnp.vstack([LHS_top_row, LHS_bottom_row])

# Right-Hand Side (RHS)
y_pheno = y_full[pheno_mask]
RHS = jnp.vstack([X_pheno.T @ y_pheno, Z_pheno.T @ y_pheno])

# Solve for estimates
solutions = jnp.linalg.solve(LHS, RHS)
mu_hat = solutions[0]
ebvs_all = solutions[1:]

print(f"Solved MME for all {n_total_animals} animals. Overall mean estimate (μ̂): {mu_hat[0]:.4f}")

# Extract EBVs for the current, selectable generation
ebvs_current_gen = ebvs_all[current_gen_start_index:]
gen10_pop_ebv = gen10_pop_phenotyped.replace(ebv=ebvs_current_gen)
print(f"Example EBV for a current animal: {gen10_pop_ebv.ebv[0][0]:.4f}")


# --- Step 5: Select Parents and Create Next Generation ---
print("\n--- Step 5: Selecting Parents and Creating Progeny ---")
key, select_key_g1, cross_key_g1 = jax.random.split(key, 3)

# Select parents from the current population based on their new EBVs
selection_method = TruncationSelection()
top_indices = selection_method.select_parents(
    key=select_key_g1,
    pop=gen10_pop_ebv, # Use the population object that has the EBVs
    sp=sp,
    n_select=(N_MALES_SELECTED + N_FEMALES_SELECTED)
)
elite_parent_iids = gen10_pop_ebv.iid[top_indices]

# Assign dams and sires from the elite pool
selected_female_iids = elite_parent_iids[:N_FEMALES_SELECTED]
selected_male_iids = elite_parent_iids[N_FEMALES_SELECTED:]

# Create mating plan
dam_iids = jax.random.choice(key, selected_female_iids, shape=(N_PROGENY,), replace=True)
sire_iids = jax.random.choice(key, selected_male_iids, shape=(N_PROGENY,), replace=True)
cross_plan_g1 = jnp.stack([dam_iids, sire_iids], axis=1)

# Create the next generation
progeny_pop_g1 = make_cross(
    key=cross_key_g1,
    pop=gen10_pop_ebv,
    cross_plan=cross_plan_g1,
    sp=sp,
    next_id_start=pedigree_book[:, 0].max() + 1
)
progeny_pop_g1 = progeny_pop_g1.replace(gen=jnp.full(progeny_pop_g1.nInd, N_GEN + 1, dtype=jnp.int32))

print(f"Created {progeny_pop_g1.nInd} progeny for Generation {N_GEN + 1}.")

Step 1: Simulating base population with msprime...
-> Initial founder population created: Population(nInd=200, nTraits=0, has_ebv=No)
-> SimParam object created: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

Step 2: Establishing the primary trait for the main experiment...
-> Trait with h^2=0.4 and gamma effects added to SimParam.

Step 3: Starting 10-generation burn-in with weak selection...
   > Gen 1/10 complete. Living animals: 200. Total in pedigree: 400
   > Gen 2/10 complete. Living animals: 200. Total in pedigree: 600
   > Gen 3/10 complete. Living animals: 200. Total in pedigree: 800
   > Gen 4/10 complete. Living animals: 200. Total in pedigree: 1000
   > Gen 5/10 complete. Living animals: 200. Total in pedigree: 1200
   > Gen 6/10 complete. Living animals: 200. Total in pedigree: 1400
   > Gen 7/10 complete. Living animals: 200. Total in pedigree: 1600
   > Gen 8/10 complete. Living animals: 200. Total in pedigree: 1800
   > Gen 9/10 complete. Living animals: 200. Tota

In [7]:
# 1) Make all the important imports
import jax
import jax.numpy as jnp
import numpy as np # Still useful for setting up some initial arrays
from functools import partial

# Import the necessary functions and classes from the chewc library
from chewc.population import msprime_pop, combine_populations
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.select import TruncationSelection
from chewc.cross import make_cross


# --- Simulation Setup ---
# Create a master JAX random key for reproducibility
key = jax.random.PRNGKey(42)

# --- Parameters for the founder population ---
N_IND = 200       # Number of individuals in the founder pop
N_CHR = 10        # Number of chromosomes
N_LOCI_PER_CHR = 1000 # Number of loci per chromosome
EFFECTIVE_POP_SIZE = 10000 # Effective population size for msprime
N_QTL_PER_CHR = 100

# 2) Use msprime_pop to simulate a population
print("Step 1: Simulating base population with msprime...")
key, pop_key = jax.random.split(key)
founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=N_IND,
    n_loci_per_chr=N_LOCI_PER_CHR,
    n_chr=N_CHR,
    effective_population_size=EFFECTIVE_POP_SIZE,
    enforce_founder_maf=True # Ensures we have good variation to start
)
print(f"-> Initial founder population created: {founder_pop}")
# founder_pop.plot_maf()



# Create the main simulation parameter object
sp = SimParam.from_founder_pop(founder_pop, genetic_map, sexes='no')
print(f"-> SimParam object created: {sp}")


# 3) Establish a trait with 0.4 heritability and gamma-distributed effects
print("\nStep 2: Establishing the primary trait for the main experiment...")
key, trait_key = jax.random.split(key)
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=N_QTL_PER_CHR, # 100 QTLs per chromosome
    mean=jnp.array([0.0]), # A single trait with a mean of 0
    var=jnp.array([1.0]),  # and a genetic variance of 1
    gamma=True,            # Use the Gamma distribution for effects
    shape=0.4              # Shape parameter for the Gamma distribution
)
print("-> Trait with h^2=0.4 and gamma effects added to SimParam.")
# Note: The heritability of 0.4 isn't set here; this just defines the *genetic architecture*.
# The heritability is realized when we add environmental noise with set_pheno.

# 4) Simulate a burn-in for 10 generations
print("\nStep 3: Starting 10-generation burn-in with weak selection...")

# --- Parameters for the burn-in phase ---
N_GEN = 10
H2_BURN_IN = 0.1 # Low heritability for selection during burn-in
N_MALES_SELECTED = 20
N_FEMALES_SELECTED = 100
POP_SIZE_FIXED = 200 # Keep the population size constant


# --- REVISED Burn-in Loop (Pedigree Book Method) ---
from chewc.population import subset_population # Ensure this is imported

# Initialize
current_pop = founder_pop.replace(gen=jnp.zeros(founder_pop.nInd, dtype=jnp.int32))
pedigree_book = jnp.stack([current_pop.id, current_pop.father, current_pop.mother], axis=1)

for gen in range(1, N_GEN + 1):
    key, pheno_key, select_key, cross_key, sample_key = jax.random.split(key, 5)

    # 1. Phenotype the current population
    current_pop = set_pheno(
        key=pheno_key, pop=current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([H2_BURN_IN])
    )

    # 2. Select parents from the current generation
    selection_method = TruncationSelection()
    n_total_to_select = N_MALES_SELECTED + N_FEMALES_SELECTED
    top_indices = selection_method.select_parents(
        key=select_key, pop=current_pop, sp=sp, n_select=n_total_to_select
    )
    elite_parent_iids = current_pop.iid[top_indices]

    # 3. Create mating plan
    key_dams, key_sires = jax.random.split(sample_key)
    dam_iids = jax.random.choice(key_dams, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    sire_iids = jax.random.choice(key_sires, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    cross_plan = jnp.stack([dam_iids, sire_iids], axis=1)

    # 4. Create progeny, passing the full current_pop for parent lookup
    progeny_pop = make_cross(
        key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
        next_id_start=pedigree_book[:, 0].max() + 1
    )
    # Assign the correct generation number to the new progeny
    progeny_pop = progeny_pop.replace(gen=jnp.full(progeny_pop.nInd, gen, dtype=jnp.int32))
    
    # 5. Update the pedigree book and replace the current population
    progeny_pedigree = jnp.stack([progeny_pop.id, progeny_pop.father, progeny_pop.mother], axis=1)
    pedigree_book = jnp.vstack([pedigree_book, progeny_pedigree])
    current_pop = progeny_pop
    
    print(f"   > Gen {gen}/{N_GEN} complete. Living animals: {current_pop.nInd}. Total in pedigree: {pedigree_book.shape[0]}")

final_founder_pop = current_pop
print(f"\n--- Burn-in complete! ---")
print(f"Final living population: {final_founder_pop}")
print(f"Total animals in pedigree book: {pedigree_book.shape[0]}")


import jax.numpy as jnp

# --- ADD THIS NEW FUNCTION to chewc/population.py ---

def calc_a_inverse_matrix_from_pedigree_book(pedigree_book: jnp.ndarray) -> jnp.ndarray:
    """
    Calculates A-inverse directly from a pedigree book array.

    Args:
        pedigree_book: A JAX array of shape (n_animals, 3) with columns
                       for Animal_ID, Sire_ID, and Dam_ID.
    """
    n_ind = pedigree_book.shape[0]
    all_ids = pedigree_book[:, 0]
    father_ids = pedigree_book[:, 2] # Dam is column 2
    mother_ids = pedigree_book[:, 1] # Sire is column 1

    # 1. --- Python-side Logic ---
    # Create the public ID to internal index (0, 1, 2...) mapping.
    # The pedigree book is assumed to be sorted by ID, so the index is just the row number.
    id_to_iid = {int(pub_id): i for i, pub_id in enumerate(all_ids)}
    unknown_parent_iid = -1

    # Convert public parent IDs to internal indices (iids)
    dam_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)
    sire_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)

    for i in range(n_ind):
        dam_pub_id = mother_ids[i]
        sire_pub_id = father_ids[i]
        if dam_pub_id != -1:
            dam_iids_np[i] = id_to_iid[int(dam_pub_id)]
        if sire_pub_id != -1:
            sire_iids_np[i] = id_to_iid[int(sire_pub_id)]

    # 2. --- Call the JIT-compiled Core ---
    dam_iids_jax = jnp.asarray(dam_iids_np)
    sire_iids_jax = jnp.asarray(sire_iids_np)

    # Assumes `_calc_a_inv_jax_loop` exists from the previous step.
    # If not, make sure to add it to your chewc/population.py file as well.
    return _calc_a_inv_jax_loop(n_ind, dam_iids_jax, sire_iids_jax)


# In chewc/population.py

import jax
import jax.numpy as jnp
from jax import lax
import numpy as np # Use numpy for the non-JIT part
from chewc.population import Population

from functools import partial

@partial(jax.jit, static_argnames=['n_ind'])
def _calc_a_inv_jax_loop(
    n_ind: int,
    dam_iids: jnp.ndarray,
    sire_iids: jnp.ndarray
) -> jnp.ndarray:
    """
    Core JIT-compiled loop for A-inverse calculation.
    Assumes parent IDs are already mapped to internal indices (iids).
    """
    unknown_parent_iid = -1

    def loop_body(i, A_inv):
        s = sire_iids[i]
        d = dam_iids[i]

        # This logic is now fully JAX-compatible as it only uses JAX arrays
        # and traceable operations.
        case_index = (s != unknown_parent_iid) + (d != unknown_parent_iid) * 2

        def case_0(mat): # Both parents unknown
            return mat.at[i, i].add(1.0)

        def case_1(mat): # Sire known, Dam unknown
            return mat.at[i, i].add(4/3).at[s, s].add(1/3).at[i, s].add(-2/3).at[s, i].add(-2/3)

        def case_2(mat): # Sire unknown, Dam known
            return mat.at[i, i].add(4/3).at[d, d].add(1/3).at[i, d].add(-2/3).at[d, i].add(-2/3)

        def case_3(mat): # Both parents known
            return (mat.at[i, i].add(2.0).at[s, s].add(0.5).at[d, d].add(0.5)
                       .at[s, d].add(0.5).at[d, s].add(0.5).at[i, s].add(-1.0)
                       .at[s, i].add(-1.0).at[i, d].add(-1.0).at[d, i].add(-1.0))

        return lax.switch(case_index, [case_0, case_1, case_2, case_3], A_inv)

    initial_A_inv = jnp.zeros((n_ind, n_ind))
    A_inv = lax.fori_loop(0, n_ind, loop_body, initial_A_inv)
    return A_inv

# --- REVISED: Public-facing wrapper function ---
def calc_a_inverse_matrix_pedigree_jax(pop: Population) -> jnp.ndarray:
    """
    Calculates the inverse of the pedigree-based A-matrix (A-inverse).
    This function handles the Python-side logic of mapping IDs and then
    calls the JIT-compiled core loop.
    """
    n_ind = pop.nInd

    # 1. --- Python-side Logic (Not JIT-compiled) ---
    # Create the ID-to-internal-index mapping using standard Python/NumPy
    id_to_iid = {int(pub_id): int(iid) for pub_id, iid in zip(pop.id, pop.iid)}
    unknown_parent_iid = -1

    # Convert public parent IDs to internal indices (iids)
    # Using NumPy here is fine as this part of the code is not traced by JAX.
    dam_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)
    sire_iids_np = np.full(n_ind, unknown_parent_iid, dtype=np.int32)

    for i in range(n_ind):
        dam_pub_id = pop.mother[i]
        sire_pub_id = pop.father[i]
        if dam_pub_id != -1:
            dam_iids_np[i] = id_to_iid[int(dam_pub_id)]
        if sire_pub_id != -1:
            sire_iids_np[i] = id_to_iid[int(sire_pub_id)]

    # 2. --- Call the JIT-compiled Core ---
    # Convert the prepared NumPy arrays to JAX arrays before calling the kernel.
    dam_iids_jax = jnp.asarray(dam_iids_np)
    sire_iids_jax = jnp.asarray(sire_iids_np)

    return _calc_a_inv_jax_loop(n_ind, dam_iids_jax, sire_iids_jax)


# Import the new function
# from chewc.population import calc_a_inverse_matrix_from_pedigree_book

# --- ABLUP Generation 1 ---

# Parameters for our selection experiment
H2_SELECT = 0.3
N_MALES_TO_SELECT = 20
N_FEMALES_TO_SELECT = 100
N_PROGENY = 200

# Let's call the final living population from the burn-in `gen10_pop`
gen10_pop = final_founder_pop

# --- Step 1: Collect Phenotypes ---
# We only phenotype the current, living generation.
key, pheno_key_g11 = jax.random.split(key)
gen10_pop_phenotyped = set_pheno(
    key=pheno_key_g11,
    pop=gen10_pop,
    traits=sp.traits,
    ploidy=sp.ploidy,
    h2=jnp.array([H2_SELECT])
)
print("--- Step 1: Phenotypes Generated ---")
print(f"Phenotyped {gen10_pop_phenotyped.nInd} animals from the current generation.")


# --- Step 2: Assemble the Pedigree & Build A⁻¹ ---
print("\n--- Step 2: Building A-Inverse Matrix from Pedigree Book---")
# We use our new function and the complete pedigree_book
A_inv = calc_a_inverse_matrix_from_pedigree_book(pedigree_book)
print(f"Successfully calculated A-inverse matrix with shape: {A_inv.shape}")


# --- Step 3: Calculate Alpha ---
print("\n--- Step 3: Calculating Alpha ---")
alpha = (1 - H2_SELECT) / H2_SELECT
print(f"With h^2 = {H2_SELECT}, alpha = {alpha:.4f}")


# --- Step 4: Assemble and Solve the MME ---
print("\n--- Step 4: Solving Mixed Model Equations ---")
# The MME must be solved for ALL animals in the pedigree.
n_total_animals = pedigree_book.shape[0]

# Create a full-sized phenotype vector with NaNs for ancestors
y_full = jnp.full((n_total_animals, 1), jnp.nan)

# Map the current generation's phenotypes into the correct slots in y_full
# We find the start index of the current generation in the pedigree book
current_gen_start_index = pedigree_book.shape[0] - gen10_pop_phenotyped.nInd
y_full = y_full.at[current_gen_start_index:, 0].set(gen10_pop_phenotyped.pheno.flatten())

# Create a mask to identify animals with phenotypes
pheno_mask = ~jnp.isnan(y_full.flatten())

# Incidence matrices (X and Z) for the animals WITH phenotypes
X_pheno = jnp.ones((gen10_pop_phenotyped.nInd, 1))
# Z_pheno correctly maps the subset of phenotyped animals to their place in the full pedigree
Z_pheno = jnp.zeros((gen10_pop_phenotyped.nInd, n_total_animals)).at[jnp.arange(gen10_pop_phenotyped.nInd), current_gen_start_index + jnp.arange(gen10_pop_phenotyped.nInd)].set(1)

# Left-Hand Side (LHS)
LHS_top_row = jnp.hstack([X_pheno.T @ X_pheno, X_pheno.T @ Z_pheno])
LHS_bottom_row = jnp.hstack([Z_pheno.T @ X_pheno, Z_pheno.T @ Z_pheno + A_inv * alpha])
LHS = jnp.vstack([LHS_top_row, LHS_bottom_row])

# Right-Hand Side (RHS)
y_pheno = y_full[pheno_mask]
RHS = jnp.vstack([X_pheno.T @ y_pheno, Z_pheno.T @ y_pheno])

# Solve for estimates
solutions = jnp.linalg.solve(LHS, RHS)
mu_hat = solutions[0]
ebvs_all = solutions[1:]

print(f"Solved MME for all {n_total_animals} animals. Overall mean estimate (μ̂): {mu_hat[0]:.4f}")

# Extract EBVs for the current, selectable generation
ebvs_current_gen = ebvs_all[current_gen_start_index:]
gen10_pop_ebv = gen10_pop_phenotyped.replace(ebv=ebvs_current_gen)
print(f"Example EBV for a current animal: {gen10_pop_ebv.ebv[0][0]:.4f}")


# --- Step 5: Select Parents and Create Next Generation ---
print("\n--- Step 5: Selecting Parents and Creating Progeny ---")
key, select_key_g1, cross_key_g1 = jax.random.split(key, 3)

# Select parents from the current population based on their new EBVs
selection_method = TruncationSelection()
top_indices = selection_method.select_parents(
    key=select_key_g1,
    pop=gen10_pop_ebv, # Use the population object that has the EBVs
    sp=sp,
    n_select=(N_MALES_SELECTED + N_FEMALES_SELECTED)
)
elite_parent_iids = gen10_pop_ebv.iid[top_indices]

# Assign dams and sires from the elite pool
selected_female_iids = elite_parent_iids[:N_FEMALES_SELECTED]
selected_male_iids = elite_parent_iids[N_FEMALES_SELECTED:]

# Create mating plan
dam_iids = jax.random.choice(key, selected_female_iids, shape=(N_PROGENY,), replace=True)
sire_iids = jax.random.choice(key, selected_male_iids, shape=(N_PROGENY,), replace=True)
cross_plan_g1 = jnp.stack([dam_iids, sire_iids], axis=1)

# Create the next generation
progeny_pop_g1 = make_cross(
    key=cross_key_g1,
    pop=gen10_pop_ebv,
    cross_plan=cross_plan_g1,
    sp=sp,
    next_id_start=pedigree_book[:, 0].max() + 1
)
progeny_pop_g1 = progeny_pop_g1.replace(gen=jnp.full(progeny_pop_g1.nInd, N_GEN + 1, dtype=jnp.int32))

print(f"Created {progeny_pop_g1.nInd} progeny for Generation {N_GEN + 1}.")

Step 1: Simulating base population with msprime...
-> Initial founder population created: Population(nInd=200, nTraits=0, has_ebv=No)
-> SimParam object created: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

Step 2: Establishing the primary trait for the main experiment...
-> Trait with h^2=0.4 and gamma effects added to SimParam.

Step 3: Starting 10-generation burn-in with weak selection...
   > Gen 1/10 complete. Living animals: 200. Total in pedigree: 400
   > Gen 2/10 complete. Living animals: 200. Total in pedigree: 600
   > Gen 3/10 complete. Living animals: 200. Total in pedigree: 800
   > Gen 4/10 complete. Living animals: 200. Total in pedigree: 1000
   > Gen 5/10 complete. Living animals: 200. Total in pedigree: 1200
   > Gen 6/10 complete. Living animals: 200. Total in pedigree: 1400
   > Gen 7/10 complete. Living animals: 200. Total in pedigree: 1600
   > Gen 8/10 complete. Living animals: 200. Total in pedigree: 1800
   > Gen 9/10 complete. Living animals: 200. Tota

In [8]:
# Place this function in your chewc/population.py or a utils file
@partial(jax.jit, static_argnames=['max_pedigree_size'])
def calc_a_inv_jit(pedigree_book: jnp.ndarray, current_pedigree_size: int, max_pedigree_size: int) -> jnp.ndarray:
    """
    JIT-compatible function to calculate A-inverse from a pre-allocated pedigree book.
    """
    # Use dynamic slice to get only the active part of the pedigree
    active_pedigree = jax.lax.dynamic_slice(pedigree_book, (0, 0), (current_pedigree_size, 3))
    
    all_ids = active_pedigree[:, 0]
    sire_ids = active_pedigree[:, 1]
    dam_ids = active_pedigree[:, 2]

    # Find parent indices (iids) using searchsorted. This is JIT-compatible.
    # The pedigree is sorted, so the index of an ID is its position.
    sire_iids = jnp.searchsorted(all_ids, sire_ids)
    dam_iids = jnp.searchsorted(all_ids, dam_ids)

    # Mask out unknown parents (-1), so they don't cause lookup errors.
    # searchsorted will map -1 to index 0, so we replace it with -1.
    sire_iids = jnp.where(sire_ids == -1, -1, sire_iids)
    dam_iids = jnp.where(dam_ids == -1, -1, dam_iids)

    # The core loop is already JIT-compatible, which is great.
    # We pass the current_pedigree_size to the static n_ind argument.
    A_inv_active = _calc_a_inv_jax_loop(current_pedigree_size, dam_iids, sire_iids)

    # Pad the result back to the maximum size for static shapes
    A_inv_padded = jnp.zeros((max_pedigree_size, max_pedigree_size))
    A_inv_padded = jax.lax.dynamic_update_slice(A_inv_padded, A_inv_active, (0, 0))
    
    return A_inv_padded

In [9]:
from typing import NamedTuple

class SimState(NamedTuple):
    key: jax.random.PRNGKey
    current_pop: Population
    pedigree_book: jnp.ndarray
    current_pedigree_size: int
    # You can add more metrics to track here
    genetic_mean_history: jnp.ndarray

In [10]:
def generation_step(state: SimState, gen_idx: int, sp: SimParam, max_pedigree_size: int):
    """
    A single step of the simulation for lax.scan.
    gen_idx is the loop counter (generation number).
    """
    key, pheno_key, select_key, cross_key = jax.random.split(state.key, 4)

    # 1. Phenotype the current population
    pop = set_pheno(
        key=pheno_key, pop=state.current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([H2_SELECT])
    )

    # 2. Build A-inverse from the current pedigree
    A_inv = calc_a_inv_jit(state.pedigree_book, state.current_pedigree_size, max_pedigree_size)

    # 3. Solve Mixed Model Equations
    alpha = (1 - H2_SELECT) / H2_SELECT
    n_total_animals = state.current_pedigree_size
    n_current_gen = pop.nInd

    # --- Construct LHS and RHS using dynamic slices ---
    y_full = jnp.full((n_total_animals,), jnp.nan)
    current_gen_start_idx = n_total_animals - n_current_gen
    
    y_full = y_full.at[current_gen_start_idx:].set(pop.pheno.flatten())
    pheno_mask = ~jnp.isnan(y_full)
    y_pheno = y_full[pheno_mask]

    X_pheno = jnp.ones((n_current_gen, 1))
    Z_pheno = jnp.zeros((n_current_gen, n_total_animals)).at[jnp.arange(n_current_gen), current_gen_start_idx + jnp.arange(n_current_gen)].set(1)

    LHS_top = jnp.hstack([X_pheno.T @ X_pheno, X_pheno.T @ Z_pheno])
    LHS_bottom = jnp.hstack([Z_pheno.T @ X_pheno, Z_pheno.T @ Z_pheno + A_inv[:n_total_animals, :n_total_animals] * alpha])
    LHS = jnp.vstack([LHS_top, LHS_bottom])
    
    RHS = jnp.vstack([X_pheno.T @ y_pheno, Z_pheno.T @ y_pheno])
    
    solutions = jnp.linalg.solve(LHS, RHS)
    ebvs_all = solutions[1:]
    
    ebvs_current_gen = jax.lax.dynamic_slice(ebvs_all, (current_gen_start_idx, 0), (n_current_gen, 1))
    pop = pop.replace(ebv=ebvs_current_gen)

    # 4. Select parents and create progeny
    selection_method = TruncationSelection()
    top_indices = selection_method.select_parents(
        key=select_key, pop=pop, sp=sp, n_select=(N_MALES_SELECTED + N_FEMALES_SELECTED)
    )
    elite_parent_iids = pop.iid[top_indices]
    
    selected_female_iids = elite_parent_iids[:N_FEMALES_SELECTED]
    selected_male_iids = elite_parent_iids[N_FEMALES_SELECTED:]
    
    dam_iids = jax.random.choice(key, selected_female_iids, shape=(N_PROGENY,), replace=True)
    sire_iids = jax.random.choice(key, selected_male_iids, shape=(N_PROGENY,), replace=True)
    cross_plan = jnp.stack([dam_iids, sire_iids], axis=1)

    progeny_pop = make_cross(
        key=cross_key, pop=pop, cross_plan=cross_plan, sp=sp,
        next_id_start=state.pedigree_book[:, 0].max() + 1
    )
    progeny_pop = progeny_pop.replace(gen=jnp.full(progeny_pop.nInd, gen_idx + 1, dtype=jnp.int32))

    # 5. Update the pedigree book and state
    progeny_pedigree = jnp.stack([progeny_pop.id, progeny_pop.father, progeny_pop.mother], axis=1)
    
    # Use dynamic_update_slice to add progeny to the pre-allocated array
    updated_ped_book = jax.lax.dynamic_update_slice(
        state.pedigree_book, progeny_pedigree, (state.current_pedigree_size, 0)
    )
    new_pedigree_size = state.current_pedigree_size + progeny_pop.nInd

    # Track metrics
    genetic_mean = jnp.mean(pop.gv)
    updated_history = state.genetic_mean_history.at[gen_idx].set(genetic_mean)

    # Return the new state for the next iteration
    new_state = SimState(
        key=key,
        current_pop=progeny_pop,
        pedigree_book=updated_ped_book,
        current_pedigree_size=new_pedigree_size,
        genetic_mean_history=updated_history
    )
    
    return new_state, None # Second element is for per-iteration outputs, which we ignore here

In [13]:
# --- Main Simulation Parameters ---
N_BURN_IN_GEN = 10
N_SELECT_GEN = 50
TOTAL_GEN = N_BURN_IN_GEN + N_SELECT_GEN
MAX_PEDIGREE_SIZE = N_IND + N_BURN_IN_GEN * POP_SIZE_FIXED + N_SELECT_GEN * N_PROGENY

# Make this a function that can be JIT-compiled
@partial(jax.jit, static_argnames=['n_gen'])
def run_simulation(initial_state: SimState, n_gen: int, sp: SimParam):
    
    # Create a partial function for the step, with static parameters baked in
    step_fn = partial(generation_step, sp=sp, max_pedigree_size=MAX_PEDIGREE_SIZE)
    
    # Run the scan
    final_state, _ = jax.lax.scan(step_fn, initial_state, xs=jnp.arange(n_gen))
    
    return final_state

# --- Setup and Run ---
# 1. Run the burn-in phase (can be done outside the main jit-able function if its parameters are always fixed)
# ... (Your burn-in loop code here, it's fine as it is) ...
# final_founder_pop = current_pop
# initial_pedigree_book = pedigree_book


# 1) Make all the important imports
import jax
import jax.numpy as jnp
import numpy as np # Still useful for setting up some initial arrays
from functools import partial

# Import the necessary functions and classes from the chewc library
from chewc.population import msprime_pop, combine_populations
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.select import TruncationSelection
from chewc.cross import make_cross


# --- Simulation Setup ---
# Create a master JAX random key for reproducibility
key = jax.random.PRNGKey(42)

# --- Parameters for the founder population ---
N_IND = 200       # Number of individuals in the founder pop
N_CHR = 10        # Number of chromosomes
N_LOCI_PER_CHR = 1000 # Number of loci per chromosome
EFFECTIVE_POP_SIZE = 10000 # Effective population size for msprime
N_QTL_PER_CHR = 100

# 2) Use msprime_pop to simulate a population
print("Step 1: Simulating base population with msprime...")
key, pop_key = jax.random.split(key)
founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=N_IND,
    n_loci_per_chr=N_LOCI_PER_CHR,
    n_chr=N_CHR,
    effective_population_size=EFFECTIVE_POP_SIZE,
    enforce_founder_maf=True # Ensures we have good variation to start
)
print(f"-> Initial founder population created: {founder_pop}")
# founder_pop.plot_maf()



# Create the main simulation parameter object
sp = SimParam.from_founder_pop(founder_pop, genetic_map, sexes='no')
print(f"-> SimParam object created: {sp}")


# 3) Establish a trait with 0.4 heritability and gamma-distributed effects
print("\nStep 2: Establishing the primary trait for the main experiment...")
key, trait_key = jax.random.split(key)
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=N_QTL_PER_CHR, # 100 QTLs per chromosome
    mean=jnp.array([0.0]), # A single trait with a mean of 0
    var=jnp.array([1.0]),  # and a genetic variance of 1
    gamma=True,            # Use the Gamma distribution for effects
    shape=0.4              # Shape parameter for the Gamma distribution
)
print("-> Trait with h^2=0.4 and gamma effects added to SimParam.")
# Note: The heritability of 0.4 isn't set here; this just defines the *genetic architecture*.
# The heritability is realized when we add environmental noise with set_pheno.

# 4) Simulate a burn-in for 10 generations
print("\nStep 3: Starting 10-generation burn-in with weak selection...")

# --- Parameters for the burn-in phase ---
N_GEN = 10
H2_BURN_IN = 0.1 # Low heritability for selection during burn-in
N_MALES_SELECTED = 20
N_FEMALES_SELECTED = 100
POP_SIZE_FIXED = 200 # Keep the population size constant


# --- REVISED Burn-in Loop (Pedigree Book Method) ---
from chewc.population import subset_population # Ensure this is imported

# Initialize
current_pop = founder_pop.replace(gen=jnp.zeros(founder_pop.nInd, dtype=jnp.int32))
pedigree_book = jnp.stack([current_pop.id, current_pop.father, current_pop.mother], axis=1)

for gen in range(1, N_GEN + 1):
    key, pheno_key, select_key, cross_key, sample_key = jax.random.split(key, 5)

    # 1. Phenotype the current population
    current_pop = set_pheno(
        key=pheno_key, pop=current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([H2_BURN_IN])
    )

    # 2. Select parents from the current generation
    selection_method = TruncationSelection()
    n_total_to_select = N_MALES_SELECTED + N_FEMALES_SELECTED
    top_indices = selection_method.select_parents(
        key=select_key, pop=current_pop, sp=sp, n_select=n_total_to_select
    )
    elite_parent_iids = current_pop.iid[top_indices]

    # 3. Create mating plan
    key_dams, key_sires = jax.random.split(sample_key)
    dam_iids = jax.random.choice(key_dams, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    sire_iids = jax.random.choice(key_sires, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    cross_plan = jnp.stack([dam_iids, sire_iids], axis=1)

    # 4. Create progeny, passing the full current_pop for parent lookup
    progeny_pop = make_cross(
        key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
        next_id_start=pedigree_book[:, 0].max() + 1
    )
    # Assign the correct generation number to the new progeny
    progeny_pop = progeny_pop.replace(gen=jnp.full(progeny_pop.nInd, gen, dtype=jnp.int32))
    
    # 5. Update the pedigree book and replace the current population
    progeny_pedigree = jnp.stack([progeny_pop.id, progeny_pop.father, progeny_pop.mother], axis=1)
    pedigree_book = jnp.vstack([pedigree_book, progeny_pedigree])
    current_pop = progeny_pop
    
    print(f"   > Gen {gen}/{N_GEN} complete. Living animals: {current_pop.nInd}. Total in pedigree: {pedigree_book.shape[0]}")

final_founder_pop = current_pop
initial_pedigree_book = pedigree_book
print(f"\n--- Burn-in complete! ---")
print(f"Final living population: {final_founder_pop}")
print(f"Total animals in pedigree book: {pedigree_book.shape[0]}")


initial_pedigree_size = pedigree_book.shape[0]


# 2. Set up the initial state for the main selection phase
pedigree_book_padded = jnp.full((MAX_PEDIGREE_SIZE, 3), -1, dtype=jnp.int32)
pedigree_book_padded = jax.lax.dynamic_update_slice(pedigree_book_padded, initial_pedigree_book, (0, 0))

initial_sim_state = SimState(
    key=key,
    current_pop=final_founder_pop,
    pedigree_book=pedigree_book_padded,
    current_pedigree_size=initial_pedigree_size,
    genetic_mean_history=jnp.zeros(N_SELECT_GEN)
)

# 3. Run the JIT-compiled simulation
print("\n--- Running JIT-compiled selection phase ---")
final_state = run_simulation(initial_sim_state, sp, N_SELECT_GEN)
print("--- Simulation complete! ---")
print(f"Final genetic mean: {final_state.genetic_mean_history[-1]}")

Step 1: Simulating base population with msprime...
-> Initial founder population created: Population(nInd=200, nTraits=0, has_ebv=No)
-> SimParam object created: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

Step 2: Establishing the primary trait for the main experiment...
-> Trait with h^2=0.4 and gamma effects added to SimParam.

Step 3: Starting 10-generation burn-in with weak selection...
   > Gen 1/10 complete. Living animals: 200. Total in pedigree: 400
   > Gen 2/10 complete. Living animals: 200. Total in pedigree: 600
   > Gen 3/10 complete. Living animals: 200. Total in pedigree: 800
   > Gen 4/10 complete. Living animals: 200. Total in pedigree: 1000
   > Gen 5/10 complete. Living animals: 200. Total in pedigree: 1200
   > Gen 6/10 complete. Living animals: 200. Total in pedigree: 1400
   > Gen 7/10 complete. Living animals: 200. Total in pedigree: 1600
   > Gen 8/10 complete. Living animals: 200. Total in pedigree: 1800
   > Gen 9/10 complete. Living animals: 200. Tota

ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'chewc.sp.SimParam'>, SimParam(nChr=10, nTraits=1, ploidy=2, sexes='no'). The error was:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/glect/.local/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/glect/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 211, in start
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/glect/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3077, in run_cell
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3132, in _run_cell
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3336, in run_cell_async
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3519, in run_ast_nodes
  File "/home/glect/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3579, in run_code
  File "/tmp/ipykernel_195654/92951164.py", line 170, in <module>
  File "<string>", line 3, in __hash__
TypeError: unhashable type: 'jaxlib._jax.ArrayImpl'
