In [None]:
import jax
import jax.numpy as jnp
import time
from chewc.population import quick_haplo, Population
from chewc.sp import SimParam
from chewc.cross import make_cross

def run_benchmark():
    """
    Benchmarks the make_cross function with varying numbers of individuals and loci.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Benchmarking Parameters ---
    # Define the scenarios you want to test
    n_individuals_list = [100, 500, 1000, 2000]
    n_loci_per_chr_list = [1000, 5000, 10000]
    n_chr = 10
    n_crosses = 1000 # Number of progeny to create in each test

    # --- JAX Setup ---
    key = jax.random.PRNGKey(42)

    # --- Store Results ---
    results = {}

    # --- Run Benchmarks ---
    for n_ind in n_individuals_list:
        for n_loci in n_loci_per_chr_list:
            print(f"\n--- Running benchmark: {n_ind} individuals, {n_loci} loci/chr ---")

            # 1. --- Setup Population ---
            key, pop_key = jax.random.split(key)
            founder_pop, gen_map = quick_haplo(
                key=pop_key,
                n_ind=n_ind,
                n_chr=n_chr,
                n_loci_per_chr=n_loci
            )

            # 2. --- Setup Simulation Parameters ---
            sp = SimParam.from_founder_pop(founder_pop, gen_map)

            # 3. --- Create a Crossing Plan ---
            # Simple plan: randomly select mothers and fathers from the population
            key, cross_plan_key = jax.random.split(key)
            mother_iids = jax.random.choice(cross_plan_key, founder_pop.iid, shape=(n_crosses,))
            key, cross_plan_key = jax.random.split(key)
            father_iids = jax.random.choice(cross_plan_key, founder_pop.iid, shape=(n_crosses,))
            cross_plan = jnp.stack([mother_iids, father_iids], axis=1)

            # 4. --- JIT Compilation and Execution ---
            # The first run will be slower due to JIT compilation. We time it separately.
            key, cross_key = jax.random.split(key)

            start_time_jit = time.time()
            progeny_pop_jit = make_cross(
                key=cross_key,
                pop=founder_pop,
                cross_plan=cross_plan,
                sp=sp,
                next_id_start=founder_pop.nInd
            )
            # block_until_ready() is essential for accurate timing in JAX
            progeny_pop_jit.geno.block_until_ready()
            jit_time = time.time() - start_time_jit
            print(f"JIT Compilation + First Run Time: {jit_time:.4f} seconds")

            # 5. --- Time the Second (Cached) Execution ---
            key, cross_key = jax.random.split(key)

            start_time_cached = time.time()
            progeny_pop_cached = make_cross(
                key=cross_key,
                pop=founder_pop,
                cross_plan=cross_plan,
                sp=sp,
                next_id_start=founder_pop.nInd
            )
            progeny_pop_cached.geno.block_until_ready()
            cached_time = time.time() - start_time_cached
            print(f"Cached Execution Time: {cached_time:.4f} seconds")

            # --- Store results ---
            results[(n_ind, n_loci)] = {
                'jit_time': jit_time,
                'cached_time': cached_time
            }

    # --- Print Summary ---
    print("\n\n--- Benchmark Summary ---")
    print("Inds\tLoci\tJIT + 1st Run (s)\tCached Run (s)")
    for (n_ind, n_loci), times in results.items():
        print(f"{n_ind}\t{n_loci}\t{times['jit_time']:.4f}\t\t\t{times['cached_time']:.4f}")

if __name__ == '__main__':
    run_benchmark()

In [None]:
import jax
import jax.numpy as jnp
import time
from chewc.population import quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno

def run_phenotype_benchmark():
    """
    Benchmarks the set_pheno function with varying population sizes,
    genome sizes, and trait architectures.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Benchmarking Parameters ---
    n_individuals_list = [500, 1000, 2000]
    n_loci_per_chr_list = [5000, 10000]
    # Trait architectures: number of causal variants (QTLs) per chromosome
    n_qtl_per_chr_list = [10, 100, 1000]
    n_chr = 10
    h2 = 0.5 # Heritability

    # --- JAX Setup ---
    key = jax.random.PRNGKey(42)

    # --- Store Results ---
    results = []

    # --- Base Population Setup (to avoid re-creating large populations unnecessarily) ---
    for n_ind in n_individuals_list:
        for n_loci in n_loci_per_chr_list:
            print(f"\n--- Setting up Population: {n_ind} individuals, {n_loci * n_chr} total loci ---")
            key, pop_key = jax.random.split(key)
            # 1. Create a base population for the current size
            founder_pop, gen_map = quick_haplo(
                key=pop_key,
                n_ind=n_ind,
                n_chr=n_chr,
                n_loci_per_chr=n_loci
            )
            sp_base = SimParam.from_founder_pop(founder_pop, gen_map)

            # 2. Benchmark different trait architectures for this population
            for n_qtl in n_qtl_per_chr_list:
                print(f"  --- Benchmarking Trait: {n_qtl * n_chr} total QTLs ---")

                # 2a. Add the trait to define the architecture
                key, trait_key = jax.random.split(key)
                sp_with_trait = add_trait_a(
                    key=trait_key,
                    founder_pop=founder_pop,
                    sim_param=sp_base,
                    n_qtl_per_chr=n_qtl,
                    mean=jnp.array([0.]),
                    var=jnp.array([1.])
                )

                # Ensure we don't request more QTLs than available loci
                if n_qtl > n_loci:
                    print(f"    Skipping: n_qtl ({n_qtl}) > n_loci ({n_loci}).")
                    continue

                # 2b. JIT Compilation + First Run
                key, pheno_key = jax.random.split(key)
                start_time_jit = time.time()
                pop_with_pheno_jit = set_pheno(
                    key=pheno_key,
                    pop=founder_pop,
                    traits=sp_with_trait.traits,
                    ploidy=sp_with_trait.ploidy,
                    h2=jnp.array([h2])
                )
                pop_with_pheno_jit.pheno.block_until_ready()
                jit_time = time.time() - start_time_jit
                print(f"    JIT Compilation + First Run Time: {jit_time:.4f} seconds")

                # 2c. Time the Second (Cached) Execution
                key, pheno_key = jax.random.split(key)
                start_time_cached = time.time()
                pop_with_pheno_cached = set_pheno(
                    key=pheno_key,
                    pop=founder_pop,
                    traits=sp_with_trait.traits,
                    ploidy=sp_with_trait.ploidy,
                    h2=jnp.array([h2])
                )
                pop_with_pheno_cached.pheno.block_until_ready()
                cached_time = time.time() - start_time_cached
                print(f"    Cached Execution Time: {cached_time:.4f} seconds")

                # --- Store results ---
                results.append({
                    'n_ind': n_ind,
                    'n_total_loci': n_loci * n_chr,
                    'n_total_qtl': n_qtl * n_chr,
                    'jit_time': jit_time,
                    'cached_time': cached_time
                })

    # --- Print Summary ---
    print("\n\n--- Phenotype Module Benchmark Summary ---")
    print("Inds\tTotal Loci\tTotal QTLs\tJIT + 1st Run (s)\tCached Run (s)")
    print("-" * 75)
    for res in results:
        print(
            f"{res['n_ind']}\t{res['n_total_loci']}\t\t{res['n_total_qtl']}\t\t"
            f"{res['jit_time']:.4f}\t\t\t{res['cached_time']:.4f}"
        )

