# popstats

> Fill in a module description here

In [None]:
#| default_exp gym

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
"""
Baseline Gymnasium environment for truncation selection using ChewC genetic simulation library.

This environment simulates a breeding program where agents select the top proportion of individuals
based on their phenotypic values, then randomly mate them to produce the next generation.
"""

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, Any, Tuple, Optional
from functools import partial

# Assuming ChewC imports (adjust paths as needed)
from chewc.population import Population, quick_haplo
from chewc.sp import SimParam
from chewc.trait import add_trait_a
from chewc.phenotype import set_pheno
from chewc.pipe import run_generation


class TruncationSelectionEnv(gym.Env):
    """
    A Gymnasium environment for truncation selection in genetic breeding programs.
    
    The agent selects a proportion of the population based on phenotypic values,
    and these selected individuals are randomly mated to produce the next generation.
    
    Args:
        pop_size (int): Population size (constant across generations)
        n_generations (int): Number of generations per episode
        heritability (float): Heritability of the trait
        trait_mean (float): Initial mean of the trait
        trait_var (float): Initial variance of the trait
        n_qtl_per_chr (int): Number of QTL per chromosome
        n_chr (int): Number of chromosomes
        map_length (float): Genetic map length per chromosome (in Morgans)
        fixed_action (Optional[float]): If provided, uses this fixed selection proportion
        seed (int): Random seed for reproducibility
    """
    
    metadata = {"render_modes": ["human"]}
    
    def __init__(
        self,
        pop_size: int = 200,
        n_generations: int = 20,
        heritability: float = 0.5,
        trait_mean: float = 0.0,
        trait_var: float = 1.0,
        n_qtl_per_chr: int = 50,
        n_chr: int = 10,
        map_length: float = 1.0,
        fixed_action: Optional[float] = None,
        seed: int = 42,
        render_mode: Optional[str] = None
    ):
        super().__init__()
        
        # Environment parameters
        self.pop_size = pop_size
        self.n_generations = n_generations
        self.heritability = heritability
        self.trait_mean = trait_mean
        self.trait_var = trait_var
        self.n_qtl_per_chr = n_qtl_per_chr
        self.n_chr = n_chr
        self.map_length = map_length
        self.fixed_action = fixed_action
        self.render_mode = render_mode
        
        # JAX random key management
        self.master_key = jax.random.PRNGKey(seed)
        self.current_key = self.master_key
        
        # Define action and observation spaces
        if fixed_action is not None:
            # Fixed action mode - action space is dummy since it won't be used
            self.action_space = spaces.Box(low=0.1, high=0.9, shape=(1,), dtype=np.float32)
        else:
            # Variable action mode - selection proportion between 0.1 and 0.9
            self.action_space = spaces.Box(low=0.1, high=0.9, shape=(1,), dtype=np.float32)
        
        # Observation space: [trait_mean, trait_var, remaining_generations]
        self.observation_space = spaces.Box(
            low=np.array([-np.inf, 0.0, 0.0]),
            high=np.array([np.inf, np.inf, float(n_generations)]),
            dtype=np.float32
        )
        
        # Initialize simulation parameters
        self._init_simulation()
        
        # Episode tracking
        self.current_generation = 0
        self.current_pop = None
        self.episode_history = []
    
    def _init_simulation(self):
        """Initialize the genetic simulation parameters."""
        # Create genetic map (uniform spacing)
        loci_per_chr = 100  # Number of loci per chromosome
        positions = np.linspace(0, self.map_length, loci_per_chr)
        gen_map = np.tile(positions, (self.n_chr, 1))
        
        # Create SimParam object
        self.sim_param = SimParam(
            gen_map=jnp.array(gen_map),
            centromere=jnp.zeros(self.n_chr),  # Centromeres at position 0
            ploidy=2,
            recomb_params=(2.6, 0.0, 0.0),  # Standard interference parameters
            sexes="no",
            track_pedigree=False
        )
        
        # Create founder population
        self.current_key, subkey = jax.random.split(self.current_key)
        founder_pop = quick_haplo(subkey, self.sim_param, self.pop_size, inbred=False)
        
        # Add additive trait
        self.current_key, subkey = jax.random.split(self.current_key)
        self.sim_param = add_trait_a(
            subkey, 
            self.sim_param.replace(founderPop=founder_pop),
            n_qtl_per_chr=self.n_qtl_per_chr,
            mean=jnp.array([self.trait_mean]),
            var=jnp.array([self.trait_var])
        )
        
        # Set phenotypes for founder population
        self.current_key, subkey = jax.random.split(self.current_key)
        self.founder_pop = set_pheno(
            subkey,
            founder_pop,
            self.sim_param.traits,
            self.sim_param.ploidy,
            h2=jnp.array([self.heritability])
        )
    
    def reset(self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Reset the environment to start a new episode."""
        if seed is not None:
            self.master_key = jax.random.PRNGKey(seed)
            self.current_key = self.master_key
            # Re-initialize simulation with new seed
            self._init_simulation()
        
        # Reset episode state
        self.current_generation = 0
        self.current_pop = self.founder_pop
        self.episode_history = []
        
        # Get initial observation
        obs = self._get_observation()
        info = self._get_info()
        
        return obs, info
    
    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        """Execute one step of the environment."""
        if self.current_pop is None:
            raise RuntimeError("Environment not initialized. Call reset() first.")
        
        # Use fixed action if specified, otherwise use provided action
        if self.fixed_action is not None:
            selection_proportion = self.fixed_action
        else:
            selection_proportion = float(action[0])
        
        # Ensure action is within bounds
        selection_proportion = np.clip(selection_proportion, 0.1, 0.9)
        
        # Calculate number of parents to select
        n_parents = int(self.pop_size * selection_proportion)
        n_parents = max(2, n_parents)  # Ensure at least 2 parents
        
        # Run one generation of selection and breeding
        self.current_key, subkey = jax.random.split(self.current_key)
        
        try:
            self.current_pop = run_generation(
                key=subkey,
                pop=self.current_pop,
                h2=jnp.array([self.heritability]),
                n_parents=n_parents,
                n_crosses=self.pop_size,
                use_pheno_selection=True,
                select_top_parents=True,
                ploidy=self.sim_param.ploidy,
                gen_map=self.sim_param.gen_map,
                recomb_param_v=self.sim_param.recomb_params[0],
                traits=self.sim_param.traits
            )
        except Exception as e:
            # Fallback error handling
            print(f"Error in run_generation: {e}")
            # Return current state with zero reward
            obs = self._get_observation()
            info = self._get_info()
            return obs, 0.0, True, False, info
        
        self.current_generation += 1
        
        # Calculate reward (genetic gain)
        current_mean = float(jnp.mean(self.current_pop.pheno[:, 0]))
        if len(self.episode_history) > 0:
            previous_mean = self.episode_history[-1]['trait_mean']
            reward = current_mean - previous_mean
        else:
            # First generation reward is improvement over founder mean
            founder_mean = float(jnp.mean(self.founder_pop.pheno[:, 0]))
            reward = current_mean - founder_mean
        
        # Store history
        self.episode_history.append({
            'generation': self.current_generation,
            'trait_mean': current_mean,
            'trait_var': float(jnp.var(self.current_pop.pheno[:, 0])),
            'selection_proportion': selection_proportion,
            'reward': reward
        })
        
        # Check if episode is done
        terminated = self.current_generation >= self.n_generations
        truncated = False
        
        obs = self._get_observation()
        info = self._get_info()
        
        return obs, reward, terminated, truncated, info
    
    def _get_observation(self) -> np.ndarray:
        """Get current observation."""
        if self.current_pop is None:
            # Return initial state
            return np.array([self.trait_mean, self.trait_var, float(self.n_generations)], dtype=np.float32)
        
        trait_mean = float(jnp.mean(self.current_pop.pheno[:, 0]))
        trait_var = float(jnp.var(self.current_pop.pheno[:, 0]))
        remaining_generations = float(self.n_generations - self.current_generation)
        
        return np.array([trait_mean, trait_var, remaining_generations], dtype=np.float32)
    
    def _get_info(self) -> Dict[str, Any]:
        """Get additional information about the current state."""
        if self.current_pop is None:
            return {"generation": 0, "history": []}
        
        return {
            "generation": self.current_generation,
            "population_size": self.pop_size,
            "trait_mean": float(jnp.mean(self.current_pop.pheno[:, 0])),
            "trait_var": float(jnp.var(self.current_pop.pheno[:, 0])),
            "breeding_values_mean": float(jnp.mean(self.current_pop.bv[:, 0])) if self.current_pop.bv is not None else None,
            "history": self.episode_history.copy()
        }
    
    def render(self):
        """Render the environment state."""
        if self.render_mode == "human":
            if self.current_pop is not None:
                trait_mean = float(jnp.mean(self.current_pop.pheno[:, 0]))
                trait_var = float(jnp.var(self.current_pop.pheno[:, 0]))
                print(f"Generation {self.current_generation}/{self.n_generations}")
                print(f"Trait Mean: {trait_mean:.4f}, Trait Variance: {trait_var:.4f}")
                if len(self.episode_history) > 0:
                    recent_gain = self.episode_history[-1]['reward']
                    print(f"Recent Genetic Gain: {recent_gain:.4f}")
                print("-" * 50)
    
    def close(self):
        """Clean up the environment."""
        pass


# Example usage and testing
if __name__ == "__main__":
    # Test with fixed action (50% selection)
    env = TruncationSelectionEnv(fixed_action=0.5, render_mode="human")
    
    obs, info = env.reset(seed=42)
    print("Initial observation:", obs)
    print("Initial info:", info)
    
    total_reward = 0.0
    for step in range(20):  # Run for 20 generations
        # Action doesn't matter since we're using fixed_action
        action = np.array([0.5])  
        obs, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        
        env.render()
        
        if terminated or truncated:
            break
    
    print(f"\nEpisode finished! Total reward: {total_reward:.4f}")
    print(f"Final trait mean: {info['trait_mean']:.4f}")
    
    # Test without fixed action
    print("\n" + "="*60)
    print("Testing variable action mode")
    
    env2 = TruncationSelectionEnv(fixed_action=None)
    obs, info = env2.reset(seed=123)
    
    # Use different selection intensities
    selection_intensities = [0.1, 0.3, 0.5, 0.7, 0.9] * 4  # 20 generations
    
    for i, selection_intensity in enumerate(selection_intensities):
        action = np.array([selection_intensity])
        obs, reward, terminated, truncated, info = env2.step(action)
        
        if i % 5 == 0:  # Print every 5 generations
            print(f"Gen {info['generation']}: Selection {selection_intensity:.1f}, "
                  f"Mean {info['trait_mean']:.3f}, Reward {reward:.4f}")
        
        if terminated or truncated:
            break
    
    env.close()
    env2.close()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()