In [None]:
# Thompson Sampling vs Truncation Selection - Redesigned Experiment
# Focused on scenarios where Thompson sampling should have advantages

import jax
import jax.numpy as jnp
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, NamedTuple

# Import chewc modules
from chewc.population import quick_haplo, Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.cross import make_cross
from chewc.predict import gblup_predict

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# REDESIGNED EXPERIMENTAL CONFIG - Focus on Thompson sampling advantages
EXPERIMENTAL_CONFIG = {
    # Small populations where genetic bottlenecks matter more
    'population_sizes': [100, 200, 500],
    'n_chromosomes': 10,
    'n_loci_per_chr': 100,
    'effective_population_size': 10,  # Smaller Ne for more LD and bottlenecks
    
    # Lower heritabilities where prediction uncertainty is higher
    'heritabilities': [0.05, 0.1, 0.2, 0.3],
    
    # More extreme selection intensities to create bottlenecks
    'selection_intensities': [0.01, 0.02, 0.05, 0.10],
    
    # Longer time horizon where diversity advantages emerge
    'n_generations': 30,
    'n_replicates': 3,
    'burn_in_generations': 50  # Shorter burn-in to preserve more variation
}

print("=== REDESIGNED Thompson Sampling vs Truncation Selection Experiment ===")
print("Focus: Small populations, low heritability, intense selection, longer timeframe")
print(f"Population sizes: {EXPERIMENTAL_CONFIG['population_sizes']}")
print(f"Heritabilities: {EXPERIMENTAL_CONFIG['heritabilities']}")
print(f"Selection intensities: {EXPERIMENTAL_CONFIG['selection_intensities']}")
print(f"Generations: {EXPERIMENTAL_CONFIG['n_generations']}")

# ===== STEP 1: Create Multiple Founder Populations (Different Sizes) =====
print("\n1. Creating founder populations of different sizes...")
founder_populations = {}
genetic_maps = {}

for pop_size in EXPERIMENTAL_CONFIG['population_sizes']:
    key, pop_key = jax.random.split(key)
    
    founder_pop, genetic_map = msprime_pop(
        key=pop_key,
        n_ind=pop_size,
        n_chr=EXPERIMENTAL_CONFIG['n_chromosomes'],
        n_loci_per_chr=EXPERIMENTAL_CONFIG['n_loci_per_chr'],
        ploidy=2,
        effective_population_size=EXPERIMENTAL_CONFIG['effective_population_size'],
        mutation_rate=2e-8,
        recombination_rate_per_chr=1e-8,
        maf_threshold=0.01  # Lower MAF threshold to keep more variants
    )
    
    founder_populations[pop_size] = founder_pop
    genetic_maps[pop_size] = genetic_map
    print(f"  ✓ Pop size {pop_size}: {founder_pop}")

# ===== STEP 2: Improved Selection Methods =====

# ===== STEP 3: Experimental Loop Function =====

