In [1]:
# 1) Make all the important imports
import jax
import jax.numpy as jnp
import numpy as np # Still useful for setting up some initial arrays
from functools import partial

# Import the necessary functions and classes from the chewc library
from chewc.population import msprime_pop, combine_populations
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.select import TruncationSelection
from chewc.cross import make_cross

#| export
from dataclasses import field
from typing import List, Optional, Dict, Callable

from flax.struct import dataclass as flax_dataclass # Using flax's dataclass for JAX-friendliness
import jax
import jax.numpy as jnp

from chewc.sp import SimParam
from typing import Tuple
from numpy.random import default_rng
import msprime
import tskit
import numpy as np
import random
from collections import defaultdict

#testing
import jax
import jax.numpy as jnp
from fastcore.test import test_eq, test_ne


# --- Simulation Setup ---
# Create a master JAX random key for reproducibility
key = jax.random.PRNGKey(42)

# --- Parameters for the founder population ---
N_IND = 200       # Number of individuals in the founder pop
N_CHR = 10        # Number of chromosomes
N_LOCI_PER_CHR = 1000 # Number of loci per chromosome
EFFECTIVE_POP_SIZE = 10000 # Effective population size for msprime
N_QTL_PER_CHR = 100

# 2) Use msprime_pop to simulate a population
print("Step 1: Simulating base population with msprime...")
key, pop_key = jax.random.split(key)
founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=N_IND,
    n_loci_per_chr=N_LOCI_PER_CHR,
    n_chr=N_CHR,
    effective_population_size=EFFECTIVE_POP_SIZE,
    enforce_founder_maf=True # Ensures we have good variation to start
)
print(f"-> Initial founder population created: {founder_pop}")
# founder_pop.plot_maf()



# Create the main simulation parameter object
sp = SimParam.from_founder_pop(founder_pop, genetic_map, sexes='no')
print(f"-> SimParam object created: {sp}")


# 3) Establish a trait with 0.4 heritability and gamma-distributed effects
print("\nStep 2: Establishing the primary trait for the main experiment...")
key, trait_key = jax.random.split(key)
sp = add_trait_a(
    key=trait_key,
    founder_pop=founder_pop,
    sim_param=sp,
    n_qtl_per_chr=N_QTL_PER_CHR, # 100 QTLs per chromosome
    mean=jnp.array([0.0]), # A single trait with a mean of 0
    var=jnp.array([1.0]),  # and a genetic variance of 1
    gamma=True,            # Use the Gamma distribution for effects
    shape=0.4              # Shape parameter for the Gamma distribution
)
print("-> Trait with h^2=0.4 and gamma effects added to SimParam.")
# Note: The heritability of 0.4 isn't set here; this just defines the *genetic architecture*.
# The heritability is realized when we add environmental noise with set_pheno.

# 4) Simulate a burn-in for 10 generations
print("\nStep 3: Starting 10-generation burn-in with weak selection...")

# --- Parameters for the burn-in phase ---
N_GEN = 10
H2_BURN_IN = 0.1 # Low heritability for selection during burn-in
N_MALES_SELECTED = 20
N_FEMALES_SELECTED = 100
POP_SIZE_FIXED = 200 # Keep the population size constant


# --- REVISED Burn-in Loop (Pedigree Book Method) ---
from chewc.population import subset_population # Ensure this is imported

# Initialize
current_pop = founder_pop.replace(gen=jnp.zeros(founder_pop.nInd, dtype=jnp.int32))
pedigree_book = jnp.stack([current_pop.id, current_pop.father, current_pop.mother], axis=1)

for gen in range(1, N_GEN + 1):
    key, pheno_key, select_key, cross_key, sample_key = jax.random.split(key, 5)

    # 1. Phenotype the current population
    current_pop = set_pheno(
        key=pheno_key, pop=current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([H2_BURN_IN])
    )

    # 2. Select parents from the current generation
    selection_method = TruncationSelection()
    n_total_to_select = N_MALES_SELECTED + N_FEMALES_SELECTED
    top_indices = selection_method.select_parents(
        key=select_key, pop=current_pop, sp=sp, n_select=n_total_to_select
    )
    elite_parent_iids = current_pop.iid[top_indices]

    # 3. Create mating plan
    key_dams, key_sires = jax.random.split(sample_key)
    dam_iids = jax.random.choice(key_dams, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    sire_iids = jax.random.choice(key_sires, elite_parent_iids, shape=(POP_SIZE_FIXED,), replace=True)
    cross_plan = jnp.stack([dam_iids, sire_iids], axis=1)

    # 4. Create progeny, passing the full current_pop for parent lookup
    progeny_pop = make_cross(
        key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
        next_id_start=pedigree_book[:, 0].max() + 1
    )
    # Assign the correct generation number to the new progeny
    progeny_pop = progeny_pop.replace(gen=jnp.full(progeny_pop.nInd, gen, dtype=jnp.int32))
    
    # 5. Update the pedigree book and replace the current population
    progeny_pedigree = jnp.stack([progeny_pop.id, progeny_pop.father, progeny_pop.mother], axis=1)
    pedigree_book = jnp.vstack([pedigree_book, progeny_pedigree])
    current_pop = progeny_pop
    
    print(f"   > Gen {gen}/{N_GEN} complete. Living animals: {current_pop.nInd}. Total in pedigree: {pedigree_book.shape[0]}")

final_founder_pop = current_pop
print(f"\n--- Burn-in complete! ---")
print(f"Final living population: {final_founder_pop}")
print(f"Total animals in pedigree book: {pedigree_book.shape[0]}")



Step 1: Simulating base population with msprime...
-> Initial founder population created: Population(nInd=200, nTraits=0, has_ebv=No)
-> SimParam object created: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

Step 2: Establishing the primary trait for the main experiment...
-> Trait with h^2=0.4 and gamma effects added to SimParam.

Step 3: Starting 10-generation burn-in with weak selection...
   > Gen 1/10 complete. Living animals: 200. Total in pedigree: 400
   > Gen 2/10 complete. Living animals: 200. Total in pedigree: 600
   > Gen 3/10 complete. Living animals: 200. Total in pedigree: 800
   > Gen 4/10 complete. Living animals: 200. Total in pedigree: 1000
   > Gen 5/10 complete. Living animals: 200. Total in pedigree: 1200
   > Gen 6/10 complete. Living animals: 200. Total in pedigree: 1400
   > Gen 7/10 complete. Living animals: 200. Total in pedigree: 1600
   > Gen 8/10 complete. Living animals: 200. Total in pedigree: 1800
   > Gen 9/10 complete. Living animals: 200. Tota