In [1]:
"""
Thompson Sampling vs Truncation Selection Experiment
Cell 1: Founder Population and Trait Architecture Setup
"""

import jax
import jax.numpy as jnp
from functools import partial

# Import chewc modules
from chewc.population import quick_haplo, Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# Experimental parameters
EXPERIMENTAL_CONFIG = {
    'population_size': 2000,
    'n_chromosomes': 10,
    'n_loci_per_chr': 1000,
    'heritabilities': [0.1, 0.3, 0.5, 0.7],
    'trait_architectures': {
        'oligogenic': {'n_qtl_per_chr': 10, 'gamma_effects': True},
        'polygenic': {'n_qtl_per_chr': 100, 'gamma_effects': False},
        'mixed': {'n_qtl_per_chr': 50, 'gamma_effects': True}
    }
}

print("=== Thompson Sampling Experiment Setup ===")
print(f"Population size: {EXPERIMENTAL_CONFIG['population_size']}")
print(f"Genome: {EXPERIMENTAL_CONFIG['n_chromosomes']} chromosomes × {EXPERIMENTAL_CONFIG['n_loci_per_chr']} loci")
print(f"Total loci: {EXPERIMENTAL_CONFIG['n_chromosomes'] * EXPERIMENTAL_CONFIG['n_loci_per_chr']:,}")

# ===== STEP 1: Create Founder Population =====
print("\n1. Creating founder population...")
key, pop_key = jax.random.split(key)

founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=EXPERIMENTAL_CONFIG['population_size'],
    n_chr=EXPERIMENTAL_CONFIG['n_chromosomes'],
    n_loci_per_chr=EXPERIMENTAL_CONFIG['n_loci_per_chr'],
    ploidy=2,
    # inbred=False,
    # chr_len_cm=100.0  # 100 cM chromosomes
)

print(f"✓ Created founder population: {founder_pop}")
print(f"✓ Genetic map shape: {genetic_map.shape}")

# ===== STEP 2: Initialize SimParam =====
print("\n2. Initializing simulation parameters...")
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
print(f"✓ SimParam initialized: {sp}")

# ===== STEP 3: Add Trait Architectures =====
print("\n3. Setting up trait architectures...")

# We'll create a dictionary to store populations with different trait architectures
populations_by_architecture = {}

for arch_name, arch_params in EXPERIMENTAL_CONFIG['trait_architectures'].items():
    print(f"\n   Setting up {arch_name} architecture:")
    print(f"   - QTL per chromosome: {arch_params['n_qtl_per_chr']}")
    print(f"   - Gamma effects: {arch_params['gamma_effects']}")
    
    key, trait_key = jax.random.split(key)
    
    # Add single trait with this architecture
    sp_with_trait = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=arch_params['n_qtl_per_chr'],
        mean=jnp.array([10.0]),  # Single trait with mean 10
        var=jnp.array([4.0]),    # Variance of 4 (std dev = 2)
        gamma=arch_params['gamma_effects']
    )
    
    # Calculate genetic values for verification
    key, pheno_key = jax.random.split(key)
    pop_with_gv = set_pheno(
        key=pheno_key,
        pop=founder_pop,
        traits=sp_with_trait.traits,
        ploidy=sp_with_trait.ploidy,
        h2=jnp.array([0.5])  # Use moderate heritability for setup
    )
    
    populations_by_architecture[arch_name] = {
        'population': pop_with_gv,
        'sim_param': sp_with_trait,
        'n_qtl': arch_params['n_qtl_per_chr'] * EXPERIMENTAL_CONFIG['n_chromosomes']
    }
    
    # Print summary statistics
    bv_mean = jnp.mean(pop_with_gv.bv[:, 0])
    bv_var = jnp.var(pop_with_gv.bv[:, 0])
    gv_mean = jnp.mean(pop_with_gv.gv[:, 0])
    gv_var = jnp.var(pop_with_gv.gv[:, 0])
    
    print(f"   ✓ Total QTL: {populations_by_architecture[arch_name]['n_qtl']}")
    print(f"   ✓ Breeding value: mean={bv_mean:.2f}, var={bv_var:.2f}")
    print(f"   ✓ Genetic value: mean={gv_mean:.2f}, var={gv_var:.2f}")

print(f"\n=== Setup Complete ===")
print(f"Ready to run experiments with {len(populations_by_architecture)} trait architectures")
print("Next: Implement prediction methods (GBLUP) and selection strategies")

# Store the current state for next cells
FOUNDER_POPULATIONS = populations_by_architecture
EXPERIMENT_KEY = key



=== Thompson Sampling Experiment Setup ===
Population size: 2000
Genome: 10 chromosomes × 1000 loci
Total loci: 10,000

1. Creating founder population...


  out = np.asarray(object, dtype=dtype)


✓ Created founder population: Population(nInd=2000, nTraits=0, has_ebv=No)
✓ Genetic map shape: (10, 1000)

2. Initializing simulation parameters...
✓ SimParam initialized: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

