In [None]:
import jax
import jax.numpy as jnp
from chewc.population import quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.cross import make_cross
from chewc.predict import gblup_predict
from chewc.population import combine_populations, Population
import matplotlib.pyplot as plt
from chewc.population import quick_haplo, combine_populations, Population, subset_population # Add subset_population


In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross

# ==================================================
# ---  simulation_parameters ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "h2" : jnp.array([.3]),
    "key": jax.random.PRNGKey(42)
}
# ====

# --- Setup Founder Population and Traits using Parameters ---
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

# Create SimParam objects for sp1 (Normal) and sp2 (Gamma)
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)



genetic_means = []
genetic_variances = []
current_pop = founder_pop
next_id_start = current_pop.nInd
n_generations = simulation_parameters["n_generations"]
population_size = simulation_parameters["population_size"]
n_select = simulation_parameters["n_select"]
pheno_key, select_key, mate_key = jax.random.split(key, 3)

for gen in range(n_generations + 1):
    #set up JAX keys
    p_subkey, pheno_key = jax.random.split(pheno_key)
    s_subkey, select_key = jax.random.split(select_key)
    m_subkey1, m_subkey2, mate_key = jax.random.split(mate_key, 3)
    #calculate true breeding values
    current_pop = set_bv(current_pop, sp.traits, sp.ploidy)
    genetic_means.append(jnp.mean(current_pop.bv))
    genetic_variances.append(jnp.var(current_pop.bv))
    #measure phenotypes (with noise based on h2)
    current_pop = set_pheno(
        key=p_subkey, pop=current_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([simulation_parameters['h2'][0]])
    )
    #naive selection based on phenotype values (find top individuals)
    selected_indices = jnp.argsort(current_pop.pheno[:, 0])[-n_select:]
    mothers = jax.random.choice(m_subkey1, selected_indices, shape=(population_size,))
    fathers = jax.random.choice(m_subkey2, selected_indices, shape=(population_size,))
    #make pairs from these top individuals
    cross_plan = jnp.stack([mothers, fathers], axis=1)
    #create new population object with the offspring
    current_pop = make_cross(
        key=s_subkey, pop=current_pop, cross_plan=cross_plan,
        sp=sp, next_id_start=next_id_start
    )
    next_id_start += population_size
    

: 