In [1]:
import jax
import jax.numpy as jnp
from jax.numpy.linalg import inv
from functools import partial
import time
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Optional
from flax.struct import dataclass as flax_dataclass
from jax.scipy.sparse.linalg import cg

# --- Import all necessary functions from your codebase ---
from chewc.population import Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_bv, set_pheno
from chewc.cross import make_cross

# ==================================================
# --- Simulation Parameters (Unchanged) ---
# ==================================================
simulation_parameters = {
    "n_founder_ind": 50,  # Increased founders for more initial variance
    "n_loci_per_chr": 1000,
    "n_chr": 5,
    "n_qtl_per_chr": 100,
    "trait_mean": jnp.array([0.0]),
    "trait_var": jnp.array([1.0]),
    "n_generations": 20, # Reduced generations to see effect faster
    "population_size": 200, # Size of each new generation
    "n_select_male": 10,
    "n_select_female": 50,
    "key": jax.random.PRNGKey(42),
    "h2": jnp.array([.4]), # Slightly higher h2 for stronger selection response
}

# ==================================================
# --- Setup (Unchanged) ---
# ==================================================
key = simulation_parameters["key"]
key, founder_key, sp1_key = jax.random.split(key, 3)

founder_pop, genetic_map = msprime_pop(
    key=founder_key,
    n_ind=simulation_parameters["n_founder_ind"],
    n_loci_per_chr=simulation_parameters["n_loci_per_chr"],
    n_chr=simulation_parameters["n_chr"]
)

sp = SimParam.from_founder_pop(founder_pop, genetic_map)
sp = add_trait_a(
    key=sp1_key, founder_pop=founder_pop, sim_param=sp,
    n_qtl_per_chr=simulation_parameters["n_qtl_per_chr"],
    mean=simulation_parameters["trait_mean"],
    var=simulation_parameters["trait_var"],
)

# ==================================================
# --- GBLUP-Specific Functions (Unchanged) ---
# ==================================================
@jax.jit
def calc_g_matrix(geno_dosage: jnp.ndarray) -> jnp.ndarray:
    n_ind, n_markers = geno_dosage.shape
    p = jnp.mean(geno_dosage, axis=0) / 2.0
    P = 2 * p
    M = geno_dosage - P
    denominator = 2 * jnp.sum(p * (1 - p))
    G = (M @ M.T) / denominator
    epsilon = 1e-4
    return G + jnp.identity(n_ind) * epsilon

@flax_dataclass(frozen=True)
class PredictionResults:
    ids: jnp.ndarray; ebv: jnp.ndarray; fixed_effects: Optional[jnp.ndarray] = None; h2_used: Optional[float] = None

@partial(jax.jit, static_argnames='n_ind')
def _mme_solver_cg(pheno: jnp.ndarray, train_mask: jnp.ndarray, K_inv: jnp.ndarray, h2: float, n_ind: int) -> tuple[jnp.ndarray, jnp.ndarray]:
    alpha = (1.0 - h2) / h2
    y = jnp.nan_to_num(pheno.flatten())
    train_mask_float = train_mask.astype(jnp.float32)

    def lhs_matvec(solution_vector):
        beta, u = solution_vector[0], solution_vector[1:]
        xtx_beta = jnp.sum(train_mask_float) * beta
        xtz_u = jnp.sum(u * train_mask_float)
        ztx_beta = train_mask_float * beta
        ztz_u_plus_kinv_u = u * train_mask_float + alpha * (K_inv @ u)
        return jnp.concatenate([jnp.array([xtx_beta + xtz_u]), ztx_beta + ztz_u_plus_kinv_u])

    rhs = jnp.concatenate([jnp.array([jnp.sum(y * train_mask_float)]), y])
    M_diag = jnp.concatenate([jnp.array([jnp.sum(train_mask_float)]), train_mask_float + alpha * jnp.diag(K_inv)])
    M_diag = jnp.maximum(M_diag, 1e-6)
    solutions, _ = cg(lhs_matvec, rhs, M=lambda x: x / M_diag)
    return solutions[0:1], solutions[1:]

def mme_predict_gblup(pop: Population, h2: float, trait_idx: int = 0) -> PredictionResults:
    pheno = pop.pheno[:, trait_idx:trait_idx+1]
    n_ind = pop.nInd
    train_mask = ~jnp.isnan(pheno.flatten())
    if jnp.sum(train_mask) == 0: raise ValueError("No individuals with phenotypes.")
    G = calc_g_matrix(pop.dosage)
    K_inv = jnp.linalg.inv(G)
    fixed_effects, all_ebv = _mme_solver_cg(pheno, train_mask, K_inv, h2, n_ind)
    return PredictionResults(ids=pop.id, ebv=all_ebv.reshape(-1, 1), fixed_effects=fixed_effects, h2_used=h2)


