# predict

> Common operations around the core datastructures for running a sim

In [None]:
#| default_exp select

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from flax.struct import dataclass as flax_dataclass
import jax.numpy as jnp
from typing import Optional
from chewc.population import Population


import jax
import jax.numpy as jnp
from fastcore.test import test_eq, test_close, test_fail
import time

# Import required chewc components
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.predict import gblup_predict

In [None]:
#| export

class SelectionMethod:
    """Base class for selection methods"""
    def __init__(self, name: str):
        self.name = name
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        raise NotImplementedError

class TruncationSelection(SelectionMethod):
    """Truncation selection - select top individuals"""
    def __init__(self):
        super().__init__("Truncation")
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        # Use breeding values if available, otherwise phenotypes
        if pop.bv is not None and jnp.var(pop.bv[:, 0]) > 1e-8:
            values = pop.bv[:, 0]
        else:
            values = pop.pheno[:, 0]
        
        # Get top n_select individuals
        top_indices = jnp.argsort(values)[-n_select:]
        return top_indices

class ImprovedThompsonSampling(SelectionMethod):
    """Improved Thompson sampling with better uncertainty handling"""
    def __init__(self):
        super().__init__("Thompson")
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        h2 = kwargs.get('h2', 0.5)
        
        # Check if we have sufficient variation for genomic prediction
        phenotypic_variance = jnp.var(pop.pheno[:, 0])
        
        if phenotypic_variance < 1e-8:
            # No variation left - random selection
            return jax.random.choice(key, pop.nInd, shape=(n_select,), replace=False)
        
        try:
            # Get EBVs using GBLUP
            prediction_results = gblup_predict(pop, h2=h2, trait_idx=0)
            ebv_mean = prediction_results.ebv[:, 0]
            reliability = prediction_results.reliability
            
            # Improved posterior sampling
            # Higher uncertainty (lower reliability) -> more exploration
            # Lower uncertainty (higher reliability) -> more exploitation
            
            # Calculate prediction error variance for each individual
            genetic_var = jnp.var(ebv_mean)
            pred_error_var = genetic_var * (1 - reliability + 1e-6)  # Avoid zero variance
            
            # Sample from individual posterior distributions
            key, sample_key = jax.random.split(key)
            sample_keys = jax.random.split(sample_key, pop.nInd)
            
            # Each individual gets their own posterior sample
            sampled_values = jnp.array([
                jax.random.normal(sample_keys[i]) * jnp.sqrt(pred_error_var[i]) + ebv_mean[i]
                for i in range(pop.nInd)
            ])
            
            # Select based on top sampled values (not probabilistic sampling)
            # This maintains the "selection intensity" more directly
            selected_indices = jnp.argsort(sampled_values)[-n_select:]
            
            return selected_indices
            
        except Exception as e:
            print(f"    GBLUP failed: {str(e)[:50]}..., using breeding values")
            # Fallback to breeding value-based selection with noise
            if pop.bv is not None:
                values = pop.bv[:, 0]
                # Add some exploration noise
                key, noise_key = jax.random.split(key)
                noise = jax.random.normal(noise_key, values.shape) * jnp.std(values) * 0.1
                noisy_values = values + noise
                selected_indices = jnp.argsort(noisy_values)[-n_select:]
                return selected_indices
            else:
                return jax.random.choice(key, pop.nInd, shape=(n_select,), replace=False)


In [None]:
# --- Test Fixtures and Helper Functions ---

def create_test_population(key, n_ind=100, n_chr=1, n_loci_per_chr=500, 
                          n_qtl_per_chr=50, h2=0.6, trait_mean=10.0, trait_var=2.0):
    """
    Helper function to create a standardized test population with traits and phenotypes.
    
    Returns:
        tuple: (population_with_phenotypes, simulation_parameters, genetic_map)
    """
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create founder population
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=n_chr, n_loci_per_chr=n_loci_per_chr
    )
    
    # Add trait architecture
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=n_qtl_per_chr,
        mean=jnp.array([trait_mean]), var=jnp.array([trait_var])
    )
    
    # Add phenotypes
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([h2])
    )
    
    return pop_with_pheno, sp, gen_map


def validate_selection_basic(selected_indices, n_select, n_ind):
    """Helper function for basic selection validation."""
    assert len(selected_indices) == n_select, f"Should select {n_select} individuals"
    assert len(jnp.unique(selected_indices)) == n_select, "Should select unique individuals"
    assert jnp.all(selected_indices >= 0), "All indices should be non-negative"
    assert jnp.all(selected_indices < n_ind), "All indices should be valid"