if __name__ == '__main__':
    run_phenotype_benchmark()

In [None]:
import jax
import jax.numpy as jnp
import time
from functools import partial

# --- chewc imports ---
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a, TraitCollection
from chewc.cross import _make_cross_geno # <-- IMPORT THE JIT-ABLE CORE FUNCTION
from chewc.phenotype import set_pheno

# ---------------------------------------------------------------------------
# 1. Define the full generation pipeline correctly
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=("n_crosses", "n_chr", "ploidy", "n_traits"))
def run_generation_pipeline(
    key,
    parent_pop,
    n_crosses,
    h2,
    # Unpacked arguments from the original 'sp' object:
    n_chr,
    gen_map,
    recomb_param_v,
    traits,
    ploidy,
    n_traits
):
    """
    JIT-compiled function for one generation.
    This version calls the core JIT-able functions directly.
    """
    key_selection, key_cross, key_sex, key_pheno = jax.random.split(key, 4)

    # --- Parent Selection ---
    mother_iids = jax.random.choice(key_selection, parent_pop.iid, shape=(n_crosses,))
    father_iids = jax.random.choice(key_selection, parent_pop.iid, shape=(n_crosses,))
    cross_plan = jnp.stack([mother_iids, father_iids], axis=1)

    # --- Genotype Creation Step ---
    # Directly call the core JIT-able function, not the helper
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_cross,
        parent_pop.geno[mother_iids],
        parent_pop.geno[father_iids],
        parent_pop.ibd[mother_iids],
        parent_pop.ibd[father_iids],
        n_chr,
        gen_map,
        recomb_param_v
    )

    # --- Assemble the Progeny Population (logic moved from make_cross) ---
    progeny_pop = Population(
        geno=progeny_geno,
        ibd=progeny_ibd,
        id=jnp.arange(parent_pop.nInd, parent_pop.nInd + n_crosses),
        iid=jnp.arange(n_crosses),
        mother=parent_pop.id[mother_iids],
        father=parent_pop.id[father_iids],
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        pheno=jnp.zeros((n_crosses, n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, n_traits))
    )

    # --- Phenotyping Step ---
    progeny_with_pheno = set_pheno(
        key=key_pheno,
        pop=progeny_pop,
        traits=traits,
        ploidy=ploidy,
        h2=jnp.array([h2])
    )

    return progeny_with_pheno

