In [1]:
#x01_populationStatistics

In [2]:
import matplotlib.pyplot as plt
import uuid
import pdb
import torch
from matplotlib.animation import FuncAnimation
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import gymnasium as gym
import numpy as np
device='cpu'

### BREEDING SIMULATOR
class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Population:
    def __init__(self, genome, haplotypes, device=device):
        self.genome = genome
        self.device = device
        self.phenotypes = None
        self.bvs = None
        self.haplotypes = haplotypes.to(device)
        self.dosages = haplotypes.sum(dim=1).float().to(device)
        self.size = haplotypes.shape[0]
                
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages).to(device)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean

        
def calculate_breeding_value(population, trait, device=device):
    return torch.einsum('hjk,jk->h', population.dosages, trait.effects).to(device)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices.to(device)

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape    
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:], parent_haplo_tensor[:,1,:,:]
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny

def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population, trait) 
    population.breeding_values = breeding_values
    if breeding_values.var() == 0:
        environmental_variance = 0  
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var() 
    
    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()
    
    population.breeding_values = breeding_values
    population.phenotypes = breeding_values + environmental_noise
    
    return population.phenotypes.max()

def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device=device)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices.to(device)

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape    
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:], parent_haplo_tensor[:,1,:,:]
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny

def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population, trait) 
    population.breeding_values = breeding_values
    if breeding_values.var() == 0:
        environmental_variance = 0  
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var() 
    
    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()
    
    population.breeding_values = breeding_values
    population.phenotypes = breeding_values + environmental_noise
    
    return population.phenotypes.max()

def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device=device)

def update_pop(population, haplotype_pop_tensor):
    population.haplotypes = haplotype_pop_tensor
    population.dosages = haplotype_pop_tensor.sum(dim=1).float()
    return population

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
#     crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * torch.logical_not(crossovers) + paternal * crossovers
    return progeny

def breed(mother_tensor, father_tensor, recombination_rate=0.1):
    eggs = recombine(mother_tensor,recombination_rate)
    pollens = recombine(father_tensor,recombination_rate)
    return torch.stack((eggs,pollens), dim=1)

def create_pop(G, haplotypes):
    return Population(G, haplotypes=haplotypes)

def bv(P,T):
    P.breeding_values = calculate_breeding_value(P.dosages,T.effects)
    
def create_progeny(mother_gametes, father_gametes, reps=1, device=device):
    progeny = []
    for _ in range(reps):
        # Randomly shuffle the gametes from each parent 
        shuffled_mother_indices = torch.randperm(mother_gametes.shape[0], device=device)
        shuffled_father_indices = torch.randperm(father_gametes.shape[0], device=device)

        # Select the shuffled gametes
        mother_gametes = mother_gametes[shuffled_mother_indices]
        father_gametes = father_gametes[shuffled_father_indices]

        # Stack the gametes to create progeny haplotypes
        progeny_haplotypes = torch.stack((mother_gametes, father_gametes), dim=1)
        progeny.append(progeny_haplotypes)
    return torch.vstack(progeny)

class SimParams:
    def __init__(self,G,T,h2,reps,pop_size,max_generations,founder_pop):
        self.G = G
        self.T = T
        self.h2 = h2
        self.reps = reps
        self.pop_size = pop_size
        self.max_generations = max_generations
        self.founder_pop = founder_pop

#RL ENVIRONMENT
class BreedingEnvironment(gym.Env):
    def __init__(self, SP):
        super(BreedingEnvironment, self).__init__()
        self.SP = SP
        self.current_generation = 0
        self.max_generations = SP.max_generations


        # Define action and observation space
        self.action_space = gym.spaces.Discrete(5)  # 0, 1, 2, 3, or 4
        self.action_history = []

        self.observation_space = gym.spaces.Dict({
            "population": gym.spaces.Box(low=0, high=1, shape=(200, 2, 1, 200), dtype=np.int32),
            "generation": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
        })
        
        self._action_to_direction = {0:(200,1),
               1:(100,2),
               2:(50,4),
               3:(25,8),
               4:(5,40),}
        
    def _get_obs(self):
        return {
            "population": self.population.haplotypes.cpu(),
            "generation": torch.tensor([self.generation / self.SP.max_generations], dtype=torch.float32).cpu()
        }
                
    def _get_info(self):
        return {
            "phenotype": self.phenotype.cpu().item(),
            "genetic_variance": self.population.breeding_values.var().cpu().item()
        }
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.population = self.SP.founder_pop
        self.phenotype = phenotype(self.population, self.SP.T, self.SP.h2)
        self.generation = 0
        
        observation = self._get_obs()
        info = self._get_info()
        self.action_history = []
        return observation, info
        
    def step(self, action):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
#         print(action)
        total_parents, total_crosses = self._action_to_direction[int(action)] # top parents, number crosses per

        top_k = torch.topk(self.population.phenotypes, total_parents).indices
        selected = self.population.haplotypes[top_k]
        
        # Breeding
        m = recombine(selected)  # Mother gametes
        f = recombine(selected)  # Father gametes
        progeny = create_progeny(m, f, reps=total_crosses)                 
                               
        # Create new population from progeny
        self.population = create_pop(self.SP.G, progeny)
        self.phenotype = phenotype(self.population, self.SP.T, self.SP.h2)
        
        
        self.generation+=1
        observation = self._get_obs()
        info = self._get_info()
        terminated = self.generation >= self.SP.max_generations
        truncated = False
        reward = self.phenotype
        self.action_history.append((action))
        self.current_generation += 1
        done = self.current_generation >= self.max_generations
        
        if done:
            info['final_generation'] = {
                'phenotype': self.phenotype,
                'genetic_variance': self.population.breeding_values.var().item()
            }
        return observation, reward, terminated, truncated, info

    
