# predict

> Common operations around the core datastructures for running a sim

In [None]:
#| default_exp select

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from flax.struct import dataclass as flax_dataclass
import jax.numpy as jnp
from typing import Optional
from chewc.population import Population


import jax
import jax.numpy as jnp
from fastcore.test import test_eq, test_close, test_fail
import time

# Import required chewc components
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.predict import gblup_predict

In [None]:
#| export

class SelectionMethod:
    """Base class for selection methods"""
    def __init__(self, name: str):
        self.name = name
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        raise NotImplementedError

class TruncationSelection(SelectionMethod):
    """Truncation selection - select top individuals"""
    def __init__(self):
        super().__init__("Truncation")
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        # Use breeding values if available, otherwise phenotypes
        if pop.bv is not None and jnp.var(pop.bv[:, 0]) > 1e-8:
            values = pop.bv[:, 0]
        else:
            values = pop.pheno[:, 0]
        
        # Get top n_select individuals
        top_indices = jnp.argsort(values)[-n_select:]
        return top_indices

class ImprovedThompsonSampling(SelectionMethod):
    """Improved Thompson sampling with better uncertainty handling"""
    def __init__(self):
        super().__init__("Thompson")
    
    def select_parents(self, key, pop, sp, n_select, **kwargs):
        h2 = kwargs.get('h2', 0.5)
        
        # Check if we have sufficient variation for genomic prediction
        phenotypic_variance = jnp.var(pop.pheno[:, 0])
        
        if phenotypic_variance < 1e-8:
            # No variation left - random selection
            return jax.random.choice(key, pop.nInd, shape=(n_select,), replace=False)
        
        try:
            # Get EBVs using GBLUP
            prediction_results = gblup_predict(pop, h2=h2, trait_idx=0)
            ebv_mean = prediction_results.ebv[:, 0]
            reliability = prediction_results.reliability
            
            # Improved posterior sampling
            # Higher uncertainty (lower reliability) -> more exploration
            # Lower uncertainty (higher reliability) -> more exploitation
            
            # Calculate prediction error variance for each individual
            genetic_var = jnp.var(ebv_mean)
            pred_error_var = genetic_var * (1 - reliability + 1e-6)  # Avoid zero variance
            
            # Sample from individual posterior distributions
            key, sample_key = jax.random.split(key)
            sample_keys = jax.random.split(sample_key, pop.nInd)
            
            # Each individual gets their own posterior sample
            sampled_values = jnp.array([
                jax.random.normal(sample_keys[i]) * jnp.sqrt(pred_error_var[i]) + ebv_mean[i]
                for i in range(pop.nInd)
            ])
            
            # Select based on top sampled values (not probabilistic sampling)
            # This maintains the "selection intensity" more directly
            selected_indices = jnp.argsort(sampled_values)[-n_select:]
            
            return selected_indices
            
        except Exception as e:
            print(f"    GBLUP failed: {str(e)[:50]}..., using breeding values")
            # Fallback to breeding value-based selection with noise
            if pop.bv is not None:
                values = pop.bv[:, 0]
                # Add some exploration noise
                key, noise_key = jax.random.split(key)
                noise = jax.random.normal(noise_key, values.shape) * jnp.std(values) * 0.1
                noisy_values = values + noise
                selected_indices = jnp.argsort(noisy_values)[-n_select:]
                return selected_indices
            else:
                return jax.random.choice(key, pop.nInd, shape=(n_select,), replace=False)


In [None]:
#| test


# The classes should be available from the exported module above
# We'll use the class definitions from the current notebook context