# ---------------------------------------------------------------------------
# 2. Define and run the benchmark (call site is updated)
# ---------------------------------------------------------------------------
def run_pipeline_benchmark():
    """
    Benchmarks the full generation pipeline with varying parameters.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Benchmarking Parameters ---
    n_individuals_list = [500, 1000]
    n_loci_per_chr_list = [5000, 10000]
    n_qtl_per_chr_list = [100, 1000]
    n_chr = 10
    n_progeny_to_create = 1000
    h2 = 0.5

    key = jax.random.PRNGKey(42)
    results = []

    for n_ind in n_individuals_list:
        for n_loci in n_loci_per_chr_list:
            key, pop_key = jax.random.split(key)
            founder_pop, gen_map = quick_haplo(
                key=pop_key, n_ind=n_ind, n_chr=n_chr, n_loci_per_chr=n_loci
            )
            sp_base = SimParam.from_founder_pop(founder_pop, gen_map)

            for n_qtl in n_qtl_per_chr_list:
                print(
                    f"\n--- Running: Inds={n_ind}, Loci={n_loci*n_chr}, "
                    f"QTLs={n_qtl*n_chr}, Progeny={n_progeny_to_create} ---"
                )

                key, trait_key = jax.random.split(key)
                sp_with_trait = add_trait_a(
                    key=trait_key, founder_pop=founder_pop, sim_param=sp_base,
                    n_qtl_per_chr=n_qtl, mean=jnp.array([0.]), var=jnp.array([1.])
                )

                # --- JIT Compilation + First Run ---
                key, pipeline_key = jax.random.split(key)
                start_time_jit = time.time()
                final_pop_jit = run_generation_pipeline(
                    key=pipeline_key,
                    parent_pop=founder_pop,
                    n_crosses=n_progeny_to_create,
                    h2=h2,
                    n_chr=sp_with_trait.n_chr,
                    gen_map=sp_with_trait.gen_map,
                    recomb_param_v=sp_with_trait.recomb_params[0],
                    traits=sp_with_trait.traits,
                    ploidy=sp_with_trait.ploidy,
                    n_traits=sp_with_trait.n_traits
                )
                final_pop_jit.pheno.block_until_ready()
                jit_time = time.time() - start_time_jit
                print(f"    JIT Compilation + First Run Time: {jit_time:.4f} seconds")

                # --- Cached Execution ---
                key, pipeline_key = jax.random.split(key)
                start_time_cached = time.time()
                final_pop_cached = run_generation_pipeline(
                    key=pipeline_key,
                    parent_pop=founder_pop,
                    n_crosses=n_progeny_to_create,
                    h2=h2,
                    n_chr=sp_with_trait.n_chr,
                    gen_map=sp_with_trait.gen_map,
                    recomb_param_v=sp_with_trait.recomb_params[0],
                    traits=sp_with_trait.traits,
                    ploidy=sp_with_trait.ploidy,
                    n_traits=sp_with_trait.n_traits
                )
                final_pop_cached.pheno.block_until_ready()
                cached_time = time.time() - start_time_cached
                print(f"    Cached Execution Time: {cached_time:.4f} seconds")

                results.append({
                    'n_ind': n_ind, 'n_total_loci': n_loci * n_chr,
                    'n_total_qtl': n_qtl * n_chr, 'jit_time': jit_time,
                    'cached_time': cached_time
                })

    # --- Print Summary ---
    print("\n\n--- Full Pipeline Benchmark Summary ---")
    print("Inds\tTotal Loci\tTotal QTLs\tJIT + 1st Run (s)\tCached Run (s)")
    print("-" * 75)
    for res in results:
        print(
            f"{res['n_ind']}\t{res['n_total_loci']}\t\t{res['n_total_qtl']}\t\t"
            f"{res['jit_time']:.4f}\t\t\t{res['cached_time']:.4f}"
        )

if __name__ == '__main__':
    run_pipeline_benchmark()

In [None]:
# add_benchmark_recurrent_selection.py

import jax
import jax.numpy as jnp
import time
from functools import partial
import matplotlib.pyplot as plt

# --- chewc imports ---
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.cross import _make_cross_geno
from chewc.phenotype import set_pheno

# ---------------------------------------------------------------------------
# 1. JIT-compatible Phenotypic Selection Function (Unchanged)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=("n_males", "n_females", "trait_idx"))
def select_pheno_jit(pop: Population, n_males: int, n_females: int, trait_idx: int = 0):
    """JIT-compatible phenotypic truncation selection."""
    phenotypes = pop.pheno[:, trait_idx]
    male_pheno = jnp.where(pop.sex == 0, phenotypes, -jnp.inf)
    male_iids = jnp.argsort(male_pheno)[-n_males:]
    female_pheno = jnp.where(pop.sex == 1, phenotypes, -jnp.inf)
    female_iids = jnp.argsort(female_pheno)[-n_females:]
    return male_iids, female_iids

# ---------------------------------------------------------------------------
# 2. Updated Full Generation Pipeline (Now includes progeny phenotyping)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=(
    "n_crosses", "n_chr", "ploidy", "n_traits", "n_select_males", "n_select_females"
))
def run_generation_pipeline(
    key,
    parent_pop,
    n_crosses,
    h2,
    n_select_males,
    n_select_females,
    # Unpacked arguments from the 'sp' object:
    n_chr,
    gen_map,
    recomb_param_v,
    traits,
    ploidy,
    n_traits
):
    """
    JIT-compiled function for one generation, including selection and progeny creation.
    This version now phenotypes the progeny before returning.
    """
    key_pheno_parent, key_selection, key_cross, key_sex, key_pheno_progeny = jax.random.split(key, 5)

    # --- Step 1: Phenotyping the Parent Population for Selection ---
    parent_pop_phenotyped = set_pheno(
        key=key_pheno_parent, pop=parent_pop, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )

    # --- Step 2: Phenotypic Selection ---
    male_iids, female_iids = select_pheno_jit(
        pop=parent_pop_phenotyped, n_males=n_select_males, n_females=n_select_females
    )

    # --- Step 3: Create Cross Plan ---
    mother_iids = jax.random.choice(key_selection, female_iids, shape=(n_crosses,), replace=True)
    father_iids = jax.random.choice(key_selection, male_iids, shape=(n_crosses,), replace=True)

    # --- Step 4: Genotype Creation ---
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_cross, parent_pop.geno[mother_iids], parent_pop.geno[father_iids],
        parent_pop.ibd[mother_iids], parent_pop.ibd[father_iids],
        n_chr, gen_map, recomb_param_v
    )

    # --- Step 5: Assemble the Progeny Population ---
    progeny_pop_unphenotyped = Population(
        geno=progeny_geno, ibd=progeny_ibd,
        id=jnp.arange(parent_pop.nInd, parent_pop.nInd + n_crosses),
        iid=jnp.arange(n_crosses), mother=parent_pop.id[mother_iids],
        father=parent_pop.id[father_iids],
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        pheno=jnp.zeros((n_crosses, n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, n_traits))
    )

    # --- Step 6: Phenotype the new Progeny Population ---
    progeny_pop_phenotyped = set_pheno(
        key=key_pheno_progeny, pop=progeny_pop_unphenotyped, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )

    return progeny_pop_phenotyped

# ---------------------------------------------------------------------------
# 3. New Recurrent Selection Benchmark
# ---------------------------------------------------------------------------
def run_recurrent_selection(n_generations=20):
    """
    Runs a multi-generation recurrent selection simulation and plots genetic gain.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Simulation Parameters ---
    n_ind = 1000
    n_loci_per_chr = 5000
    n_qtl_per_chr = 100
    n_chr = 10
    h2 = 0.3
    n_select_males = 50
    n_select_females = 500
    
    # The population size is kept constant by creating n_ind progeny each generation
    n_progeny_to_create = n_ind
    
    key = jax.random.PRNGKey(1337)

    # --- Setup Founder Population ---
    print("1. Setting up founder population...")
    key, pop_key, trait_key = jax.random.split(key, 3)
    founder_pop, gen_map = quick_haplo(
        key=pop_key, n_ind=n_ind, n_chr=n_chr, n_loci_per_chr=n_loci_per_chr
    )
    sp_base = SimParam.from_founder_pop(founder_pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=founder_pop, sim_param=sp_base,
        n_qtl_per_chr=n_qtl_per_chr, mean=jnp.array([0.]), var=jnp.array([1.])
    )
    
    # --- Initial Phenotyping (Generation 0) ---
    key, pheno_key = jax.random.split(key)
    current_pop = set_pheno(
        key=pheno_key, pop=founder_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([h2])
    )
    
    # --- Run Recurrent Selection Loop ---
    print(f"\n2. Starting {n_generations}-generation recurrent selection...")
    mean_phenotypes = [current_pop.pheno.mean().item()]
    
    start_time = time.time()

    for gen in range(n_generations):
        key, pipeline_key = jax.random.split(key)
        
        # Run one full generation of selection and breeding
        progeny_pop = run_generation_pipeline(
            key=pipeline_key,
            parent_pop=current_pop,
            n_crosses=n_progeny_to_create,
            h2=h2,
            n_select_males=n_select_males,
            n_select_females=n_select_females,
            n_chr=sp.n_chr,
            gen_map=sp.gen_map,
            recomb_param_v=sp.recomb_params[0],
            traits=sp.traits,
            ploidy=sp.ploidy,
            n_traits=sp.n_traits
        )
        
        # Progeny from this generation become the parents for the next
        current_pop = progeny_pop
        
        # Record the mean phenotype of the new generation
        mean_pheno = current_pop.pheno.mean()
        mean_pheno.block_until_ready() # Ensure calculation is complete before timing
        mean_phenotypes.append(mean_pheno.item())
        
        print(f"   - Generation {gen + 1:2d}/{n_generations} | Mean Phenotype: {mean_pheno:.4f}")

    total_time = time.time() - start_time
    print(f"\n3. Simulation finished.")
    print(f"   - Total time for {n_generations} generations: {total_time:.2f} seconds")
    print(f"   - Time per generation: {total_time / n_generations:.3f} seconds")
    
    # --- Plot Results ---
    plt.figure(figsize=(10, 6))
    plt.plot(range(n_generations + 1), mean_phenotypes, 'o-', label="Mean Phenotype")
    plt.title("Response to Phenotypic Selection over Generations")
    plt.xlabel("Generation")
    plt.ylabel("Mean Phenotypic Value")
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.xticks(range(0, n_generations + 1, 2))
    plt.tight_layout()
    plt.savefig("recurrent_selection_gain.png")
    print("\n4. Plot saved to 'recurrent_selection_gain.png'")