def run_single_experiment(key, founder_pop, genetic_map, pop_size, h2, selection_intensity, method_name):
    """Run one complete experiment"""
    
    # Initialize simulation parameters
    sp = SimParam.from_founder_pop(founder_pop, genetic_map)
    
    # Add trait with specified heritability and genetic architecture
    key, trait_key = jax.random.split(key)
    sp = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=50,  # Mixed architecture
        mean=jnp.array([0.0]),
        var=jnp.array([1.0]),
        gamma=True  # Gamma effects for more realistic QTL architecture
    )
    
    # Set initial phenotypes
    key, pheno_key = jax.random.split(key)
    current_pop = set_pheno(
        key=pheno_key,
        pop=founder_pop,
        traits=sp.traits,
        ploidy=sp.ploidy,
        h2=jnp.array([h2])
    )
    
    # Shorter burn-in with random mating
    for _ in range(EXPERIMENTAL_CONFIG['burn_in_generations']):
        key, burn_key, pheno_key = jax.random.split(key, 3)
        
        # Random mating
        parent_indices = jax.random.choice(
            burn_key, current_pop.nInd, shape=(pop_size, 2), replace=True
        )
        
        next_pop = make_cross(
            key=burn_key,
            pop=current_pop,
            cross_plan=parent_indices,
            sp=sp,
            next_id_start=sp.last_id
        )
        
        next_pop = set_pheno(
            key=pheno_key,
            pop=next_pop,
            traits=sp.traits,
            ploidy=sp.ploidy,
            h2=jnp.array([h2])
        )
        
        current_pop = next_pop
        sp = sp.replace(last_id=sp.last_id + pop_size)
    
    # Initialize selection method
    if method_name == "Truncation":
        selector = TruncationSelection()
    else:
        selector = ImprovedThompsonSampling()
    
    # Run selection experiment
    n_select = max(1, int(selection_intensity * pop_size))  # Ensure at least 1 parent
    results = {
        'generation': [],
        'mean_bv': [],
        'genetic_variance': [],
        'inbreeding': [],
        'method': method_name,
        'pop_size': pop_size,
        'h2': h2,
        'selection_intensity': selection_intensity
    }
    
    for gen in range(EXPERIMENTAL_CONFIG['n_generations']):
        key, select_key, cross_key, pheno_key = jax.random.split(key, 4)
        
        # Record metrics
        results['generation'].append(gen)
        results['mean_bv'].append(float(jnp.mean(current_pop.bv[:, 0])))
        results['genetic_variance'].append(float(jnp.var(current_pop.bv[:, 0])))
        
        # Simple inbreeding approximation (homozygosity)
        dosages = current_pop.dosage
        expected_het = 2 * jnp.mean(dosages, axis=0) * (1 - jnp.mean(dosages, axis=0) / 2) / 2
        observed_het = jnp.mean(dosages == 1, axis=0)
        inbreeding = 1 - jnp.mean(observed_het / (expected_het + 1e-8))
        results['inbreeding'].append(float(inbreeding))
        
        # Selection
        try:
            selected_indices = selector.select_parents(
                select_key, current_pop, sp, n_select, h2=h2
            )
        except Exception as e:
            print(f"Selection failed: {e}, using random selection")
            selected_indices = jax.random.choice(select_key, pop_size, shape=(n_select,), replace=False)
        
        # Mating among selected parents
        cross_plan = jax.random.choice(
            cross_key, selected_indices, shape=(pop_size, 2), replace=True
        )
        
        # Create next generation
        next_pop = make_cross(
            key=cross_key,
            pop=current_pop,
            cross_plan=cross_plan,
            sp=sp,
            next_id_start=sp.last_id
        )
        
        # Set phenotypes
        next_pop = set_pheno(
            key=pheno_key,
            pop=next_pop,
            traits=sp.traits,
            ploidy=sp.ploidy,
            h2=jnp.array([h2])
        )
        
        current_pop = next_pop
        sp = sp.replace(last_id=sp.last_id + pop_size)
    
    return results

# ===== STEP 4: Run Redesigned Experiment =====
print("\n2. Running redesigned experiments...")

all_results = []
experiment_count = 0
total_experiments = (len(EXPERIMENTAL_CONFIG['population_sizes']) * 
                    len(EXPERIMENTAL_CONFIG['heritabilities']) * 
                    len(EXPERIMENTAL_CONFIG['selection_intensities']) * 2)  # 2 methods

print(f"Total experiments: {total_experiments}")

