In [None]:
"""
Thompson Sampling vs Truncation Selection Experiment
Cell 1: Founder Population and Trait Architecture Setup
"""

import jax
import jax.numpy as jnp
from functools import partial

# Import chewc modules
from chewc.population import quick_haplo, Population, msprime_pop
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno

# Set random seed for reproducibility
key = jax.random.PRNGKey(42)

# Experimental parameters
EXPERIMENTAL_CONFIG = {
    'population_size': 2000,
    'n_chromosomes': 10,
    'n_loci_per_chr': 1000,
    'heritabilities': [0.1, 0.3, 0.5, 0.7],
    'trait_architectures': {
        'oligogenic': {'n_qtl_per_chr': 10, 'gamma_effects': True},
        'polygenic': {'n_qtl_per_chr': 100, 'gamma_effects': False},
        'mixed': {'n_qtl_per_chr': 50, 'gamma_effects': True}
    }
}

print("=== Thompson Sampling Experiment Setup ===")
print(f"Population size: {EXPERIMENTAL_CONFIG['population_size']}")
print(f"Genome: {EXPERIMENTAL_CONFIG['n_chromosomes']} chromosomes × {EXPERIMENTAL_CONFIG['n_loci_per_chr']} loci")
print(f"Total loci: {EXPERIMENTAL_CONFIG['n_chromosomes'] * EXPERIMENTAL_CONFIG['n_loci_per_chr']:,}")

# ===== STEP 1: Create Founder Population =====
print("\n1. Creating founder population...")
key, pop_key = jax.random.split(key)

founder_pop, genetic_map = msprime_pop(
    key=pop_key,
    n_ind=EXPERIMENTAL_CONFIG['population_size'],
    n_chr=EXPERIMENTAL_CONFIG['n_chromosomes'],
    n_loci_per_chr=EXPERIMENTAL_CONFIG['n_loci_per_chr'],
    ploidy=2,
    # inbred=False,
    # chr_len_cm=100.0  # 100 cM chromosomes
)

print(f"✓ Created founder population: {founder_pop}")
print(f"✓ Genetic map shape: {genetic_map.shape}")

# ===== STEP 2: Initialize SimParam =====
print("\n2. Initializing simulation parameters...")
sp = SimParam.from_founder_pop(founder_pop, genetic_map)
print(f"✓ SimParam initialized: {sp}")

# ===== STEP 3: Add Trait Architectures =====
print("\n3. Setting up trait architectures...")

# We'll create a dictionary to store populations with different trait architectures
populations_by_architecture = {}

for arch_name, arch_params in EXPERIMENTAL_CONFIG['trait_architectures'].items():
    print(f"\n   Setting up {arch_name} architecture:")
    print(f"   - QTL per chromosome: {arch_params['n_qtl_per_chr']}")
    print(f"   - Gamma effects: {arch_params['gamma_effects']}")
    
    key, trait_key = jax.random.split(key)
    
    # Add single trait with this architecture
    sp_with_trait = add_trait_a(
        key=trait_key,
        founder_pop=founder_pop,
        sim_param=sp,
        n_qtl_per_chr=arch_params['n_qtl_per_chr'],
        mean=jnp.array([10.0]),  # Single trait with mean 10
        var=jnp.array([4.0]),    # Variance of 4 (std dev = 2)
        gamma=arch_params['gamma_effects']
    )
    
    # Calculate genetic values for verification
    key, pheno_key = jax.random.split(key)
    pop_with_gv = set_pheno(
        key=pheno_key,
        pop=founder_pop,
        traits=sp_with_trait.traits,
        ploidy=sp_with_trait.ploidy,
        h2=jnp.array([0.5])  # Use moderate heritability for setup
    )
    
    populations_by_architecture[arch_name] = {
        'population': pop_with_gv,
        'sim_param': sp_with_trait,
        'n_qtl': arch_params['n_qtl_per_chr'] * EXPERIMENTAL_CONFIG['n_chromosomes']
    }
    
    # Print summary statistics
    bv_mean = jnp.mean(pop_with_gv.bv[:, 0])
    bv_var = jnp.var(pop_with_gv.bv[:, 0])
    gv_mean = jnp.mean(pop_with_gv.gv[:, 0])
    gv_var = jnp.var(pop_with_gv.gv[:, 0])
    
    print(f"   ✓ Total QTL: {populations_by_architecture[arch_name]['n_qtl']}")
    print(f"   ✓ Breeding value: mean={bv_mean:.2f}, var={bv_var:.2f}")
    print(f"   ✓ Genetic value: mean={gv_mean:.2f}, var={gv_var:.2f}")

print(f"\n=== Setup Complete ===")
print(f"Ready to run experiments with {len(populations_by_architecture)} trait architectures")
print("Next: Implement prediction methods (GBLUP) and selection strategies")

# Store the current state for next cells
FOUNDER_POPULATIONS = populations_by_architecture
EXPERIMENT_KEY = key



=== Thompson Sampling Experiment Setup ===
Population size: 2000
Genome: 10 chromosomes × 1000 loci
Total loci: 10,000

1. Creating founder population...


  out = np.asarray(object, dtype=dtype)


✓ Created founder population: Population(nInd=2000, nTraits=0, has_ebv=No)
✓ Genetic map shape: (10, 1000)

2. Initializing simulation parameters...
✓ SimParam initialized: SimParam(nChr=10, nTraits=0, ploidy=2, sexes='no')

3. Setting up trait architectures...

   Setting up oligogenic architecture:
   - QTL per chromosome: 10
   - Gamma effects: True
   ✓ Total QTL: 100
   ✓ Breeding value: mean=2.90, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

   Setting up polygenic architecture:
   - QTL per chromosome: 100
   - Gamma effects: False
   ✓ Total QTL: 1000
   ✓ Breeding value: mean=2.38, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

   Setting up mixed architecture:
   - QTL per chromosome: 50
   - Gamma effects: True
   ✓ Total QTL: 500
   ✓ Breeding value: mean=-0.35, var=4.00
   ✓ Genetic value: mean=10.00, var=4.00

=== Setup Complete ===
Ready to run experiments with 3 trait architectures
Next: Implement prediction methods (GBLUP) and selection strategies