def test_truncation_selection_basic():
    """
    Tests basic truncation selection functionality - selecting top individuals
    based on breeding values or phenotypes.
    """
    print("Starting basic truncation selection test...")
    
    # Setup simulation
    key = jax.random.PRNGKey(42)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create population
    n_ind = 100
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=500
    )
    
    # Add trait and phenotypes
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=50,
        mean=jnp.array([10.0]), var=jnp.array([2.0])
    )
    
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([0.6])
    )
    
    # Test truncation selection
    selector = TruncationSelection()
    n_select = 20
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(
        select_key, pop_with_pheno, sp, n_select
    )
    
    # Validate results
    assert len(selected_indices) == n_select, f"Should select {n_select} individuals"
    assert len(jnp.unique(selected_indices)) == n_select, "Should select unique individuals"
    assert jnp.all(selected_indices >= 0), "All indices should be non-negative"
    assert jnp.all(selected_indices < n_ind), "All indices should be valid"
    
    # Check that selected individuals have higher breeding values than average
    all_bv = pop_with_pheno.bv[:, 0]
    selected_bv = all_bv[selected_indices]
    mean_selected_bv = jnp.mean(selected_bv)
    mean_all_bv = jnp.mean(all_bv)
    
    assert mean_selected_bv > mean_all_bv, "Selected individuals should have higher BV than average"
    
    # Check that we actually selected the top individuals
    top_indices = jnp.argsort(all_bv)[-n_select:]
    selected_set = set(selected_indices.tolist())
    top_set = set(top_indices.tolist())
    
    assert selected_set == top_set, "Should select exactly the top individuals"
    
    print("Basic truncation selection test passed!")


def test_truncation_selection_no_bv():
    """
    Tests truncation selection when breeding values are not available,
    should fall back to phenotypes.
    """
    print("Testing truncation selection fallback to phenotypes...")
    
    key = jax.random.PRNGKey(123)
    founder_key, pheno_key = jax.random.split(key, 2)
    
    # Create population without breeding values
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=50, n_chr=1, n_loci_per_chr=100
    )
    
    # Create phenotypes directly without trait architecture
    pheno_values = jax.random.normal(pheno_key, (50, 1))
    pop_pheno_only = pop.replace(
        pheno=pheno_values,
        bv=None  # Explicitly no breeding values
    )
    
    # Test selection
    selector = TruncationSelection()
    n_select = 10
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(
        select_key, pop_pheno_only, None, n_select
    )
    
    # Validate that top phenotypic individuals were selected
    top_pheno_indices = jnp.argsort(pheno_values[:, 0])[-n_select:]
    selected_set = set(selected_indices.tolist())
    top_set = set(top_pheno_indices.tolist())
    
    assert selected_set == top_set, "Should select top phenotypic individuals"
    
    print("Truncation selection phenotype fallback test passed!")


def test_thompson_sampling_basic():
    """
    Tests basic Thompson sampling functionality.
    """
    print("Starting basic Thompson sampling test...")
    
    # Setup simulation
    key = jax.random.PRNGKey(456)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create population
    n_ind = 100
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=1000
    )
    
    # Add trait and phenotypes
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=100,
        mean=jnp.array([5.0]), var=jnp.array([1.0])
    )
    
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([0.5])
    )
    
    # Test Thompson sampling
    selector = ImprovedThompsonSampling()
    n_select = 25
    h2 = 0.5
    
    key, select_key = jax.random.split(key)
    selected_indices = selector.select_parents(
        select_key, pop_with_pheno, sp, n_select, h2=h2
    )
    
    # Basic validation
    assert len(selected_indices) == n_select, f"Should select {n_select} individuals"
    assert len(jnp.unique(selected_indices)) == n_select, "Should select unique individuals"
    assert jnp.all(selected_indices >= 0), "All indices should be non-negative"
    assert jnp.all(selected_indices < n_ind), "All indices should be valid"
    
    # Thompson sampling should still tend to select better individuals on average
    all_bv = pop_with_pheno.bv[:, 0]
    selected_bv = all_bv[selected_indices]
    mean_selected_bv = jnp.mean(selected_bv)
    mean_all_bv = jnp.mean(all_bv)
    
    # Should be better than random, but allow some variation due to sampling
    assert mean_selected_bv > mean_all_bv * 0.9, "Selected individuals should generally have higher BV"
    
    print("Basic Thompson sampling test passed!")