def validate_selection_quality(selected_indices, all_values, threshold_factor=1.0):
    """Helper function to validate that selected individuals have higher values."""
    selected_values = all_values[selected_indices]
    mean_selected = jnp.mean(selected_values)
    mean_all = jnp.mean(all_values)
    assert mean_selected > mean_all * threshold_factor, \
        f"Selected individuals should have values > {threshold_factor} * population mean"
    return mean_selected, mean_all


# --- Test Functions ---

def test_truncation_selection_basic():
    """Tests basic truncation selection functionality."""
    print("Starting basic truncation selection test...")
    
    key = jax.random.PRNGKey(42)
    pop, sp, _ = create_test_population(key)
    
    selector = TruncationSelection()
    n_select = 20
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(select_key, pop, sp, n_select)
    
    # Basic validation
    validate_selection_basic(selected_indices, n_select, pop.nInd)
    
    # Quality validation
    all_bv = pop.bv[:, 0]
    validate_selection_quality(selected_indices, all_bv)
    
    # Check that we selected exactly the top individuals
    top_indices = jnp.argsort(all_bv)[-n_select:]
    selected_set = set(selected_indices.tolist())
    top_set = set(top_indices.tolist())
    assert selected_set == top_set, "Should select exactly the top individuals"
    
    print("Basic truncation selection test passed!")


def test_truncation_selection_no_bv():
    """Tests truncation selection fallback to phenotypes when no breeding values."""
    print("Testing truncation selection fallback to phenotypes...")
    
    key = jax.random.PRNGKey(123)
    founder_key, pheno_key = jax.random.split(key, 2)
    
    # Create population without breeding values
    pop, gen_map = quick_haplo(key=founder_key, n_ind=50, n_chr=1, n_loci_per_chr=100)
    pheno_values = jax.random.normal(pheno_key, (50, 1))
    pop_pheno_only = pop.replace(pheno=pheno_values, bv=None)
    
    # Test selection
    selector = TruncationSelection()
    n_select = 10
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(select_key, pop_pheno_only, None, n_select)
    
    # Validate that top phenotypic individuals were selected
    top_pheno_indices = jnp.argsort(pheno_values[:, 0])[-n_select:]
    selected_set = set(selected_indices.tolist())
    top_set = set(top_pheno_indices.tolist())
    assert selected_set == top_set, "Should select top phenotypic individuals"
    
    print("Truncation selection phenotype fallback test passed!")


def test_thompson_sampling_basic():
    """Tests basic Thompson sampling functionality."""
    print("Starting basic Thompson sampling test...")
    
    key = jax.random.PRNGKey(456)
    pop, sp, _ = create_test_population(key, n_loci_per_chr=1000, n_qtl_per_chr=100, h2=0.5)
    
    selector = ImprovedThompsonSampling()
    n_select = 25
    h2 = 0.5
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(select_key, pop, sp, n_select, h2=h2)
    
    # Basic validation
    validate_selection_basic(selected_indices, n_select, pop.nInd)
    
    # Thompson sampling should still select generally good individuals
    all_bv = pop.bv[:, 0]
    validate_selection_quality(selected_indices, all_bv, threshold_factor=0.9)
    
    print("Basic Thompson sampling test passed!")


def test_thompson_sampling_exploration():
    """Tests that Thompson sampling shows more exploration than truncation selection."""
    print("Testing Thompson sampling exploration behavior...")
    
    key = jax.random.PRNGKey(789)
    pop, sp, _ = create_test_population(key, n_ind=80, n_loci_per_chr=800, 
                                      n_qtl_per_chr=80, h2=0.4, trait_mean=0.0, trait_var=1.0)
    
    # Setup selection parameters
    n_select = 15
    n_replicates = 10
    h2 = 0.4
    
    truncation_selector = TruncationSelection()
    thompson_selector = ImprovedThompsonSampling()
    
    key, *select_keys = jax.random.split(key, n_replicates * 2 + 1)
    
    # Run multiple selections
    def run_multiple_selections(selector, keys, use_h2=False):
        selections = []
        for i, select_key in enumerate(keys):
            if use_h2:
                selected = selector.select_parents(select_key, pop, sp, n_select, h2=h2)
            else:
                selected = selector.select_parents(select_key, pop, sp, n_select)
            selections.append(set(selected.tolist()))
        return selections
    
    truncation_selections = run_multiple_selections(
        truncation_selector, select_keys[:n_replicates]
    )
    thompson_selections = run_multiple_selections(
        thompson_selector, select_keys[n_replicates:], use_h2=True
    )
    
    # Measure diversity
    n_unique_truncation = len(set().union(*truncation_selections))
    n_unique_thompson = len(set().union(*thompson_selections))
    
    print(f"Truncation unique individuals: {n_unique_truncation}")
    print(f"Thompson unique individuals: {n_unique_thompson}")
    
    # Validate exploration behavior
    assert n_unique_thompson >= n_unique_truncation, "Thompson should explore at least as much"
    assert n_unique_truncation == n_select, "Truncation should be deterministic"
    
    print("Thompson sampling exploration test passed!")