if __name__ == '__main__':
    run_recurrent_selection()

In [None]:
# add_benchmark_vmap_replicates.py

import jax
import jax.numpy as jnp
import time
from functools import partial
import matplotlib.pyplot as plt

# --- chewc imports ---
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.cross import _make_cross_geno
from chewc.phenotype import set_pheno

# ---------------------------------------------------------------------------
# 1. JIT-compatible Phenotypic Selection Function (Unchanged)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=("n_males", "n_females", "trait_idx"))
def select_pheno_jit(pop: Population, n_males: int, n_females: int, trait_idx: int = 0):
    """JIT-compatible phenotypic truncation selection."""
    phenotypes = pop.pheno[:, trait_idx]
    male_pheno = jnp.where(pop.sex == 0, phenotypes, -jnp.inf)
    male_iids = jnp.argsort(male_pheno)[-n_males:]
    female_pheno = jnp.where(pop.sex == 1, phenotypes, -jnp.inf)
    female_iids = jnp.argsort(female_pheno)[-n_females:]
    return male_iids, female_iids

# ---------------------------------------------------------------------------
# 2. JIT-compiled Single Generation Pipeline (Unchanged)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=(
    "n_crosses", "n_chr", "ploidy", "n_traits", "n_select_males", "n_select_females"
))
def run_generation_pipeline(
    key, parent_pop, n_crosses, h2, n_select_males, n_select_females,
    n_chr, gen_map, recomb_param_v, traits, ploidy, n_traits
):
    """JIT-compiled function for one generation."""
    key_pheno_parent, key_selection, key_cross, key_sex, key_pheno_progeny = jax.random.split(key, 5)
    parent_pop_phenotyped = set_pheno(
        key=key_pheno_parent, pop=parent_pop, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )
    male_iids, female_iids = select_pheno_jit(
        pop=parent_pop_phenotyped, n_males=n_select_males, n_females=n_select_females
    )
    mother_iids = jax.random.choice(key_selection, female_iids, shape=(n_crosses,), replace=True)
    father_iids = jax.random.choice(key_selection, male_iids, shape=(n_crosses,), replace=True)
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_cross, parent_pop.geno[mother_iids], parent_pop.geno[father_iids],
        parent_pop.ibd[mother_iids], parent_pop.ibd[father_iids],
        n_chr, gen_map, recomb_param_v
    )
    progeny_pop_unphenotyped = Population(
        geno=progeny_geno, ibd=progeny_ibd,
        id=jnp.arange(parent_pop.nInd, parent_pop.nInd + n_crosses),
        iid=jnp.arange(n_crosses), mother=parent_pop.id[mother_iids],
        father=parent_pop.id[father_iids],
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        pheno=jnp.zeros((n_crosses, n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, n_traits))
    )
    progeny_pop_phenotyped = set_pheno(
        key=key_pheno_progeny, pop=progeny_pop_unphenotyped, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )
    return progeny_pop_phenotyped