3. Setting up trait architectures...

   Setting up oligogenic architecture:
   - QTL per chromosome: 10
   - Gamma effects: True
   ✓ Total QTL: 100
   ✓ Breeding value: mean=2.90, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

   Setting up polygenic architecture:
   - QTL per chromosome: 100
   - Gamma effects: False
   ✓ Total QTL: 1000
   ✓ Breeding value: mean=2.38, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

   Setting up mixed architecture:
   - QTL per chromosome: 50
   - Gamma effects: True
   ✓ Total QTL: 500
   ✓ Breeding value: mean=-0.35, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

=== Setup Complete ===
Ready to run experiments with 3 trait architectures
Next: Implement prediction methods (GBLUP) and selection strategies


In [None]:
"""
Cell 2: GBLUP with Uncertainty Estimation
Implements the core prediction infrastructure for Thompson Sampling
"""

import jax
import jax.numpy as jnp
from functools import partial
from typing import Tuple, NamedTuple

# Design Decision: Separate Predictor Objects
# We keep predictions separate from Population objects to maintain clean
# separation of concerns and match real breeding program workflows

class PredictionResult(NamedTuple):
    """Container for GBLUP predictions and uncertainty estimates."""
    pred_bv: jnp.ndarray      # Predicted breeding values (nInd, nTraits)
    pred_var: jnp.ndarray     # Prediction variances (nInd, nTraits) 
    accuracy: jnp.ndarray     # Prediction accuracies (nInd, nTraits)

class GBLUPPredictor:
    """
    A trained GBLUP model that can make predictions with uncertainty.
    
    This object stores the trained model components and provides a clean
    interface for making predictions on new populations.
    """
    def __init__(self, 
                 training_dosage: jnp.ndarray,
                 training_phenotypes: jnp.ndarray, 
                 h2: jnp.ndarray):
        """
        Args:
            training_dosage: Marker dosage matrix from training population
            training_phenotypes: Phenotype matrix from training population  
            h2: Heritability estimates for each trait
        """
        self.training_dosage = training_dosage
        self.training_phenotypes = training_phenotypes
        self.h2 = h2
        
        # Pre-compute training population G matrix for efficiency
        self.G_train = compute_genomic_relationship_matrix(training_dosage)
        
    def predict(self, candidate_dosage: jnp.ndarray) -> PredictionResult:
        """Make predictions for candidate individuals."""
        G_cross = self._compute_cross_relationships(candidate_dosage)
        
        return gblup_with_uncertainty(
            G_train=self.G_train,
            G_cross=G_cross,
            y=self.training_phenotypes,
            h2=self.h2
        )
    
    def _compute_cross_relationships(self, candidate_dosage: jnp.ndarray) -> jnp.ndarray:
        """Compute cross-relationships between training and candidate populations."""
        # Center both matrices using training population allele frequencies
        training_freqs = jnp.mean(self.training_dosage, axis=0) / 2.0
        
        centered_training = self.training_dosage - 2.0 * training_freqs
        centered_candidates = candidate_dosage - 2.0 * training_freqs
        
        # Denominator from training population
        denominator = 2.0 * jnp.sum(training_freqs * (1.0 - training_freqs))
        
        # Cross-relationships: G_cross[i,j] = relationship between training[i] and candidate[j]
        G_cross = jnp.dot(centered_training, centered_candidates.T) / denominator
        
        return G_cross

@partial(jax.jit, static_argnames=())
def compute_cross_relationships(training_dosage: jnp.ndarray, 
                               candidate_dosage: jnp.ndarray) -> jnp.ndarray:
    """JIT-compiled function to compute cross-relationships between populations."""
    # Center both matrices using training population allele frequencies
    training_freqs = jnp.mean(training_dosage, axis=0) / 2.0
    
    centered_training = training_dosage - 2.0 * training_freqs
    centered_candidates = candidate_dosage - 2.0 * training_freqs
    
    # Denominator from training population
    denominator = 2.0 * jnp.sum(training_freqs * (1.0 - training_freqs))
    
    # Cross-relationships: G_cross[i,j] = relationship between training[i] and candidate[j]
    G_cross = jnp.dot(centered_training, centered_candidates.T) / denominator
    
    return G_cross