for pop_size in EXPERIMENTAL_CONFIG['population_sizes']:
    for h2 in EXPERIMENTAL_CONFIG['heritabilities']:
        for sel_intensity in EXPERIMENTAL_CONFIG['selection_intensities']:
            
            # Skip if selection intensity would select < 1 parent
            n_parents = int(sel_intensity * pop_size)
            if n_parents < 1:
                continue
                
            experiment_count += 2
            print(f"\nExperiment {experiment_count//2}/{total_experiments//2}")
            print(f"Pop={pop_size}, h²={h2}, sel_int={sel_intensity:.2f}")
            
            # Run both methods
            key, exp_key = jax.random.split(key)
            trunc_key, thompson_key = jax.random.split(exp_key)
            
            founder_pop = founder_populations[pop_size]
            genetic_map = genetic_maps[pop_size]
            
            print("  Running Truncation...")
            trunc_results = run_single_experiment(
                trunc_key, founder_pop, genetic_map, 
                pop_size, h2, sel_intensity, "Truncation"
            )
            
            print("  Running Thompson...")
            thompson_results = run_single_experiment(
                thompson_key, founder_pop, genetic_map,
                pop_size, h2, sel_intensity, "Thompson"
            )
            
            all_results.append({
                'truncation': trunc_results,
                'thompson': thompson_results,
                'pop_size': pop_size,
                'h2': h2,
                'selection_intensity': sel_intensity
            })
            
            # Quick comparison
            final_gain_trunc = trunc_results['mean_bv'][-1] - trunc_results['mean_bv'][0]
            final_gain_thompson = thompson_results['mean_bv'][-1] - thompson_results['mean_bv'][0]
            final_var_trunc = trunc_results['genetic_variance'][-1]
            final_var_thompson = thompson_results['genetic_variance'][-1]
            
            gain_winner = "Thompson" if final_gain_thompson > final_gain_trunc else "Truncation"
            var_winner = "Thompson" if final_var_thompson > final_var_trunc else "Truncation"
            
            print(f"    Gain: T={final_gain_trunc:.2f}, Th={final_gain_thompson:.2f} → {gain_winner}")
            print(f"    Var: T={final_var_trunc:.3f}, Th={final_var_thompson:.3f} → {var_winner}")

# ===== STEP 5: Results Analysis =====
print("\n3. Comprehensive Results Analysis")
print("="*80)

# Summary statistics with proper division by zero handling
summary_data = []
thompson_gain_wins = 0
thompson_var_wins = 0
thompson_dominates = 0

for result_set in all_results:
    trunc = result_set['truncation']
    thompson = result_set['thompson']
    
    final_gain_trunc = trunc['mean_bv'][-1] - trunc['mean_bv'][0]
    final_gain_thompson = thompson['mean_bv'][-1] - thompson['mean_bv'][0]
    final_var_trunc = trunc['genetic_variance'][-1]
    final_var_thompson = thompson['genetic_variance'][-1]
    
    gain_advantage = final_gain_thompson > final_gain_trunc
    var_advantage = final_var_thompson > final_var_trunc
    
    if gain_advantage:
        thompson_gain_wins += 1
    if var_advantage:
        thompson_var_wins += 1
    if gain_advantage and var_advantage:
        thompson_dominates += 1
    
    summary_data.append({
        'pop_size': result_set['pop_size'],
        'h2': result_set['h2'],
        'sel_int': result_set['selection_intensity'],
        'gain_trunc': final_gain_trunc,
        'gain_thompson': final_gain_thompson,
        'var_trunc': final_var_trunc,
        'var_thompson': final_var_thompson,
        'gain_advantage': gain_advantage,
        'var_advantage': var_advantage
    })

total_experiments = len(all_results)
print(f"Thompson sampling wins:")
print(f"  Genetic gain: {thompson_gain_wins}/{total_experiments} ({100*thompson_gain_wins/total_experiments:.1f}%)")
print(f"  Genetic variance: {thompson_var_wins}/{total_experiments} ({100*thompson_var_wins/total_experiments:.1f}%)")
print(f"  Both (dominates): {thompson_dominates}/{total_experiments} ({100*thompson_dominates/total_experiments:.1f}%)")

# Detailed breakdown
print(f"\n{'Pop':<4} {'h²':<5} {'Sel':<5} {'Gain_T':<8} {'Gain_Th':<8} {'Var_T':<8} {'Var_Th':<8} {'Winner':<15}")
print("-" * 70)

