In [3]:
import jax
import jax.numpy as jnp
from jax import lax
from functools import partial
import time
import matplotlib.pyplot as plt
import numpy as np
from typing import Literal, Dict, Optional
from flax.struct import dataclass as flax_dataclass

import jax
from jax import debug
import jax.numpy as jnp
from functools import partial
from jax.scipy.sparse.linalg import cg

# --- Import all necessary functions from your codebase ---
from chewc.population import Population, msprime_pop, combine_populations
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross
# --- ADD THIS IMPORT ---
# This function will calculate our EBVs
# ==================================================
# ---  Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_replicates": 5,
    "n_founder_ind": 10,
    "n_loci_per_chr": 100,
    "n_chr": 2,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 5,
    "population_size": 100,
    "n_select": 50,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.3]),
}

# ==================================================
# ---  Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
)

# ==================================================
# --- Step 1: Phenotype the Founder Population ---
# ==================================================
# We need a new JAX random key to ensure this step is reproducible.
key, pheno_key = jax.random.split(key)

# First, set the true breeding values (BVs) for the founder population.
# This isn't strictly necessary if you immediately call set_pheno,
# but it's good practice to see how it's done.
founder_pop = set_bv(
    pop=founder_pop,
    traits=sp.traits,
    ploidy=sp.ploidy
)

# Now, set the phenotypes. This function will also calculate BVs and GVs internally,
# then add environmental noise based on the heritability (h2).
founder_pop = set_pheno(
    key=pheno_key,
    pop=founder_pop,
    traits=sp.traits,
    ploidy=sp.ploidy,
    h2=simulation_parameters["h2"]
)

# --- Verification ---
# Let's check the correlation between the true breeding values and the phenotypes.
# It should be close to the square root of the heritability (sqrt(0.3) ≈ 0.547).
# This confirms our phenotype generation is working correctly.
corr = jnp.corrcoef(founder_pop.bv.flatten(), founder_pop.pheno.flatten())[0, 1]
print(f"Founder Population Ready: {founder_pop}")
print(f"Correlation(BV, Pheno): {corr:.4f}")
print(f"Expected Correlation (sqrt(h2)): {jnp.sqrt(simulation_parameters['h2'][0]):.4f}")

# ==================================================
# --- Step 2: Create the First Generation (G1) ---
# ==================================================
# We need a new JAX random key for parent selection and crossing.
key, select_key, cross_key = jax.random.split(key, 3)

# 1. --- Select Parents ---
# We'll select the top 4 males and top 4 females based on phenotype.
n_select_male = 4
n_select_female = 4

# Phenotypes are in the first (and only) column of the .pheno array
phenos = founder_pop.pheno[:, 0]
sexes = founder_pop.sex

# Get iids of all males and females
male_iids = founder_pop.iid[sexes == 1]
female_iids = founder_pop.iid[sexes == 0]

# Get their corresponding phenotypes
male_phenos = phenos[male_iids]
female_phenos = phenos[female_iids]

# Get the internal IDs (iids) of the top individuals of each sex
top_male_iids = male_iids[jnp.argsort(male_phenos)[-n_select_male:]]
top_female_iids = female_iids[jnp.argsort(female_phenos)[-n_select_female:]]

# 2. --- Define a Mating Plan ---
# We will create a cross plan to produce our target population size.
# For simplicity, we'll randomly pair the selected sires and dams.
n_progeny = simulation_parameters["population_size"]
sires = jax.random.choice(select_key, top_male_iids, shape=(n_progeny,))
dams = jax.random.choice(select_key, top_female_iids, shape=(n_progeny,))
cross_plan = jnp.stack([dams, sires], axis=1)


# 3. --- Generate Progeny ---
# The `make_cross` function needs to know the next available public ID.
next_id = founder_pop.id.max() + 1
progeny_pop = make_cross(
    key=cross_key,
    pop=founder_pop,
    cross_plan=cross_plan,
    sp=sp,
    next_id_start=next_id
)

# 4. --- Combine Populations ---
# Now we have a population with both founders and their offspring.
# This population has a pedigree structure.
combined_pop = combine_populations(founder_pop, progeny_pop)