@partial(jax.jit, static_argnames=())
def compute_genomic_relationship_matrix(dosage_matrix: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the genomic relationship matrix (G) using the VanRaden method.
    
    G = (M - 2p)(M - 2p)' / (2 * sum(p_i * (1 - p_i)))
    
    Where M is the marker matrix and p is the vector of allele frequencies.
    
    Args:
        dosage_matrix: Marker dosage matrix of shape (nInd, nLoci)
        
    Returns:
        Genomic relationship matrix of shape (nInd, nInd)
    """
    n_ind, n_loci = dosage_matrix.shape
    
    # Calculate allele frequencies (p)
    allele_freqs = jnp.mean(dosage_matrix, axis=0) / 2.0
    
    # Center the marker matrix: M - 2p
    centered_markers = dosage_matrix - 2.0 * allele_freqs
    
    # Calculate the denominator: 2 * sum(p_i * (1 - p_i))
    denominator = 2.0 * jnp.sum(allele_freqs * (1.0 - allele_freqs))
    
    # Compute G = (M - 2p)(M - 2p)' / denominator
    G = jnp.dot(centered_markers, centered_markers.T) / denominator
    
    return G

@partial(jax.jit, static_argnames=())
def gblup_with_uncertainty(
    G_train: jnp.ndarray,           # Training population genomic relationships (nTrain, nTrain)
    G_cross: jnp.ndarray,           # Cross-relationships (nTrain, nCandidates) 
    y: jnp.ndarray,                 # Training phenotypes (nTrain, nTraits)
    h2: jnp.ndarray                 # Heritability (nTraits,)
) -> PredictionResult:
    """
    GBLUP predictions with Prediction Error Variance (PEV) for uncertainty estimation.
    
    This is the core function enabling Thompson Sampling by providing both
    predictions and their associated uncertainty.
    
    Mathematical Details:
    1. Solve mixed model equations: (G + λI)û = y
    2. Calculate PEV = diag(G_cross^T @ C^-1 @ G_cross) for prediction accuracy
    3. Convert to prediction variance using error variance
    
    Args:
        G_train: Genomic relationship matrix for training individuals
        G_cross: Cross-relationships between training and candidate individuals
        y: Phenotypic observations for training individuals
        h2: Heritability for each trait
        
    Returns:
        PredictionResult containing predictions, variances, and accuracies
    """
    n_train, n_traits = y.shape
    n_candidates = G_cross.shape[1]
    
    # Add small regularization for numerical stability
    epsilon = 1e-6
    
    # Initialize output arrays
    pred_bv = jnp.zeros((n_candidates, n_traits))
    pred_var = jnp.zeros((n_candidates, n_traits))
    accuracy = jnp.zeros((n_candidates, n_traits))
    
    # Process each trait separately (allows different heritabilities)
    for trait_idx in range(n_traits):
        current_y = y[:, trait_idx]
        current_h2 = h2[trait_idx]
        
        # Mixed model equations: (G + λI)û = y
        lambda_reg = (1.0 - current_h2) / (current_h2 + epsilon)
        C = G_train + lambda_reg * jnp.eye(n_train) + epsilon * jnp.eye(n_train)
        
        # Solve for breeding values of training population
        C_inv = jnp.linalg.inv(C)
        u_train = C_inv @ current_y
        
        # Predict breeding values for candidates
        pred_bv_trait = G_cross.T @ u_train
        
        # Calculate Prediction Error Variance (PEV)
        # PEV measures the uncertainty in our predictions
        pev_diagonal = jnp.diag(G_cross.T @ C_inv @ G_cross)
        
        # Convert PEV to prediction variance
        # This is the variance of the prediction error
        var_g = current_h2 * jnp.var(current_y) / (current_h2 + epsilon)
        pred_var_trait = (1.0 - pev_diagonal * current_h2) * var_g
        
        # Ensure prediction variance is non-negative
        pred_var_trait = jnp.maximum(pred_var_trait, epsilon)
        
        # Calculate prediction accuracy (correlation between true and predicted BV)
        accuracy_trait = jnp.sqrt(pev_diagonal * current_h2)
        
        # Store results
        pred_bv = pred_bv.at[:, trait_idx].set(pred_bv_trait)
        pred_var = pred_var.at[:, trait_idx].set(pred_var_trait)
        accuracy = accuracy.at[:, trait_idx].set(accuracy_trait)
    
    return PredictionResult(
        pred_bv=pred_bv,
        pred_var=pred_var,
        accuracy=accuracy
    )

@partial(jax.jit, static_argnames=())
def self_prediction_gblup(
    pop: Population,
    traits_config,  # TraitCollection from SimParam
    ploidy: int,
    h2: jnp.ndarray
) -> PredictionResult:
    """
    Performs GBLUP self-prediction where training and candidate sets are identical.
    
    This is useful for initial testing and demonstrates the uncertainty estimation
    on the founder population itself.
    
    Args:
        pop: Population to predict on
        traits_config: Trait configuration from SimParam
        ploidy: Ploidy level
        h2: Heritability array
        
    Returns:
        PredictionResult with self-predictions and uncertainties
    """
    # Get dosage matrix for genomic relationships
    dosage = pop.dosage
    
    # Compute genomic relationship matrix
    G = compute_genomic_relationship_matrix(dosage)
    
    # For self-prediction, G_cross = G_train = G
    return gblup_with_uncertainty(
        G_train=G,
        G_cross=G,
        y=pop.pheno,
        h2=h2
    )

# ===== DEMONSTRATION: Train and Test GBLUP Predictors =====
print("\n=== GBLUP Training and Validation ===")
print("Training GBLUP predictors on founder populations...")

# Store trained predictors for each architecture
trained_predictors = {}

# Test GBLUP on each trait architecture
for arch_name, arch_data in FOUNDER_POPULATIONS.items():
    print(f"\n--- {arch_name.upper()} Architecture ---")
    
    pop = arch_data['population']
    sp = arch_data['sim_param']
    
    # Train predictor on founder population
    test_h2 = jnp.array([0.5])
    
    predictor = GBLUPPredictor(
        training_dosage=pop.dosage,
        training_phenotypes=pop.pheno,
        h2=test_h2
    )
    
    # Test self-prediction performance
    prediction_result = predictor.predict(pop.dosage)
    
    # Calculate prediction statistics
    true_bv = pop.bv[:, 0]
    pred_bv = prediction_result.pred_bv[:, 0]
    pred_var = prediction_result.pred_var[:, 0]
    
    # Correlation between true and predicted breeding values
    correlation = jnp.corrcoef(true_bv, pred_bv)[0, 1]
    
    # Mean prediction variance (uncertainty)
    mean_uncertainty = jnp.mean(jnp.sqrt(pred_var))
    
    print(f"Prediction accuracy (correlation): {correlation:.3f}")
    print(f"Mean prediction uncertainty (std): {mean_uncertainty:.3f}")
    print(f"Prediction variance range: [{jnp.min(pred_var):.3f}, {jnp.max(pred_var):.3f}]")
    
    # Verify that high uncertainty individuals have lower prediction accuracy
    # Split into high/low uncertainty groups
    uncertainty_threshold = jnp.median(pred_var)
    high_uncertainty_mask = pred_var > uncertainty_threshold
    
    high_unc_corr = jnp.corrcoef(
        true_bv[high_uncertainty_mask], 
        pred_bv[high_uncertainty_mask]
    )[0, 1]
    
    low_unc_corr = jnp.corrcoef(
        true_bv[~high_uncertainty_mask], 
        pred_bv[~high_uncertainty_mask]
    )[0, 1]
    
    print(f"High uncertainty individuals: r = {high_unc_corr:.3f}")
    print(f"Low uncertainty individuals: r = {low_unc_corr:.3f}")
    print(f"Uncertainty captures prediction quality: {low_unc_corr > high_unc_corr}")
    
    # Store the trained predictor for use in selection experiments
    trained_predictors[arch_name] = predictor

print("\n=== GBLUP Implementation Complete ===")
print("✓ Genomic relationship matrix computation")
print("✓ Mixed model equation solving")  
print("✓ Prediction Error Variance (PEV) calculation")
print("✓ Uncertainty estimation validated")
print("✓ Trained predictors stored for selection experiments")
print("\nNext: Implement Thompson Sampling selection strategy")

In [None]:
"""
Cell 3: Thompson Sampling Selection Strategy
Implements the core selection methods for comparing strategies
"""

import jax
import jax.numpy as jnp
from functools import partial
from typing import Dict, Tuple
import numpy as np

# ===== SELECTION STRATEGIES =====

@partial(jax.jit, static_argnames=('n_parents',))
def truncation_selection(
    pred_bv: jnp.ndarray,       # Predicted breeding values (nCandidates, nTraits)
    n_parents: int              # Number of parents to select
) -> jnp.ndarray:
    """
    Traditional truncation selection - always picks top predicted individuals.
    
    This is the baseline method used in most breeding programs.
    Ignores prediction uncertainty completely.
    
    Args:
        pred_bv: Predicted breeding values for all candidates
        n_parents: Number of parents to select (static for JIT)
        
    Returns:
        Array of selected individual indices (iids)
    """
    # For single trait, select based on first (and only) trait
    trait_bv = pred_bv[:, 0] if pred_bv.ndim > 1 else pred_bv
    
    # Always pick the top n_parents by predicted breeding value
    return jnp.argsort(trait_bv)[-n_parents:]

@partial(jax.jit, static_argnames=('n_parents',))
def thompson_sampling_selection(
    key: jax.random.PRNGKey,
    pred_bv: jnp.ndarray,       # Predicted breeding values (nCandidates, nTraits)  
    pred_var: jnp.ndarray,      # Prediction variances (nCandidates, nTraits)
    n_parents: int              # Number of parents to select
) -> jnp.ndarray:
    """
    Thompson Sampling selection - balances exploitation vs exploration.
    
    Instead of always picking the "best" predicted individuals, this method:
    1. Samples once from each candidate's posterior distribution N(μ, σ²)
    2. Selects the top n_parents from this single sample
    3. Naturally balances high predicted merit vs high uncertainty
    
    High uncertainty individuals get a "chance" to be selected if they
    sample high, enabling exploration of promising but uncertain genetics.
    
    Args:
        key: JAX random key for sampling
        pred_bv: Predicted breeding values for all candidates
        pred_var: Prediction variances (uncertainty) for all candidates
        n_parents: Number of parents to select (static for JIT)
        
    Returns:
        Array of selected individual indices (iids)
    """
    # For single trait, use first trait
    trait_bv = pred_bv[:, 0] if pred_bv.ndim > 1 else pred_bv
    trait_var = pred_var[:, 0] if pred_var.ndim > 1 else pred_var
    
    # Sample once from each candidate's posterior distribution
    # This is the heart of Thompson Sampling
    sampled_bv = jax.random.normal(key, trait_bv.shape) * jnp.sqrt(trait_var) + trait_bv
    
    # Select top n_parents based on this single sample
    # High uncertainty individuals can "get lucky" and be selected
    return jnp.argsort(sampled_bv)[-n_parents:]

@partial(jax.jit, static_argnames=())
def calculate_selection_metrics_jax(
    true_bv: jnp.ndarray,
    selected_indices: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Calculate selection response and diversity metrics (JAX arrays only).
    
    Args:
        true_bv: True breeding values of all candidates
        selected_indices: Indices of selected individuals
        
    Returns:
        Tuple of (selection_response, genetic_diversity, selection_differential, population_mean)
    """    
    # Selection response: mean genetic merit of selected individuals
    selected_bv = true_bv[selected_indices]
    selection_response = jnp.mean(selected_bv)
    
    # Genetic diversity: variance among selected individuals
    genetic_diversity = jnp.var(selected_bv)
    
    # Selection differential: difference from population mean
    population_mean = jnp.mean(true_bv)
    selection_differential = selection_response - population_mean
    
    return selection_response, genetic_diversity, selection_differential, population_mean

def calculate_selection_metrics(
    true_bv: jnp.ndarray,
    selected_indices: jnp.ndarray,
    selection_intensity: float
) -> Dict[str, float]:
    """
    Python wrapper for selection metrics that converts to dictionary.
    """
    selection_response, genetic_diversity, selection_differential, population_mean = \
        calculate_selection_metrics_jax(true_bv, selected_indices)
    
    return {
        'selection_response': float(selection_response),
        'genetic_diversity': float(genetic_diversity),
        'selection_differential': float(selection_differential),
        'population_mean': float(population_mean)
    }

# ===== SELECTION COMPARISON FUNCTION =====

def compare_selection_strategies(
    key: jax.random.PRNGKey,
    pop: Population,
    predictor: GBLUPPredictor,
    n_parents: int = 100,
    n_replicates: int = 10
) -> Dict[str, Dict]:
    """
    Compare Thompson Sampling vs Truncation Selection on a single population.
    
    This function demonstrates the key difference: truncation is deterministic
    while Thompson Sampling introduces strategic randomness based on uncertainty.
    
    Args:
        key: JAX random key
        pop: Population to select from
        predictor: Trained GBLUP predictor
        n_parents: Number of parents to select
        n_replicates: Number of Thompson Sampling replicates to run
        
    Returns:
        Dictionary comparing the two strategies
    """
    # Get predictions and uncertainties
    predictions = predictor.predict(pop.dosage)
    true_bv = pop.bv[:, 0]  # Single trait
    
    print(f"\nComparing selection strategies (selecting {n_parents}/{pop.nInd} individuals)")
    print(f"Selection intensity: {100 * n_parents / pop.nInd:.1f}%")
    
    # ===== TRUNCATION SELECTION (DETERMINISTIC) =====
    trunc_selected = truncation_selection(predictions.pred_bv, n_parents)
    trunc_metrics = calculate_selection_metrics(true_bv, trunc_selected, n_parents / pop.nInd)
    
    print(f"\n--- Truncation Selection ---")
    print(f"Selection response: {trunc_metrics['selection_response']:.3f}")
    print(f"Genetic diversity: {trunc_metrics['genetic_diversity']:.3f}")
    print(f"Selection differential: {trunc_metrics['selection_differential']:.3f}")
    
    # ===== THOMPSON SAMPLING (STOCHASTIC) =====
    thompson_metrics_list = []
    thompson_selected_sets = []
    
    for rep in range(n_replicates):
        key, rep_key = jax.random.split(key)
        
        thompson_selected = thompson_sampling_selection(
            rep_key, predictions.pred_bv, predictions.pred_var, n_parents
        )
        thompson_metrics = calculate_selection_metrics(true_bv, thompson_selected, n_parents / pop.nInd)
        
        thompson_metrics_list.append(thompson_metrics)
        thompson_selected_sets.append(thompson_selected)
    
    # Average across replicates
    avg_thompson_metrics = {
        key: np.mean([metrics[key] for metrics in thompson_metrics_list])
        for key in thompson_metrics_list[0].keys()
    }
    
    print(f"\n--- Thompson Sampling (avg of {n_replicates} replicates) ---")
    print(f"Selection response: {avg_thompson_metrics['selection_response']:.3f}")
    print(f"Genetic diversity: {avg_thompson_metrics['genetic_diversity']:.3f}")
    print(f"Selection differential: {avg_thompson_metrics['selection_differential']:.3f}")
    
    # ===== ANALYZE SELECTION DIFFERENCES =====
    
    # How many individuals overlap between strategies?
    overlap_counts = []
    for thompson_selected in thompson_selected_sets:
        overlap = len(np.intersect1d(trunc_selected, thompson_selected))
        overlap_counts.append(overlap)
    
    avg_overlap = np.mean(overlap_counts)
    
    print(f"\n--- Strategy Comparison ---")
    print(f"Average overlap: {avg_overlap:.1f}/{n_parents} individuals ({100*avg_overlap/n_parents:.1f}%)")
    print(f"Thompson Sampling explores {n_parents - avg_overlap:.1f} different individuals on average")
    
    # Analyze uncertainty of selected individuals
    trunc_uncertainty = jnp.mean(predictions.pred_var[trunc_selected, 0])
    thompson_uncertainties = [
        jnp.mean(predictions.pred_var[selected, 0]) 
        for selected in thompson_selected_sets
    ]
    avg_thompson_uncertainty = np.mean(thompson_uncertainties)
    
    print(f"\nUncertainty of selected individuals:")
    print(f"Truncation: {trunc_uncertainty:.4f}")
    print(f"Thompson Sampling: {avg_thompson_uncertainty:.4f}")
    print(f"Thompson selects {'higher' if avg_thompson_uncertainty > trunc_uncertainty else 'lower'} uncertainty individuals")
    
    return {
        'truncation': trunc_metrics,
        'thompson_sampling': avg_thompson_metrics,
        'overlap_percentage': 100 * avg_overlap / n_parents,
        'uncertainty_difference': avg_thompson_uncertainty - trunc_uncertainty
    }

# ===== DEMONSTRATION: Selection Strategy Comparison =====
print("\n=== Selection Strategy Comparison ===")
print("Testing Thompson Sampling vs Truncation Selection...")

key = EXPERIMENT_KEY
selection_results = {}

# Test on each trait architecture
for arch_name, arch_data in FOUNDER_POPULATIONS.items():
    print(f"\n{'='*50}")
    print(f"TESTING: {arch_name.upper()} Architecture")
    print(f"{'='*50}")
    
    pop = arch_data['population']
    
    # Use the trained predictor from previous cell
    predictor = trained_predictors[arch_name]
    
    key, selection_key = jax.random.split(key)
    
    # Compare strategies
    comparison_results = compare_selection_strategies(
        key=selection_key,
        pop=pop,
        predictor=predictor,
        n_parents=100,  # Select top 5% as parents
        n_replicates=20  # More replicates for robust statistics
    )
    
    selection_results[arch_name] = comparison_results

print(f"\n{'='*60}")
print("SUMMARY: Thompson Sampling vs Truncation Selection")
print(f"{'='*60}")

for arch_name, results in selection_results.items():
    trunc = results['truncation']
    thompson = results['thompson_sampling']
    
    print(f"\n{arch_name.upper()}:")
    print(f"  Selection Response - Truncation: {trunc['selection_response']:.3f}")
    print(f"  Selection Response - Thompson:   {thompson['selection_response']:.3f}")
    print(f"  Diversity - Truncation: {trunc['genetic_diversity']:.3f}")
    print(f"  Diversity - Thompson:   {thompson['genetic_diversity']:.3f}")
    print(f"  Overlap: {results['overlap_percentage']:.1f}%")
    print(f"  Uncertainty exploration: {'+' if results['uncertainty_difference'] > 0 else ''}{results['uncertainty_difference']:.4f}")

print(f"\n✓ Selection strategies implemented and tested")
print("✓ Thompson Sampling shows exploration of uncertain individuals")
print("Next: Multi-generation breeding program simulation")

In [None]:
"""
Cell 4: Multi-generational Breeding Program Simulation
Complete factorial experiment comparing Thompson Sampling vs Truncation Selection
"""

import jax
import jax.numpy as jnp
from functools import partial
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd

# Import chewc crossing functionality
from chewc.cross import make_cross
from chewc.phenotype import set_pheno

def select_individuals_from_population(
    pop: Population,
    selected_indices: jnp.ndarray
) -> Population:
    """
    Create a new population containing only the selected individuals.
    
    Args:
        pop: Original population
        selected_indices: Array of internal IDs (iids) to select
        
    Returns:
        New Population object with only selected individuals
    """
    n_selected = len(selected_indices)
    
    # Create new contiguous internal IDs
    new_iids = jnp.arange(n_selected)
    
    # Extract selected individuals' data
    selected_pop = Population(
        geno=pop.geno[selected_indices],
        id=pop.id[selected_indices],
        iid=new_iids,  # Reset to contiguous 0-indexed
        mother=pop.mother[selected_indices],
        father=pop.father[selected_indices],
        sex=pop.sex[selected_indices],
        pheno=pop.pheno[selected_indices] if pop.pheno.size > 0 else jnp.zeros((n_selected, 0)),
        fixEff=pop.fixEff[selected_indices],
        bv=pop.bv[selected_indices] if pop.bv is not None else None,
        gv=pop.gv[selected_indices] if pop.gv is not None else None,
        ebv=pop.ebv[selected_indices] if pop.ebv is not None else None
    )
    
    return selected_pop

def create_random_crosses(
    key: jax.random.PRNGKey,
    parent_pop: Population,
    n_offspring: int
) -> jnp.ndarray:
    """
    Create random crossing plan from parent population.
    
    Args:
        key: JAX random key
        parent_pop: Population of selected parents
        n_offspring: Number of offspring to produce
        
    Returns:
        Cross plan array of shape (n_offspring, 2) with parent iids
    """
    n_parents = parent_pop.nInd
    
    # Separate parents by sex if available, otherwise random pairing
    if jnp.all(parent_pop.sex == parent_pop.sex[0]):
        # All same sex or sex not meaningful - just random pairs
        mothers = jax.random.choice(key, parent_pop.iid, (n_offspring,), replace=True)
        key, father_key = jax.random.split(key)
        fathers = jax.random.choice(father_key, parent_pop.iid, (n_offspring,), replace=True)
    else:
        # Use sex information
        males = parent_pop.iid[parent_pop.sex == 0]
        females = parent_pop.iid[parent_pop.sex == 1]
        
        if len(males) == 0 or len(females) == 0:
            # Fallback to random if one sex is missing
            mothers = jax.random.choice(key, parent_pop.iid, (n_offspring,), replace=True)
            key, father_key = jax.random.split(key)
            fathers = jax.random.choice(father_key, parent_pop.iid, (n_offspring,), replace=True)
        else:
            key, mother_key = jax.random.split(key)
            mothers = jax.random.choice(mother_key, females, (n_offspring,), replace=True)
            fathers = jax.random.choice(key, males, (n_offspring,), replace=True)
    
    return jnp.stack([mothers, fathers], axis=1)

def run_single_generation(
    key: jax.random.PRNGKey,
    pop: Population,
    sp: SimParam,
    predictor: GBLUPPredictor,
    selection_strategy: str,
    h2: float,
    n_parents: int = 100,
    n_offspring: int = 1000
) -> Tuple[Population, Dict]:
    """
    Run a single generation of breeding program.
    
    Args:
        key: JAX random key
        pop: Current population  
        sp: Simulation parameters
        predictor: Trained GBLUP predictor (from founder generation)
        selection_strategy: 'truncation' or 'thompson'
        h2: Heritability for phenotyping
        n_parents: Number of parents to select
        n_offspring: Number of offspring to produce
        
    Returns:
        Tuple of (offspring_population, generation_metrics)
    """
    key, pheno_key, select_key, cross_key, offspring_pheno_key = jax.random.split(key, 5)
    
    # 1. Set phenotypes for current population
    pop_with_pheno = set_pheno(
        key=pheno_key,
        pop=pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2])
    )
    
    # 2. Make predictions using founder-trained predictor
    predictions = predictor.predict(pop_with_pheno.dosage)
    
    # 3. Select parents based on strategy
    if selection_strategy == 'truncation':
        selected_indices = truncation_selection(predictions.pred_bv, n_parents)
    elif selection_strategy == 'thompson':
        selected_indices = thompson_sampling_selection(
            select_key, predictions.pred_bv, predictions.pred_var, n_parents
        )
    else:
        raise ValueError(f"Unknown selection strategy: {selection_strategy}")
    
    # 4. Create parent population
    parent_pop = select_individuals_from_population(pop_with_pheno, selected_indices)
    
    # 5. Create crosses
    cross_plan = create_random_crosses(cross_key, parent_pop, n_offspring)
    
    # 6. Generate offspring
    next_id_start = jnp.max(pop.id) + 1
    offspring_pop = make_cross(
        key=cross_key,
        pop=parent_pop,
        cross_plan=cross_plan,
        sp=sp,
        next_id_start=next_id_start
    )
    
    # 7. Set genetic values for offspring (needed for metrics)
    offspring_with_gv = set_pheno(
        key=offspring_pheno_key,
        pop=offspring_pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2])
    )
    
    # 8. Calculate generation metrics
    metrics = calculate_selection_metrics(
        pop_with_pheno.bv[:, 0], selected_indices, n_parents / pop.nInd
    )
    
    # Add offspring metrics
    metrics.update({
        'offspring_mean_bv': float(jnp.mean(offspring_with_gv.bv[:, 0])),
        'offspring_var_bv': float(jnp.var(offspring_with_gv.bv[:, 0])),
        'parent_mean_bv': float(jnp.mean(parent_pop.bv[:, 0])),
        'parent_var_bv': float(jnp.var(parent_pop.bv[:, 0]))
    })
    
    return offspring_with_gv, metrics