def test_selection_with_no_variation():
    """Tests selection methods with zero phenotypic variation."""
    print("Testing selection with zero variation...")
    
    key = jax.random.PRNGKey(999)
    pop, _ = quick_haplo(key=key, n_ind=50, n_chr=1, n_loci_per_chr=100)
    
    # Create population with no variation
    constant_pheno = jnp.ones((50, 1)) * 5.0
    pop_no_var = pop.replace(
        pheno=constant_pheno,
        bv=jnp.ones((50, 1)) * 2.0
    )
    
    selectors = [
        ("Truncation", TruncationSelection()),
        ("Thompson", ImprovedThompsonSampling())
    ]
    
    n_select = 10
    key, select_key1, select_key2 = jax.random.split(key, 3)
    keys = [select_key1, select_key2]
    
    for (name, selector), select_key in zip(selectors, keys):
        try:
            kwargs = {'h2': 0.5} if name == "Thompson" else {}
            selected = selector.select_parents(select_key, pop_no_var, None, n_select, **kwargs)
            assert len(selected) == n_select, f"{name} should still select requested number"
            print(f"{name} selection handled zero variation")
        except Exception as e:
            print(f"{name} selection failed with zero variation: {e}")
    
    print("Zero variation test passed!")


def test_selection_edge_cases():
    """Tests edge cases like selecting 1 or all individuals."""
    print("Testing selection edge cases...")
    
    key = jax.random.PRNGKey(111)
    pop, sp, _ = create_test_population(key, n_ind=10, n_loci_per_chr=100, 
                                      n_qtl_per_chr=10, h2=0.5, trait_mean=0.0, trait_var=1.0)
    
    selector = TruncationSelection()
    
    # Test selecting one individual
    key, select_key1 = jax.random.split(key)
    selected_one = selector.select_parents(select_key1, pop, sp, 1)
    assert len(selected_one) == 1, "Should select exactly one individual"
    
    # Test selecting all individuals
    key, select_key2 = jax.random.split(key)
    selected_all = selector.select_parents(select_key2, pop, sp, pop.nInd)
    assert len(selected_all) == pop.nInd, "Should select all individuals"
    assert len(jnp.unique(selected_all)) == pop.nInd, "All should be unique"
    
    print("Edge cases test passed!")


def test_selection_methods_comparison():
    """Tests and compares both selection methods on a larger population."""
    print("Testing selection methods comparison...")
    
    key = jax.random.PRNGKey(222)
    pop, sp, _ = create_test_population(key, n_ind=200, n_chr=2, n_loci_per_chr=500,
                                      n_qtl_per_chr=100, h2=0.3, trait_mean=10.0, trait_var=4.0)
    
    # Test both methods
    selectors = [
        ("Truncation", TruncationSelection(), {}),
        ("Thompson", ImprovedThompsonSampling(), {'h2': 0.3})
    ]
    
    n_select = 50
    key, select_key1, select_key2 = jax.random.split(key, 3)
    keys = [select_key1, select_key2]
    
    results = {}
    
    for (name, selector, kwargs), select_key in zip(selectors, keys):
        selected = selector.select_parents(select_key, pop, sp, n_select, **kwargs)
        validate_selection_basic(selected, n_select, pop.nInd)
        
        all_bv = pop.bv[:, 0]
        mean_selected, mean_all = validate_selection_quality(selected, all_bv, 0.95)
        results[name] = mean_selected
        
        print(f"{name} selection - Mean BV: {mean_selected:.3f} (pop mean: {mean_all:.3f})")
    
    print("Both selection methods completed successfully")
    print("Selection methods comparison test passed!")



In [None]:
#| hide
import nbdev; nbdev.nbdev_export()