# ==================================================
# --- Corrected Simulation & Benchmarking Loop ---
# ==================================================
print("--- Starting GBLUP Benchmark with Non-Overlapping Generations ---")

# --- Initial Setup ---
key, pheno_key = jax.random.split(key)
current_pop = set_pheno(key=pheno_key, pop=founder_pop, traits=sp.traits,
                        ploidy=sp.ploidy, h2=simulation_parameters["h2"])
# This is our initial population for the loop
genetic_variances = [jnp.var(current_pop.bv)]
accuracies = []
mean_bvs = [jnp.mean(current_pop.bv)]

# --- Generational Loop ---
for gen in range(simulation_parameters["n_generations"]):
    print(f"\n--- Generation {gen + 1}/{simulation_parameters['n_generations']} ---")
    
    # 1. --- Prediction ---
    # Predict EBVs for the current generation
    gblup_results = mme_predict_gblup(pop=current_pop, h2=simulation_parameters["h2"][0])
    current_pop = current_pop.replace(ebv=gblup_results.ebv)

    # 2. --- Reporting ---
    accuracy = jnp.corrcoef(current_pop.bv.flatten(), current_pop.ebv.flatten())[0, 1]
    accuracies.append(accuracy)
    print(f"  -> Population Size: {current_pop.nInd}")
    print(f"  -> Prediction Accuracy: {accuracy:.4f}")
    print(f"  -> Genetic Variance:    {genetic_variances[-1]:.4f}")

    # 3. --- Selection ---
    ebvs, sexes = current_pop.ebv[:, 0], current_pop.sex
    male_iids = current_pop.iid[sexes == 1]
    female_iids = current_pop.iid[sexes == 0]
    
    num_males_to_select = min(simulation_parameters['n_select_male'], len(male_iids))
    num_females_to_select = min(simulation_parameters['n_select_female'], len(female_iids))

    top_male_iids = male_iids[jnp.argsort(ebvs[male_iids])[-num_males_to_select:]]
    top_female_iids = female_iids[jnp.argsort(ebvs[female_iids])[-num_females_to_select:]]

    # 4. --- Mating & Creating the NEXT Generation ---
    key, select_key, cross_key, pheno_key = jax.random.split(key, 4)
    sires = jax.random.choice(select_key, top_male_iids, shape=(simulation_parameters['population_size'],), replace=True)
    dams = jax.random.choice(select_key, top_female_iids, shape=(simulation_parameters['population_size'],), replace=True)
    cross_plan = jnp.stack([dams, sires], axis=1)

    next_pop = make_cross(key=cross_key, pop=current_pop, cross_plan=cross_plan, sp=sp,
                          next_id_start=current_pop.id.max() + 1)
    
    # Phenotype the new generation and calculate their true BVs
    next_pop = set_pheno(key=pheno_key, pop=next_pop, traits=sp.traits,
                         ploidy=sp.ploidy, h2=simulation_parameters["h2"])
    
    # 5. --- LOGIC FIX: Replace the old generation with the new one ---
    current_pop = next_pop
    
    # Store metrics for the new population
    genetic_variances.append(jnp.var(current_pop.bv))
    mean_bvs.append(jnp.mean(current_pop.bv))
    
    # 6. --- Termination Check ---
    # Now this check is meaningful because it's on the entire population of selection candidates
    if genetic_variances[-1] < 1e-4:
        print(f"\nTERMINATING: Genetic variance in generation {gen + 1} is exhausted.")
        break



--- Starting GBLUP Benchmark with Non-Overlapping Generations ---

--- Generation 1/20 ---
  -> Population Size: 50
  -> Prediction Accuracy: 0.6715
  -> Genetic Variance:    1.0000

--- Generation 2/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7608
  -> Genetic Variance:    1.0125

--- Generation 3/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7308
  -> Genetic Variance:    0.8886

--- Generation 4/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7706
  -> Genetic Variance:    0.8711

--- Generation 5/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7069
  -> Genetic Variance:    0.8466

--- Generation 6/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7186
  -> Genetic Variance:    0.9572

--- Generation 7/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7862
  -> Genetic Variance:    0.9303

--- Generation 8/20 ---
  -> Population Size: 200
  -> Prediction Accuracy: 0.7809
  -> Genetic Variance:    0.8641