In [1]:
#x01_populationStatistics

In [2]:
import matplotlib.pyplot as plt
import uuid
import pdb
import torch
from matplotlib.animation import FuncAnimation
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import gymnasium as gym
import numpy as np
device='cpu'

### BREEDING SIMULATOR
class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)

class Population:
    def __init__(self, genome, haplotypes, device=device):
        self.genome = genome
        self.device = device
        self.phenotypes = None
        self.bvs = None
        self.haplotypes = haplotypes.to(device)
        self.dosages = haplotypes.sum(dim=1).float().to(device)
        self.size = haplotypes.shape[0]

class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages).to(device)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean


def calculate_breeding_value(population, trait, device=device):
    return torch.einsum('hjk,jk->h', population.dosages, trait.effects).to(device)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices.to(device)

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:], parent_haplo_tensor[:,1,:,:]
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny



def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device=device)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices.to(device)

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:], parent_haplo_tensor[:,1,:,:]
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny

def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population, trait)
    population.breeding_values = breeding_values
    if breeding_values.var() == 0:
        environmental_variance = 0
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var()

    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()

    population.breeding_values = breeding_values
    population.phenotypes = breeding_values + environmental_noise

    return population.phenotypes.max()

def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device=device)

def update_pop(population, haplotype_pop_tensor):
    population.haplotypes = haplotype_pop_tensor
    population.dosages = haplotype_pop_tensor.sum(dim=1).float()
    return population

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
#     crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * torch.logical_not(crossovers) + paternal * crossovers
    return progeny

def breed(mother_tensor, father_tensor, recombination_rate=0.1):
    eggs = recombine(mother_tensor,recombination_rate)
    pollens = recombine(father_tensor,recombination_rate)
    return torch.stack((eggs,pollens), dim=1)

def create_pop(G, haplotypes):
    return Population(G, haplotypes=haplotypes)

def bv(P,T):
    P.breeding_values = calculate_breeding_value(P.dosages,T.effects)

def create_progeny(mother_gametes, father_gametes, reps=1, device=device):
    progeny = []
    for _ in range(reps):
        # Randomly shuffle the gametes from each parent
        shuffled_mother_indices = torch.randperm(mother_gametes.shape[0], device=device)
        shuffled_father_indices = torch.randperm(father_gametes.shape[0], device=device)

        # Select the shuffled gametes
        mother_gametes = mother_gametes[shuffled_mother_indices]
        father_gametes = father_gametes[shuffled_father_indices]

        # Stack the gametes to create progeny haplotypes
        progeny_haplotypes = torch.stack((mother_gametes, father_gametes), dim=1)
        progeny.append(progeny_haplotypes)
    return torch.vstack(progeny)


def random_crosses(parent_population, total_crosses, device=device):
    num_parents = parent_population.shape[0]
    ploidy, num_chromosomes, num_loci = parent_population.shape[1:]
    
    # Randomly select parents for each cross
    parent_indices = torch.randint(0, num_parents, (total_crosses, 2), device=device)
    
    # Select the parent haplotypes
    mothers = parent_population[parent_indices[:, 0]]
    fathers = parent_population[parent_indices[:, 1]]
    
    # Perform recombination for both parents
    mother_gametes = recombine(mothers)
    father_gametes = recombine(fathers)
    
    # Combine gametes to create progeny
    progeny = torch.stack((mother_gametes, father_gametes), dim=1)
    
    return progeny

class SimParams:
    def __init__(self,G,T,h2,reps,pop_size,max_generations,founder_pop):
        self.G = G
        self.T = T
        self.h2 = h2
        self.reps = reps
        self.pop_size = pop_size
        self.max_generations = max_generations
        self.founder_pop = founder_pop