# ---------------------------------------------------------------------------
# 3. New Function for a Single Replicate (Using jax.lax.scan)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=(
    "n_generations", "n_ind", "n_loci_per_chr", "n_qtl_per_chr", "n_chr", "h2",
    "n_select_males", "n_select_females"
))
def run_single_replicate(
    replicate_key, n_generations, n_ind, n_loci_per_chr, n_qtl_per_chr,
    n_chr, h2, n_select_males, n_select_females
):
    """
    Runs a full, multi-generation simulation for a single replicate.
    This function is designed to be transformed by jax.vmap.
    """
    # --- 1. Setup Founder Population for this replicate ---
    key, pop_key, trait_key, pheno_key = jax.random.split(replicate_key, 4)
    founder_pop, gen_map = quick_haplo(
        key=pop_key, n_ind=n_ind, n_chr=n_chr, n_loci_per_chr=n_loci_per_chr
    )
    sp_base = SimParam.from_founder_pop(founder_pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=founder_pop, sim_param=sp_base,
        n_qtl_per_chr=n_qtl_per_chr, mean=jnp.array([0.]), var=jnp.array([1.])
    )
    founder_pop = set_pheno(
        key=pheno_key, pop=founder_pop, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([h2])
    )

    # --- 2. Define the body function for the scan (a single generation) ---
    def generation_step(carry, _):
        pop, key = carry
        key, pipeline_key = jax.random.split(key)
        
        next_pop = run_generation_pipeline(
            pipeline_key, pop, n_ind, h2, n_select_males, n_select_females,
            sp.n_chr, sp.gen_map, sp.recomb_params[0], sp.traits, sp.ploidy, sp.n_traits
        )
        
        # The carry for the next iteration is the new population and a new key
        # The value we store at each step ('y') is the mean phenotype
        return (next_pop, key), next_pop.pheno.mean()

    # --- 3. Run the compiled loop using jax.lax.scan ---
    initial_carry = (founder_pop, key)
    # The `scan` function loops `n_generations` times, repeatedly calling `generation_step`.
    (_, _), mean_phenotypes_over_gens = jax.lax.scan(
        generation_step, initial_carry, None, length=n_generations
    )
    
    # --- 4. Combine initial phenotype with the results from the scan ---
    all_mean_phenotypes = jnp.concatenate([
        jnp.array([founder_pop.pheno.mean()]),
        mean_phenotypes_over_gens
    ])
    
    return all_mean_phenotypes