# --- Verification ---
print(f"Progeny Population Created: {progeny_pop}")
print(f"Combined Population Ready: {combined_pop}")
print("\nPedigree of the first 5 offspring:")
print("ID | Mother | Father")
print("--------------------")
for i in range(5):
    offspring = progeny_pop.id[i]
    mother = progeny_pop.mother[i]
    father = progeny_pop.father[i]
    print(f"{offspring} | {mother}    | {father}")


# ===============================================================
# --- Step 3.1: Phenotype the Combined Population (G0 + G1) ---
# ===============================================================
# Get a new key for this round of phenotyping.
key, pheno_key_g1 = jax.random.split(key)

# Run set_pheno on the entire combined population.
# This ensures all individuals (founders and progeny) are evaluated
# under the same environmental variance conditions.
combined_pop = set_pheno(
    key=pheno_key_g1,
    pop=combined_pop,
    traits=sp.traits,
    ploidy=sp.ploidy,
    h2=simulation_parameters["h2"]
)

# --- Verification ---
# Check that the progeny now have breeding values and phenotypes.
print(f"Combined population phenotyped: {combined_pop}")
print(f"Mean BV of Founders (G0): {jnp.mean(combined_pop.bv[combined_pop.gen==0]):.4f}")
print(f"Mean BV of Progeny (G1):  {jnp.mean(combined_pop.bv[combined_pop.gen==1]):.4f}")


@partial(jax.jit, static_argnames=('n_ind',))
def _jit_calc_a_inverse(mother_iids: jnp.ndarray, father_iids: jnp.ndarray, n_ind: int) -> jnp.ndarray:
    initial_A_inv = jnp.zeros((n_ind, n_ind))
    def loop_body(i, A_inv):
        sire_iid, dam_iid = father_iids[i], mother_iids[i]
        case_index = (sire_iid != -1) + 2 * (dam_iid != -1)
        def case_0(mat): return mat.at[i, i].add(1.0)
        def case_1(mat): return mat.at[i, i].add(4/3.).at[sire_iid, sire_iid].add(1/3.).at[i, sire_iid].add(-2/3.).at[sire_iid, i].add(-2/3.)
        def case_2(mat): return mat.at[i, i].add(4/3.).at[dam_iid, dam_iid].add(1/3.).at[i, dam_iid].add(-2/3.).at[dam_iid, i].add(-2/3.)
        def case_3(mat): return (mat.at[i, i].add(2.0).at[sire_iid, sire_iid].add(0.5).at[dam_iid, dam_iid].add(0.5)
                               .at[sire_iid, dam_iid].add(0.5).at[dam_iid, sire_iid].add(0.5).at[i, sire_iid].add(-1.0)
                               .at[sire_iid, i].add(-1.0).at[i, dam_iid].add(-1.0).at[dam_iid, i].add(-1.0))
        return lax.switch(case_index, [case_0, case_1, case_2, case_3], A_inv)
    return lax.fori_loop(0, n_ind, loop_body, initial_A_inv)

def calc_a_inverse_matrix_pedigree_jax(pop: Population) -> jnp.ndarray:
    n_ind = pop.nInd
    id_to_iid_map = jnp.full(pop.id.max() + 2, -1, dtype=jnp.int32).at[pop.id].set(pop.iid)
    mother_iids = jnp.where(pop.mother < 0, -1, id_to_iid_map[pop.mother.clip(min=0)])
    father_iids = jnp.where(pop.father < 0, -1, id_to_iid_map[pop.father.clip(min=0)])
    return _jit_calc_a_inverse(mother_iids, father_iids, n_ind)

@flax_dataclass(frozen=True)
class PredictionResults:
    ids: jnp.ndarray; ebv: jnp.ndarray; pev: Optional[jnp.ndarray] = None; reliability: Optional[jnp.ndarray] = None
    fixed_effects: Optional[jnp.ndarray] = None; h2_used: Optional[float] = None; var_components: Optional[Dict] = None