def test_thompson_sampling_exploration_vs_exploitation():
    """
    Tests that Thompson sampling shows more exploration than truncation selection
    by running multiple selections and measuring diversity.
    """
    print("Testing Thompson sampling exploration behavior...")
    
    # Setup simulation
    key = jax.random.PRNGKey(789)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create population
    n_ind = 80
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=800
    )
    
    # Add trait and phenotypes
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=80,
        mean=jnp.array([0.0]), var=jnp.array([1.0])
    )
    
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([0.4])
    )
    
    # Run multiple selections with each method
    n_select = 15
    n_replicates = 10
    h2 = 0.4
    
    truncation_selector = TruncationSelection()
    thompson_selector = ImprovedThompsonSampling()
    
    key, *select_keys = jax.random.split(key, n_replicates * 2 + 1)
    
    # Collect selections from both methods
    truncation_selections = []
    thompson_selections = []
    
    for i in range(n_replicates):
        # Truncation selection (should be deterministic)
        trunc_selected = truncation_selector.select_parents(
            select_keys[i], pop_with_pheno, sp, n_select
        )
        truncation_selections.append(set(trunc_selected.tolist()))
        
        # Thompson sampling (should be stochastic)
        thompson_selected = thompson_selector.select_parents(
            select_keys[i + n_replicates], pop_with_pheno, sp, n_select, h2=h2
        )
        thompson_selections.append(set(thompson_selected.tolist()))
    
    # Measure diversity: count unique individuals selected across replicates
    all_truncation_selected = set().union(*truncation_selections)
    all_thompson_selected = set().union(*thompson_selections)
    
    n_unique_truncation = len(all_truncation_selected)
    n_unique_thompson = len(all_thompson_selected)
    
    print(f"Truncation unique individuals: {n_unique_truncation}")
    print(f"Thompson unique individuals: {n_unique_thompson}")
    
    # Thompson sampling should explore more (select more unique individuals)
    # But this test might be sensitive, so we'll be conservative
    assert n_unique_thompson >= n_unique_truncation, "Thompson should explore at least as much as truncation"
    
    # For truncation, all replicates should be identical (deterministic)
    assert n_unique_truncation == n_select, "Truncation should always select the same individuals"
    
    print("Thompson sampling exploration test passed!")


def test_selection_with_no_variation():
    """
    Tests selection methods when there's no phenotypic variation
    (should handle gracefully).
    """
    print("Testing selection with zero variation...")
    
    key = jax.random.PRNGKey(999)
    founder_key = key
    
    # Create population
    n_ind = 50
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=100
    )
    
    # Create population with identical phenotypes (no variation)
    constant_pheno = jnp.ones((n_ind, 1)) * 5.0
    pop_no_var = pop.replace(
        pheno=constant_pheno,
        bv=jnp.ones((n_ind, 1)) * 2.0  # Also constant breeding values
    )
    
    # Test both selection methods
    truncation_selector = TruncationSelection()
    thompson_selector = ImprovedThompsonSampling()
    n_select = 10
    
    key, select_key1, select_key2 = jax.random.split(key, 3)
    
    # Both should handle no variation gracefully
    try:
        trunc_selected = truncation_selector.select_parents(
            select_key1, pop_no_var, None, n_select
        )
        assert len(trunc_selected) == n_select, "Should still select requested number"
        print("Truncation selection handled zero variation")
    except Exception as e:
        print(f"Truncation selection failed with zero variation: {e}")
    
    try:
        thompson_selected = thompson_selector.select_parents(
            select_key2, pop_no_var, None, n_select, h2=0.5
        )
        assert len(thompson_selected) == n_select, "Should still select requested number"
        print("Thompson sampling handled zero variation")
    except Exception as e:
        print(f"Thompson sampling failed with zero variation: {e}")
    
    print("Zero variation test passed!")