def run_breeding_program(
    key: jax.random.PRNGKey,
    founder_pop: Population,
    sp: SimParam,
    selection_strategy: str,
    h2: float,
    n_generations: int = 10,
    n_parents: int = 100,
    n_offspring: int = 1000
) -> Tuple[List[Population], List[Dict]]:
    """
    Run complete multi-generational breeding program.
    
    Args:
        key: JAX random key
        founder_pop: Starting population
        sp: Simulation parameters
        selection_strategy: 'truncation' or 'thompson'
        h2: Heritability
        n_generations: Number of generations to simulate
        n_parents: Parents per generation
        n_offspring: Offspring per generation
        
    Returns:
        Tuple of (populations_by_generation, metrics_by_generation)
    """
    print(f"Running {n_generations} generations with {selection_strategy} selection (h²={h2})")
    
    # Train predictor on founder population
    founder_with_pheno = set_pheno(
        key=key,
        pop=founder_pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2])
    )
    
    predictor = GBLUPPredictor(
        training_dosage=founder_with_pheno.dosage,
        training_phenotypes=founder_with_pheno.pheno,
        h2=jnp.array([h2])
    )
    
    populations = [founder_with_pheno]
    all_metrics = []
    current_pop = founder_with_pheno
    
    for generation in range(n_generations):
        key, gen_key = jax.random.split(key)
        
        print(f"  Generation {generation + 1}/{n_generations}... ", end="")
        
        current_pop, gen_metrics = run_single_generation(
            key=gen_key,
            pop=current_pop,
            sp=sp,
            predictor=predictor,
            selection_strategy=selection_strategy,
            h2=h2,
            n_parents=n_parents,
            n_offspring=n_offspring
        )
        
        populations.append(current_pop)
        all_metrics.append(gen_metrics)
        
        print(f"Mean BV: {gen_metrics['offspring_mean_bv']:.3f}")
    
    return populations, all_metrics