# ---------------------------------------------------------------------------
# 4. New Vmapped Benchmark
# ---------------------------------------------------------------------------
def run_vmapped_benchmark(n_replicates=10, n_generations=20):
    """
    Runs multiple replicates of the simulation in parallel using vmap.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Simulation Parameters ---
    sim_params = {
        "n_generations": n_generations,
        "n_ind": 1000,
        "n_loci_per_chr": 1000,
        "n_qtl_per_chr": 100,
        "n_chr": 10,
        "h2": 0.3,
        "n_select_males": 50,
        "n_select_females": 500
    }
    
    # --- Create a vmapped version of the single-replicate function ---
    # `in_axes=(0, None, ...)` tells vmap to map over the first argument (the key)
    # and broadcast all other arguments, as they are the same for all replicates.
    vmapped_simulation = jax.vmap(run_single_replicate, in_axes=(0, None, None, None, None, None, None, None, None))
    
    print("1. JIT-compiling the vmapped simulation...")
    start_jit = time.time()
    main_key = jax.random.PRNGKey(42)
    replicate_keys = jax.random.split(main_key, n_replicates)
    
    # JIT compile by running once
    all_results = vmapped_simulation(replicate_keys, *sim_params.values()).block_until_ready()
    jit_time = time.time() - start_jit
    print(f"   JIT compilation took: {jit_time:.2f} seconds")

    # --- Run Benchmarked Simulation ---
    print(f"\n2. Starting {n_replicates} vmapped replicates for {n_generations} generations...")
    start_run = time.time()
    all_results = vmapped_simulation(replicate_keys, *sim_params.values()).block_until_ready()
    run_time = time.time() - start_run

    print("3. Simulation finished.")
    print(f"   - Total time for {n_replicates} replicates: {run_time:.2f} seconds")
    print(f"   - Time per replicate: {run_time / n_replicates:.3f} seconds")
    
    # --- Process and Plot Results ---
    mean_over_reps = all_results.mean(axis=0)
    std_over_reps = all_results.std(axis=0)

    plt.figure(figsize=(12, 7))
    generations = jnp.arange(n_generations + 1)
    
    # Plot the mean genetic gain
    plt.plot(generations, mean_over_reps, 'o-', color='royalblue', label="Mean Phenotype")
    
    # Plot the shaded error variance (±1 standard deviation)
    plt.fill_between(generations, 
                     mean_over_reps - std_over_reps, 
                     mean_over_reps + std_over_reps, 
                     color='royalblue', alpha=0.2, label="±1 Std. Dev.")
    
    plt.title(f"Response to Selection ({n_replicates} Replicates)")
    plt.xlabel("Generation")
    plt.ylabel("Mean Phenotypic Value")
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.xticks(range(0, n_generations + 1, 2))
    plt.tight_layout()
    plt.savefig("recurrent_selection_vmapped.png")
    print("\n4. Plot saved to 'recurrent_selection_vmapped.png'")

if __name__ == '__main__':
    run_vmapped_benchmark(n_replicates=50, n_generations=20)

In [None]:
# add_benchmark_founders.py

import jax
import jax.numpy as jnp
import time
from functools import partial
import matplotlib.pyplot as plt

# --- chewc imports ---
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.cross import _make_cross_geno
from chewc.phenotype import set_pheno

# ---------------------------------------------------------------------------
# 1. JIT-compatible Phenotypic Selection Function (Unchanged)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=("n_males", "n_females", "trait_idx"))
def select_pheno_jit(pop: Population, n_males: int, n_females: int, trait_idx: int = 0):
    """JIT-compatible phenotypic truncation selection."""
    phenotypes = pop.pheno[:, trait_idx]
    male_pheno = jnp.where(pop.sex == 0, phenotypes, -jnp.inf)
    male_iids = jnp.argsort(male_pheno)[-n_males:]
    female_pheno = jnp.where(pop.sex == 1, phenotypes, -jnp.inf)
    female_iids = jnp.argsort(female_pheno)[-n_females:]
    return male_iids, female_iids

# ---------------------------------------------------------------------------
# 2. JIT-compiled Single Generation Pipeline (Unchanged)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=(
    "n_crosses", "n_chr", "ploidy", "n_traits", "n_select_males", "n_select_females"
))
def run_generation_pipeline(
    key, parent_pop, n_crosses, h2, n_select_males, n_select_females,
    n_chr, gen_map, recomb_param_v, traits, ploidy, n_traits
):
    """JIT-compiled function for one generation."""
    key_pheno_parent, key_selection, key_cross, key_sex, key_pheno_progeny = jax.random.split(key, 5)
    parent_pop_phenotyped = set_pheno(
        key=key_pheno_parent, pop=parent_pop, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )
    male_iids, female_iids = select_pheno_jit(
        pop=parent_pop_phenotyped, n_males=n_select_males, n_females=n_select_females
    )
    mother_iids = jax.random.choice(key_selection, female_iids, shape=(n_crosses,), replace=True)
    father_iids = jax.random.choice(key_selection, male_iids, shape=(n_crosses,), replace=True)
    progeny_geno, progeny_ibd = _make_cross_geno(
        key_cross, parent_pop.geno[mother_iids], parent_pop.geno[father_iids],
        parent_pop.ibd[mother_iids], parent_pop.ibd[father_iids],
        n_chr, gen_map, recomb_param_v
    )
    progeny_pop_unphenotyped = Population(
        geno=progeny_geno, ibd=progeny_ibd,
        id=jnp.arange(parent_pop.nInd, parent_pop.nInd + n_crosses),
        iid=jnp.arange(n_crosses), mother=parent_pop.id[mother_iids],
        father=parent_pop.id[father_iids],
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_crosses,)),
        pheno=jnp.zeros((n_crosses, n_traits)),
        fixEff=jnp.zeros(n_crosses, dtype=jnp.float32),
        bv=jnp.zeros((n_crosses, n_traits))
    )
    progeny_pop_phenotyped = set_pheno(
        key=key_pheno_progeny, pop=progeny_pop_unphenotyped, traits=traits,
        ploidy=ploidy, h2=jnp.array([h2])
    )
    return progeny_pop_phenotyped

# ---------------------------------------------------------------------------
# 3. Updated Function for a Single Replicate (Handles Founder Population)
# ---------------------------------------------------------------------------

@partial(jax.jit, static_argnames=(
    "n_generations", "n_founders", "n_ind", "n_loci_per_chr", "n_qtl_per_chr",
    "n_chr", "h2", "n_select_males", "n_select_females"
))
def run_single_replicate(
    replicate_key, n_generations, n_founders, n_ind, n_loci_per_chr, n_qtl_per_chr,
    n_chr, h2, n_select_males, n_select_females
):
    """
    Runs a full simulation for one replicate, starting from a small founder group.
    """
    # --- 1. Setup small founder group and define trait architecture ---
    key, founder_key, trait_key, base_pop_key = jax.random.split(replicate_key, 4)
    
    initial_founders, gen_map = quick_haplo(
        key=founder_key, n_ind=n_founders, n_chr=n_chr, n_loci_per_chr=n_loci_per_chr
    )
    sp_base = SimParam.from_founder_pop(initial_founders, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=initial_founders, sim_param=sp_base,
        n_qtl_per_chr=n_qtl_per_chr, mean=jnp.array([0.]), var=jnp.array([1.])
    )

    # --- 2. Create Generation 0 by crossing the founders ---
    key_cross_plan, key_cross, key_sex, key_pheno = jax.random.split(base_pop_key, 4)
    
    # Randomly mate founders (selfing is allowed) to create the base population
    mother_iids = jax.random.choice(key_cross_plan, initial_founders.iid, shape=(n_ind,), replace=True)
    father_iids = jax.random.choice(key_cross_plan, initial_founders.iid, shape=(n_ind,), replace=True)

    base_pop_geno, base_pop_ibd = _make_cross_geno(
        key_cross, initial_founders.geno[mother_iids], initial_founders.geno[father_iids],
        initial_founders.ibd[mother_iids], initial_founders.ibd[father_iids],
        n_chr, gen_map, sp.recomb_params[0]
    )
    base_pop_unphenotyped = Population(
        geno=base_pop_geno, ibd=base_pop_ibd,
        id=jnp.arange(n_founders, n_founders + n_ind), iid=jnp.arange(n_ind),
        mother=initial_founders.id[mother_iids], father=initial_founders.id[father_iids],
        sex=jax.random.choice(key_sex, jnp.array([0, 1], dtype=jnp.int8), (n_ind,)),
        pheno=jnp.zeros((n_ind, sp.n_traits)), fixEff=jnp.zeros(n_ind, dtype=jnp.float32),
        bv=jnp.zeros((n_ind, sp.n_traits))
    )
    base_pop = set_pheno(
        key=key_pheno, pop=base_pop_unphenotyped, traits=sp.traits,
        ploidy=sp.ploidy, h2=jnp.array([h2])
    )

    # --- 3. Define the body function for the scan (a single generation) ---
    def generation_step(carry, _):
        pop, key = carry
        key, pipeline_key = jax.random.split(key)
        next_pop = run_generation_pipeline(
            pipeline_key, pop, n_ind, h2, n_select_males, n_select_females,
            sp.n_chr, sp.gen_map, sp.recomb_params[0], sp.traits, sp.ploidy, sp.n_traits
        )
        return (next_pop, key), next_pop.pheno.mean()

    # --- 4. Run the recurrent selection loop using jax.lax.scan ---
    initial_carry = (base_pop, key)
    (_, _), mean_phenotypes_over_gens = jax.lax.scan(
        generation_step, initial_carry, None, length=n_generations
    )
    
    # --- 5. Combine initial phenotype with the results from the scan ---
    all_mean_phenotypes = jnp.concatenate([
        jnp.array([base_pop.pheno.mean()]),
        mean_phenotypes_over_gens
    ])
    
    return all_mean_phenotypes

# ---------------------------------------------------------------------------
# 4. Vmapped Benchmark (call site is updated)
# ---------------------------------------------------------------------------
def run_vmapped_benchmark(n_replicates=10, n_generations=20):
    """
    Runs multiple replicates of the simulation in parallel using vmap.
    """
    print(f"JAX is using: {jax.devices()[0].platform.upper()}")

    # --- Simulation Parameters ---
    sim_params = {
        "n_generations": n_generations,
        "n_founders": 5,
        "n_ind": 1000,
        "n_loci_per_chr": 1000,
        "n_qtl_per_chr": 100,
        "n_chr": 10,
        "h2": 0.3,
        "n_select_males": 50,
        "n_select_females": 500
    }
    
    # --- vmap the single-replicate function ---
    vmapped_simulation = jax.vmap(run_single_replicate, in_axes=(0, None, None, None, None, None, None, None, None, None))
    
    print("1. JIT-compiling the vmapped simulation...")
    start_jit = time.time()
    main_key = jax.random.PRNGKey(42)
    replicate_keys = jax.random.split(main_key, n_replicates)
    
    all_results = vmapped_simulation(replicate_keys, *sim_params.values()).block_until_ready()
    jit_time = time.time() - start_jit
    print(f"   JIT compilation took: {jit_time:.2f} seconds")

    # --- Run Benchmarked Simulation ---
    print(f"\n2. Starting {n_replicates} vmapped replicates for {n_generations} generations...")
    start_run = time.time()
    all_results = vmapped_simulation(replicate_keys, *sim_params.values()).block_until_ready()
    run_time = time.time() - start_run

    print("3. Simulation finished.")
    print(f"   - Total time for {n_replicates} replicates: {run_time:.2f} seconds")
    print(f"   - Time per replicate: {run_time / n_replicates:.3f} seconds")
    
    # --- Process and Plot Results ---
    mean_over_reps = all_results.mean(axis=0)
    std_over_reps = all_results.std(axis=0)
    plt.figure(figsize=(12, 7))
    generations = jnp.arange(n_generations + 1)
    
    plt.plot(generations, mean_over_reps, 'o-', color='royalblue', label="Mean Phenotype")
    plt.fill_between(generations, 
                     mean_over_reps - std_over_reps, 
                     mean_over_reps + std_over_reps, 
                     color='royalblue', alpha=0.2, label="±1 Std. Dev.")
    
    plt.title(f"Response to Selection ({n_replicates} Replicates from {sim_params['n_founders']} Founders)")
    plt.xlabel("Generation")
    plt.ylabel("Mean Phenotypic Value")
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.xticks(range(0, n_generations + 1, 2))
    plt.tight_layout()
    plt.savefig("recurrent_selection_from_founders.png")
    print("\n4. Plot saved to 'recurrent_selection_from_founders.png'")

if __name__ == '__main__':
    run_vmapped_benchmark(n_replicates=50, n_generations=20)