#RL ENVIRONMENT
class BreedingEnvironment(gym.Env):
    def __init__(self, SP):
        super(BreedingEnvironment, self).__init__()
        self.SP = SP
        self.current_generation = 0
        self.max_generations = SP.max_generations


        # Define action and observation space
        self.action_space = gym.spaces.Discrete(5)  # 0, 1, 2, 3, or 4
        self.action_history = []

        self.observation_space = gym.spaces.Dict({
            "population": gym.spaces.Box(low=0, high=1, shape=(200, 2, 1, 200), dtype=np.int32),
            "generation": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
        })

        self._action_to_direction = {0:(200,1),
               1:(100,2),
               2:(50,4),
               3:(25,8),
               4:(5,40),}

    def _get_obs(self):
        return {
            "population": self.population.haplotypes.cpu(),
            "generation": torch.tensor([self.generation / self.SP.max_generations], dtype=torch.float32).cpu()
        }

    def _get_info(self):
        return {
            "phenotype": self.phenotype.cpu().item(),
            "genetic_variance": self.population.breeding_values.var().cpu().item()
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.population = self.SP.founder_pop
        self.phenotype = phenotype(self.population, self.SP.T, self.SP.h2)
        self.generation = 0

        observation = self._get_obs()
        info = self._get_info()
        self.action_history = []
        return observation, info

    def step(self, action):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
#         print(action)
        total_parents, total_crosses = self._action_to_direction[int(action)] # top parents, number crosses per

        top_k = torch.topk(self.population.phenotypes, total_parents).indices
        selected = self.population.haplotypes[top_k]

        # Breeding
        m = recombine(selected)  # Mother gametes
        f = recombine(selected)  # Father gametes
        progeny = create_progeny(m, f, reps=total_crosses)

        # Create new population from progeny
        self.population = create_pop(self.SP.G, progeny)
        self.phenotype = phenotype(self.population, self.SP.T, self.SP.h2)


        self.generation+=1
        observation = self._get_obs()
        info = self._get_info()
        terminated = self.generation >= self.SP.max_generations
        truncated = False
        reward = self.phenotype
        self.action_history.append((action))
        self.current_generation += 1
        done = self.current_generation >= self.max_generations

        if done:
            info['final_generation'] = {
                'phenotype': self.phenotype,
                'genetic_variance': self.population.breeding_values.var().item()
            }
        return observation, reward, terminated, truncated, info



In [3]:
def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population, trait)
    population.breeding_values = breeding_values
    population.genetic_var = breeding_values.var()
    if breeding_values.var() == 0:
        environmental_variance = 0
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var()

    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()
    population.phenotypes = breeding_values + environmental_noise


In [4]:
########

In [10]:
import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_checker import check_env

class SelectionIntensityEnvironment(gym.Env):
    def __init__(self, SP):
        super(SelectionIntensityEnvironment, self).__init__()
        self.SP = SP
        self.current_generation = 0
        self.max_generations = SP.max_generations
        # Define action space
        self.action_space = spaces.Box(
            low=np.array([0.1]), 
            high=np.array([1.0]), 
            dtype=np.float32
        )
        # Define observation space
        self.observation_space = gym.spaces.Dict({
            "population": gym.spaces.Box(low=0, high=1, shape=(self.SP.pop_size, 2, self.SP.G.n_chr, self.SP.G.n_loci), dtype=np.int32),
            "generation": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
        })
        # Logging attributes
        self.action_values = []
        self.genetic_variance = []
        self.max_breeding_values = []
        self.final_generations = []
        self.episode_count = 0
        
    def _get_obs(self):
        population = self.population.haplotypes.cpu().numpy().astype(np.int32)
        generation = np.array([self.current_generation / self.SP.max_generations], dtype=np.float32)
        return {
            "population": population,
            "generation": generation
        }

    def _get_info(self):
        return {
            "phenotype": self.population.phenotypes.max().cpu().item(),
            "genetic_variance": self.population.breeding_values.var().cpu().item()
        }

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.population = self.SP.founder_pop
        self.phenotype = phenotype(self.population, self.SP.T, self.SP.h2)

        observation = self._get_obs()
        info = self._get_info()
        return observation, info

    def step(self, action):
        action_scalar = action.item() if isinstance(action, np.ndarray) else action
        selected = torch.topk(self.population.phenotypes, int(action_scalar * self.population.size)).indices
        new_pop = random_crosses(self.population.haplotypes[selected], self.SP.pop_size)
        
        self.current_generation += 1
        observation = self._get_obs()
        info = self._get_info()
        terminated = self.current_generation >= self.SP.max_generations or self.population.genetic_var < .005
        reward = float(self.population.breeding_values.max())
        
        
        self.action_values.append((self.current_generation, action_scalar))
        # Log genetic variance and max breeding value at the end of each generation
        if terminated:
            self.genetic_variance.append(self.population.genetic_var)
            self.max_breeding_values.append(float(self.population.breeding_values.max()))
            self.final_generations.append(self.current_generation)
            self.episode_count += 1
        
        return observation, reward, bool(terminated), False, info

# Example usage
n_chr = 1
n_loci = 222
founder_pop_size = 333
h2 = 1
reps = 1
max_generations = 20
G = Genome(n_chr, n_loci)
founder_pop = create_pop(G, create_random_pop(G, 5))
founder_pop = random_crosses(founder_pop.haplotypes, founder_pop_size)
founder_pop = create_pop(G, founder_pop)
T = Trait(G, founder_pop, target_mean=0.0, target_variance=1.0)
SP = SimParams(G, T, h2, reps, founder_pop_size, max_generations, founder_pop)

env = SelectionIntensityEnvironment(SP)
check_env(env)

vec_env = DummyVecEnv([lambda: env])
model = PPO("MultiInputPolicy", vec_env, verbose=1)
total_timesteps = 1000
model.learn(total_timesteps=total_timesteps)

Using cpu device
-----------------------------
| time/              |      |
|    fps             | 141  |
|    iterations      | 1    |
|    time_elapsed    | 14   |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x7f3db6872c50>