@partial(jax.jit, static_argnames='n_ind')
def _mme_solver_cg(pheno: jnp.ndarray, train_mask: jnp.ndarray, K_inv: jnp.ndarray, h2: float, n_ind: int) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Solves the Mixed Model Equations using the Conjugate Gradient method.
    """
    alpha = (1.0 - h2) / h2
    y = jnp.nan_to_num(pheno.flatten())
    
    # Cast train_mask to float once for reuse
    train_mask_float = train_mask.astype(jnp.float32)

    def lhs_matvec(solution_vector):
        beta = solution_vector[0]
        u = solution_vector[1:]

        xtx_beta = jnp.sum(train_mask_float) * beta
        # --- FIX #1: Use multiplication instead of boolean indexing ---
        xtz_u = jnp.sum(u * train_mask_float)

        ztx_beta = train_mask_float * beta
        ztz_u_plus_a_inv_u = u * train_mask_float + alpha * (K_inv @ u)

        top_part = xtx_beta + xtz_u
        bottom_part = ztx_beta + ztz_u_plus_a_inv_u
        
        return jnp.concatenate([jnp.array([top_part]), bottom_part])

    # --- FIX #2: Use multiplication instead of boolean indexing ---
    rhs_top = jnp.sum(y * train_mask_float)
    rhs_bottom = y
    rhs = jnp.concatenate([jnp.array([rhs_top]), rhs_bottom])
    
    # Preconditioner
    M = jnp.concatenate([
        jnp.array([jnp.sum(train_mask_float)]),
        train_mask_float + alpha * jnp.diag(K_inv)
    ])
    
    # Add a small epsilon to the diagonal to avoid division by zero for non-phenotyped individuals
    M = M + 1e-6
    
    solutions, _ = cg(lhs_matvec, rhs, M=lambda x: x / M)

    return solutions[0:1], solutions[1:]

# You do not need to change `mme_predict_final_fix_cg`
def mme_predict_final_fix_cg(pop: Population, h2: float, trait_idx: int = 0) -> PredictionResults:
    pheno = pop.pheno[:, trait_idx:trait_idx+1]
    n_ind = pop.nInd
    train_mask = ~jnp.isnan(pheno.flatten())

    if jnp.sum(train_mask) == 0:
        raise ValueError("No individuals with phenotypes.")

    K_inv = calc_a_inverse_matrix_pedigree_jax(pop)
    fixed_effects, all_ebv = _mme_solver_cg(pheno, train_mask, K_inv, h2, n_ind)

    return PredictionResults(ids=pop.id, ebv=all_ebv.reshape(-1, 1), fixed_effects=fixed_effects)

# ==================================================
# --- Simulation & Benchmarking Loop ---
# ==================================================
print("--- Starting Pedigree Scaling Benchmark (Tracking Variance) ---")

# --- Initial Setup ---
key = jax.random.PRNGKey(42)
key, founder_key, sp1_key = jax.random.split(key, 3)
founder_pop, genetic_map = msprime_pop(key=founder_key, n_ind=20, n_loci_per_chr=1000, n_chr=5)
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(key=sp1_key, founder_pop=founder_pop, sim_param=sp, n_qtl_per_chr=100,
                 mean=jnp.array([0.0]), var=jnp.array([1.0]))
key, pheno_key = jax.random.split(key)
current_pop = set_pheno(key=pheno_key, pop=founder_pop, traits=sp.traits, ploidy=sp.ploidy, h2=jnp.array([0.3]))

# --- Data storage for results ---
n_generations = 15
population_sizes, ablup_timings, all_gens_accuracies, latest_gen_accuracies = [], [], [], []
genetic_variances = [jnp.var(current_pop.bv)] # Store initial variance

# --- Generational Loop ---
for gen in range(1, n_generations + 1):
    print(f"\n--- Generation {gen}/{n_generations} ---")
    pop_size = current_pop.nInd
    
    start_time = time.perf_counter()
    ablup_results = mme_predict_final_fix_cg(pop=current_pop, h2=0.3)
    ablup_results.ebv.block_until_ready()
    end_time = time.perf_counter()
    duration = end_time - start_time
    
    current_pop = current_pop.replace(ebv=ablup_results.ebv)
    
    # ========================= THE FIX =========================
    # --- Corrected Variance Check ---
    # The termination condition should be based on the variance of the EBVs
    # across ALL selection candidates, not the true BV of only the last generation.
    # If there is no variance in EBVs, selection is no longer effective.
    selectable_variance = jnp.var(current_pop.ebv)
    
    print(f"Population Size: {pop_size}, ABLUP Time: {duration:.4f}s")
    print(f"  -> Genetic Variance (Latest Gen BV): {jnp.var(current_pop.bv[current_pop.gen == gen - 1]):.6f}")
    print(f"  -> Selectable Variance (All EBVs):   {selectable_variance:.6f}")

    if selectable_variance < 1e-4:
        print("\n" + "="*60)
        print(f"TERMINATING SIMULATION at Generation {gen}.")
        print("Reason: Selectable genetic variance (variance of EBVs) is exhausted.")
        print("="*60)
        break
    # ==========================================================

    # --- If not terminated, append all data for this successful generation ---
    population_sizes.append(pop_size)
    ablup_timings.append(duration)
    latest_gen_mask = (current_pop.gen == (gen - 1))
    
    accuracy_latest = jnp.corrcoef(current_pop.bv[latest_gen_mask].flatten(),
                                   current_pop.ebv[latest_gen_mask].flatten())[0, 1]
    latest_gen_accuracies.append(accuracy_latest)
    
    print(f"  -> Accuracy (Latest Gen):   {latest_gen_accuracies[-1]:.4f}")
    
    # --- Create next generation (Unchanged from your original script) ---
    n_select_male, n_select_female = 20, 100
    ebvs = current_pop.ebv[:, 0]; sexes = current_pop.sex
    male_iids = current_pop.iid[sexes == 1].astype(jnp.int32)
    female_iids = current_pop.iid[sexes == 0].astype(jnp.int32)
    
    male_ebvs = ebvs[male_iids]; female_ebvs = ebvs[female_iids]
    top_male_iids = male_iids[jnp.argsort(male_ebvs)[-n_select_male:]]
    top_female_iids = female_iids[jnp.argsort(female_ebvs)[-n_select_female:]]

    key, select_key, cross_key, pheno_key = jax.random.split(key, 4)
    n_progeny = simulation_parameters['population_size']
    sires = jax.random.choice(select_key, top_male_iids, shape=(n_progeny,), replace=True)
    dams = jax.random.choice(select_key, top_female_iids, shape=(n_progeny,), replace=True)
    cross_plan = jnp.stack([dams, sires], axis=1)

    progeny_pop = make_cross(key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
                             next_id_start=current_pop.id.max() + 1)
    progeny_pop = set_pheno(key=pheno_key, pop=progeny_pop, traits=sp.traits,
                            ploidy=sp.ploidy, h2=jnp.array([0.3]))
    current_pop = combine_populations(current_pop, progeny_pop)



Founder Population Ready: Population(nInd=10, nTraits=1, has_ebv=No)
Correlation(BV, Pheno): 0.4804
Expected Correlation (sqrt(h2)): 0.5477
Progeny Population Created: Population(nInd=100, nTraits=1, has_ebv=No)
Combined Population Ready: Population(nInd=110, nTraits=1, has_ebv=No)

Pedigree of the first 5 offspring:
ID | Mother | Father
--------------------
10 | 0    | 6
11 | 0    | 6
12 | 2    | 7
13 | 2    | 7
14 | 2    | 7
Combined population phenotyped: Population(nInd=110, nTraits=1, has_ebv=No)
Mean BV of Founders (G0): 0.0954
Mean BV of Progeny (G1):  0.2618
--- Starting Pedigree Scaling Benchmark (Tracking Variance) ---

--- Generation 1/15 ---
Population Size: 20, ABLUP Time: 0.7674s
  -> Genetic Variance (Latest Gen BV): 1.000000
  -> Selectable Variance (All EBVs):   0.332959
  -> Accuracy (Latest Gen):   0.4388

--- Generation 2/15 ---
Population Size: 120, ABLUP Time: 0.3630s
  -> Genetic Variance (Latest Gen BV): 0.725478
  -> Selectable Variance (All EBVs):   0.300516
 