#RLAGENT

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define our custom feature extractor
class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, features_dim: int = 64):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
        
        population_shape = observation_space['population'].shape
        self.input_size = population_shape[0] * population_shape[1] * population_shape[3]  # Flattened population size
        
        self.fc1 = nn.Linear(self.input_size + 1, features_dim).to(device)  # +1 for generation input
        self.fc2 = nn.Linear(features_dim, features_dim).to(device)

    def forward(self, observations):
        population = observations['population'].float().to(device)
        generation = observations['generation'].float().to(device)
        
        # Flatten the population input
        x = population.view(population.size(0), -1)
        
        # Concatenate generation input
        x = torch.cat([x, generation], dim=1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        return x

# Define our custom policy
class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, *args, **kwargs):
        super(CustomActorCriticPolicy, self).__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch=dict(pi=[64, 64], vf=[64, 64]),
            features_extractor_class=CustomFeatureExtractor,
            features_extractor_kwargs=dict(features_dim=8),
            *args,
            **kwargs
        )
        
#LOGGING
import torch
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback

class AverageFinalGenerationCallback(BaseCallback):
    def __init__(self, log_freq=100, verbose=0):
        super(AverageFinalGenerationCallback, self).__init__(verbose)
        self.log_freq = log_freq
        self.phenotypes = []
        self.genetic_variances = []
        self.best_phenotype = -float('inf')
        self.episode_count = 0

    def _on_step(self) -> bool:
        for env_idx, done in enumerate(self.locals['dones']):
            if done:
                info = self.locals['infos'][env_idx]
                if 'final_generation' in info:
                    self.episode_count += 1
                    final_gen_info = info['final_generation']
                    
                    # Move tensors to CPU and convert to numpy
                    phenotype = self._to_numpy(final_gen_info['phenotype'])
                    genetic_variance = self._to_numpy(final_gen_info['genetic_variance'])
                    
                    self.phenotypes.append(phenotype)
                    self.genetic_variances.append(genetic_variance)
                    
                    # Update best phenotype
                    self.best_phenotype = max(self.best_phenotype, phenotype)

                    # Log every log_freq episodes
                    if self.episode_count % self.log_freq == 0:
                        avg_phenotype = np.mean(self.phenotypes)
                        avg_genetic_variance = np.mean(self.genetic_variances)
                        
                        self.logger.record("final_generation/avg_phenotype", avg_phenotype)
                        self.logger.record("final_generation/avg_genetic_variance", avg_genetic_variance)
                        self.logger.record("final_generation/best_phenotype", self.best_phenotype)
                        
                        # Reset lists for next logging period
                        self.phenotypes = []
                        self.genetic_variances = []

        return True

    def _to_numpy(self, tensor):
        if isinstance(tensor, torch.Tensor):
            return tensor.cpu().numpy()
        return tensor


In [4]:

# n_chr = 1
# n_loci = 200
# founder_pop_size = 200
# h2 = .3
# reps=1
# max_generations=20
# G = Genome(n_chr, n_loci)
# founder_pop = create_pop(G, create_random_pop(G, founder_pop_size))
# T = Trait(G, founder_pop, target_mean=0.0, target_variance=1.0)
# SP = SimParams(G,T,h2, reps, founder_pop_size, max_generations, founder_pop)

# from stable_baselines3 import PPO
# from gymnasium.spaces import Box
# import gymnasium.spaces as spaces

# env = BreedingEnvironment(SP)
# obs,info = env.reset()

# #BASELINES
# # Create the environment
# env = BreedingEnvironment(SP)
# env.reset()
# #calculate naive (e.g. choosing the same option every time)
# results = {}
# REPS=3
# for a in range(5):
#     env.reset()
#     results[a] = []
#     for r in range(REPS): #reps
#         env.reset()
#         for i in range(env.SP.max_generations):
#             env.step(a)
#         results[a].append(env.phenotype.cpu().numpy())
# random_results = []
# for i in results.keys():
#     print(np.array(results[i]).mean())

# #set up tensorboard
# # %load_ext tensorboard
# tensorboard_log = f'./ppo_breeding_tensorboard/'
# # %tensorboard --logdir {tensorboard_log}


# from stable_baselines3.common.vec_env import DummyVecEnv
# from stable_baselines3.common.monitor import Monitor

# bs = 2

# model = PPO(CustomActorCriticPolicy, env, batch_size=bs, n_steps=bs*10, device='cuda', verbose=0, 
#             tensorboard_log=tensorboard_log)

# # Create an instance of your custom callback
# callback = AverageFinalGenerationCallback(log_freq=1)

# # Set up TensorBoard logging
# tb_log_name = f"PPO_bs_{bs}_nsteps_{bs*10}"

# # Train the model with the custom callback
# model.learn(total_timesteps=10000, callback=callback, tb_log_name=tb_log_name)
  

5.9014497
10.319417
10.001578
7.9427414
2.8446248


<stable_baselines3.ppo.ppo.PPO at 0x7f372a2cdc90>