# ===== FACTORIAL EXPERIMENT =====
print("\n=== Multi-generational Breeding Program Simulation ===")

# Experimental parameters
N_GENERATIONS = 15
N_PARENTS = 100
N_OFFSPRING = 1000
HERITABILITIES = [0.3, 0.5]  # Reduced for speed
STRATEGIES = ['truncation', 'thompson']

key = EXPERIMENT_KEY
experiment_results = {}

for arch_name, arch_data in FOUNDER_POPULATIONS.items():
    print(f"\n{'='*60}")
    print(f"ARCHITECTURE: {arch_name.upper()}")
    print(f"{'='*60}")
    
    founder_pop = arch_data['population']
    sp = arch_data['sim_param']
    
    arch_results = {}
    
    for h2 in HERITABILITIES:
        for strategy in STRATEGIES:
            key, exp_key = jax.random.split(key)
            
            print(f"\n--- {strategy.upper()} Selection, h² = {h2} ---")
            
            populations, metrics = run_breeding_program(
                key=exp_key,
                founder_pop=founder_pop,
                sp=sp,
                selection_strategy=strategy,
                h2=h2,
                n_generations=N_GENERATIONS,
                n_parents=N_PARENTS,
                n_offspring=N_OFFSPRING
            )
            
            arch_results[f"{strategy}_h2_{h2}"] = {
                'populations': populations,
                'metrics': metrics,
                'final_mean_bv': metrics[-1]['offspring_mean_bv'] if metrics else 0.0
            }
    
    experiment_results[arch_name] = arch_results