def test_selection_edge_cases():
    """
    Tests edge cases like selecting all individuals, selecting one individual, etc.
    """
    print("Testing selection edge cases...")
    
    key = jax.random.PRNGKey(111)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create small population
    n_ind = 10
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=1, n_loci_per_chr=100
    )
    
    # Add some variation
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=10,
        mean=jnp.array([0.0]), var=jnp.array([1.0])
    )
    
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([0.5])
    )
    
    selector = TruncationSelection()
    
    # Test selecting just one individual
    key, select_key1 = jax.random.split(key)
    selected_one = selector.select_parents(select_key1, pop_with_pheno, sp, 1)
    assert len(selected_one) == 1, "Should select exactly one individual"
    
    # Test selecting all individuals
    key, select_key2 = jax.random.split(key)
    selected_all = selector.select_parents(select_key2, pop_with_pheno, sp, n_ind)
    assert len(selected_all) == n_ind, "Should select all individuals"
    assert len(jnp.unique(selected_all)) == n_ind, "All individuals should be unique"
    
    print("Edge cases test passed!")


def test_selection_performance_comparison():
    """
    Tests that both selection methods can run and produce reasonable results
    on a moderately sized population.
    """
    print("Testing selection performance comparison...")
    
    key = jax.random.PRNGKey(222)
    founder_key, trait_key, pheno_key = jax.random.split(key, 3)
    
    # Create larger population
    n_ind = 200
    pop, gen_map = quick_haplo(
        key=founder_key, n_ind=n_ind, n_chr=2, n_loci_per_chr=500
    )
    
    # Add trait and phenotypes
    sp = SimParam.from_founder_pop(pop, gen_map)
    sp = add_trait_a(
        key=trait_key, founder_pop=pop, sim_param=sp, n_qtl_per_chr=100,
        mean=jnp.array([10.0]), var=jnp.array([4.0])
    )
    
    pop_with_pheno = set_pheno(
        key=pheno_key, pop=pop, traits=sp.traits, ploidy=sp.ploidy,
        h2=jnp.array([0.3])
    )
    
    # Test both methods
    truncation_selector = TruncationSelection()
    thompson_selector = ImprovedThompsonSampling()
    n_select = 50
    h2 = 0.3
    
    key, select_key1, select_key2 = jax.random.split(key, 3)
    
    # Run both selection methods
    trunc_selected = truncation_selector.select_parents(
        select_key1, pop_with_pheno, sp, n_select
    )
    
    thompson_selected = thompson_selector.select_parents(
        select_key2, pop_with_pheno, sp, n_select, h2=h2
    )
    
    print("Both selection methods completed successfully")
    
    # Both should produce valid results
    assert len(trunc_selected) == n_select
    assert len(thompson_selected) == n_select
    assert len(jnp.unique(trunc_selected)) == n_select
    assert len(jnp.unique(thompson_selected)) == n_select
    
    # Compare selection quality
    all_bv = pop_with_pheno.bv[:, 0]
    trunc_mean_bv = jnp.mean(all_bv[trunc_selected])
    thompson_mean_bv = jnp.mean(all_bv[thompson_selected])
    overall_mean_bv = jnp.mean(all_bv)
    
    print(f"Overall mean BV: {overall_mean_bv:.3f}")
    print(f"Truncation selected mean BV: {trunc_mean_bv:.3f}")
    print(f"Thompson selected mean BV: {thompson_mean_bv:.3f}")
    
    # Both should select above-average individuals
    assert trunc_mean_bv > overall_mean_bv, "Truncation should select above-average individuals"
    assert thompson_mean_bv > overall_mean_bv * 0.95, "Thompson should generally select good individuals"
    
    print("Performance comparison test passed!")


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()