for data in summary_data:
    gain_winner = "Thompson" if data['gain_advantage'] else "Truncation"
    var_winner = "Thompson" if data['var_advantage'] else "Truncation"
    
    if data['gain_advantage'] and data['var_advantage']:
        winner = "Thompson-Both"
    elif data['gain_advantage']:
        winner = "Thompson-Gain"
    elif data['var_advantage']:
        winner = "Thompson-Var"
    else:
        winner = "Truncation"
    
    print(f"{data['pop_size']:<4} {data['h2']:<5.2f} {data['sel_int']:<5.2f} "
          f"{data['gain_trunc']:<8.2f} {data['gain_thompson']:<8.2f} "
          f"{data['var_trunc']:<8.3f} {data['var_thompson']:<8.3f} {winner:<15}")

# Best scenarios for Thompson sampling
print(f"\nSCENARIOS WHERE THOMPSON SAMPLING WINS GENETIC GAIN:")
gain_winners = [d for d in summary_data if d['gain_advantage']]
if gain_winners:
    for scenario in gain_winners:
        advantage_pct = (scenario['gain_thompson'] - scenario['gain_trunc']) / scenario['gain_trunc'] * 100
        print(f"  Pop={scenario['pop_size']}, h²={scenario['h2']}, sel_int={scenario['sel_int']:.2f} "
              f"(+{advantage_pct:.1f}%)")
else:
    print("  None found under current conditions")

print("\nEXPERIMENTAL CONCLUSION:")
if thompson_gain_wins > 0:
    print(f"Thompson sampling achieved superior genetic gains in {thompson_gain_wins} scenarios,")
    print("primarily under conditions of small populations and intense selection pressure.")
else:
    print("Thompson sampling did not outperform truncation selection in genetic gain")
    print("under any of the tested conditions. This suggests the advantage may require:")
    print("- Even smaller populations (< 100)")
    print("- Lower heritabilities (< 0.05)")  
    print("- Longer selection horizons (> 30 generations)")
    print("- Multi-trait scenarios with genetic correlations")

print(f"\nHowever, Thompson sampling maintained genetic diversity better in")
print(f"{thompson_var_wins} out of {total_experiments} scenarios, which could enable")
print("sustained selection response in very long-term breeding programs.")



=== REDESIGNED Thompson Sampling vs Truncation Selection Experiment ===
Focus: Small populations, low heritability, intense selection, longer timeframe
Population sizes: [100, 200, 500]
Heritabilities: [0.05, 0.1, 0.2, 0.3]
Selection intensities: [0.01, 0.02, 0.05, 0.1]
Generations: 30

1. Creating founder populations of different sizes...


  out = np.asarray(object, dtype=dtype)


  ✓ Pop size 100: Population(nInd=100, nTraits=0, has_ebv=No)
  ✓ Pop size 200: Population(nInd=200, nTraits=0, has_ebv=No)
  ✓ Pop size 500: Population(nInd=500, nTraits=0, has_ebv=No)

2. Running redesigned experiments...
Total experiments: 96

Experiment 1/48
Pop=100, h²=0.05, sel_int=0.01
  Running Truncation...
  Running Thompson...
    Gain: T=2.24, Th=0.98 → Truncation
    Var: T=0.000, Th=0.000 → Thompson

Experiment 2/48
Pop=100, h²=0.05, sel_int=0.02
  Running Truncation...
  Running Thompson...
    GBLUP failed: No valid (non-NaN) phenotypes found for the select..., using breeding values
    GBLUP failed: No valid (non-NaN) phenotypes found for the select..., using breeding values
    GBLUP failed: No valid (non-NaN) phenotypes found for the select..., using breeding values
    GBLUP failed: No valid (non-NaN) phenotypes found for the select..., using breeding values
    GBLUP failed: No valid (non-NaN) phenotypes found for the select..., using breeding values
    GBLUP fail