In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chewc.population import msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno, set_bv
from chewc.cross import make_cross

# ==================================================
# ---  simulation_parameters ---
# ==================================================
simulation_parameters = {
    "n_replicates": 3,
    "n_founder_ind": 10,
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 50,
    "population_size": 200,
    "n_select": 20,
    "key": jax.random.PRNGKey(42)
}
# ====

# --- Setup Founder Population and Traits using Parameters ---
key = simulation_parameters["key"]
key, founder_key, sp1_key, sp2_key = jax.random.split(key, 4)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

# Create SimParam objects for sp1 (Normal) and sp2 (Gamma)
sp1 = SimParam.from_founder_pop(founder_pop, genetic_map)
sp1 = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp1,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
    gamma=False
)

# Define the base templates for each simulation group
scenario_templates = [
    {"name_base": "sp1 (Normal)", "sp": sp1, "color": "blue"},
]

# Define the heritabilities you want to test
h2_values_to_test = [0.1, 0.6]

# Dynamically assign colors for different h2 values if you want
# For simplicity, we can just alternate or use a predefined map.
h2_colors = {"0.1": "blue", "0.6": "red"} # Example color mapping