# ===== ANALYZE RESULTS =====
print(f"\n{'='*80}")
print("EXPERIMENT RESULTS SUMMARY")
print(f"{'='*80}")

results_summary = []

for arch_name, arch_results in experiment_results.items():
    print(f"\n{arch_name.upper()} Architecture:")
    
    for h2 in HERITABILITIES:
        trunc_key = f"truncation_h2_{h2}"
        thompson_key = f"thompson_h2_{h2}"
        
        if trunc_key in arch_results and thompson_key in arch_results:
            trunc_final = arch_results[trunc_key]['final_mean_bv']
            thompson_final = arch_results[thompson_key]['final_mean_bv']
            
            improvement = thompson_final - trunc_final
            percent_improvement = 100 * improvement / trunc_final if trunc_final != 0 else 0
            
            print(f"  h² = {h2}:")
            print(f"    Truncation final BV:     {trunc_final:.3f}")
            print(f"    Thompson Sampling final: {thompson_final:.3f}")
            print(f"    Improvement:             {improvement:+.3f} ({percent_improvement:+.1f}%)")
            
            results_summary.append({
                'Architecture': arch_name,
                'Heritability': h2,
                'Truncation_Final_BV': trunc_final,
                'Thompson_Final_BV': thompson_final,
                'Improvement': improvement,
                'Percent_Improvement': percent_improvement
            })

# Convert to DataFrame for easy analysis
results_df = pd.DataFrame(results_summary)
print(f"\n{'='*80}")
print("OVERALL RESULTS")
print(f"{'='*80}")
print(results_df.to_string(index=False, float_format='%.3f'))

# Calculate average improvements
avg_improvement = results_df['Percent_Improvement'].mean()
positive_improvements = (results_df['Percent_Improvement'] > 0).sum()
total_conditions = len(results_df)

print(f"\nSUMMARY STATISTICS:")
print(f"Average improvement: {avg_improvement:.2f}%")
print(f"Positive improvements: {positive_improvements}/{total_conditions} conditions")
print(f"Thompson Sampling shows {'consistent' if positive_improvements >= total_conditions * 0.8 else 'mixed'} advantages")

print(f"\n✓ Multi-generational experiment complete!")
print("✓ Results demonstrate Thompson Sampling performance across architectures and heritabilities")
print("Next: Analyze and visualize the genetic gain trajectories")