In [1]:
#x01_populationStatistics

In [2]:
import matplotlib.pyplot as plt
from fastcore.basics import patch
import uuid
import pdb
import torch
from matplotlib.animation import FuncAnimation
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device='cpu'

class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Population:
    def __init__(self, genome, haplotypes, device=device):
        self.genome = genome
        self.device = device
        self.phenotypes = None
        self.bvs = None
        self.haplotypes = haplotypes
        self.dosages = haplotypes.sum(dim=1).float()
        self.size = haplotypes.shape[0]
                
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean

        
def calculate_breeding_value(population_dosages, trait_effects, device = device):
    return torch.einsum('hjk,jk->h', population_dosages,trait_effects)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape    
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    #crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny


def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population.dosages, trait.effects) 
    
    if breeding_values.var() == 0:
#         print('phenotype: no var')
        environmental_variance = 0  
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var() 
    
    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()
    
    population.breeding_values = breeding_values
    population.phenotypes = breeding_values + environmental_noise
#     def _create_random_haplotypes(self,num_individuals):
#         return torch.randint(0, 2, (num_individuals, *self.g.shape), device=self.device)
def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device= device)

def update_pop(population, haplotype_pop_tensor):
    population.haplotypes = haplotype_pop_tensor
    population.dosages = haplotype_pop_tensor.sum(dim=1).float()
    return population

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
#     crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * torch.logical_not(crossovers) + paternal * crossovers
    return progeny

def breed(mother_tensor, father_tensor, recombination_rate=0.1):
    eggs = recombine(mother_tensor,recombination_rate)
    pollens = recombine(father_tensor,recombination_rate)
    return torch.stack((eggs,pollens), dim=1)

def create_pop(G, haplotypes):
    return Population(G, haplotypes=haplotypes)

def bv(P,T):
    P.breeding_values = calculate_breeding_value(P.dosages,T.effects)
    
def create_progeny(mother_gametes, father_gametes,reps = 1):
    progeny = []
    for _ in range(reps):
        # Randomly shuffle the gametes from each parent 
        shuffled_mother_indices = torch.randperm(mother_gametes.shape[0])
        shuffled_father_indices = torch.randperm(father_gametes.shape[0])

        # Select the shuffled gametes
        mother_gametes = mother_gametes[shuffled_mother_indices]
        father_gametes = father_gametes[shuffled_father_indices]

        # Stack the gametes to create progeny haplotypes
        progeny_haplotypes = torch.stack((mother_gametes, father_gametes),dim=1)
        progeny.append(progeny_haplotypes)
    return torch.vstack(progeny)
    

class BreedingEnvironment:
    def __init__(self, G, T, h2, reps, pop_size, max_generations=10, variance_threshold=1e-6):
        self.G = G
        self.T = T
        self.h2 = h2
        self.reps = reps
        self.pop_size = pop_size
        self.max_generations = max_generations
        self.variance_threshold = variance_threshold

        # Create and store the initial population
        self.initial_haplotypes = create_random_pop(G, pop_size)
        self.initial_population = create_pop(G, self.initial_haplotypes)
        phenotype(self.initial_population, self.T, self.h2)
        
        # Initialize current population
        self.population = self.initial_population
        self.history = []
        self.current_generation = 0
        
    def step(self, selection_scores):
        # Log current population
        current_state = self.get_state()

        # Select parents based on actions
        selected_parent_indices = self.select_parents(actions)
        selected = self.population.haplotypes[selected_parent_indices]

        # Breeding
        m = recombine(selected)  # Mother gametes
        f = recombine(selected)  # Father gametes
        progeny = create_progeny(m, f, reps=action2[actions])  # Create progeny

        # Create new population from progeny
        new_pop = create_pop(self.G, progeny)
        phenotype(new_pop, self.T, self.h2)

        # Switch current population to progeny population
        self.population = new_pop
        # Check if episode is done
        done = self.is_done()
        # Calculate reward only if the episode is done
        reward = self.calculate_reward() if self.is_done() else 0
        reward = self.calculate_reward()
#         print(self.population.phenotypes.var())
        # Get new state
        new_state = self.get_state()

        # Increment generation counter
        self.current_generation += 1
        return new_state, reward, done

    def select_parents(self, selection_scores):
        k = int(self.pop_size * 0.1)  # Select top 10% as in the paper
        parents = torch.topk(selection_scores, k).indices
        return parents
    

    def calculate_reward(self):
        if self.current_generation == self.max_generations - 1:
            return self.population.phenotypes.max()
        return 0
    
    def is_done(self):
        return self.current_generation >= self.max_generations

    def reset(self):
        # Reset to the initial population
        self.population = create_pop(self.G, self.initial_haplotypes.clone())
        phenotype(self.population, self.T, self.h2)
        self.history = []
        self.current_generation = 0
        return self.get_state()
    
    def get_state(self):
        generation_progress = self.current_generation / self.max_generations
        return self.population.haplotypes, generation_progress

    
    # Neural Network
class BreedingNetwork(nn.Module):
    def __init__(self, genotype_size, hidden_size=64):
        super(BreedingNetwork, self).__init__()
        self.fc1 = nn.Linear(genotype_size + 1, hidden_size)  # +1 for generation %
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
    
    def forward(self, genotype, generation_percent):
        generation_percent = torch.ones(genotype.shape[0])*generation_percent
        x = torch.cat([genotype, generation_percent.unsqueeze(1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class PPOAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=3e-4, gamma=0.99, epsilon=0.2, value_coef=0.5, entropy_coef=0.01):
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.optimizer = optim.Adam(list(self.actor.parameters()) + list(self.critic.parameters()), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef

    def get_action(self, state):
        probs = self.actor(state)
        action = torch.multinomial(probs, 1)
        return action.item(), probs[0, action.item()].item()

    def update(self, states, actions, old_probs, rewards, next_states, dones):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        old_probs = torch.FloatTensor(old_probs)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Compute advantages
        values = self.critic(states).squeeze()
        next_values = self.critic(next_states).squeeze()
        advantages = rewards + self.gamma * next_values * (1 - dones) - values

        # Compute actor loss
        new_probs = self.actor(states)
        new_probs = new_probs.gather(1, actions.unsqueeze(1)).squeeze()
        ratio = new_probs / old_probs
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        # Compute critic loss
        critic_loss = (rewards + self.gamma * next_values * (1 - dones) - values).pow(2).mean()

        # Compute entropy bonus
        entropy = -(new_probs * torch.log(new_probs)).mean()

        # Total loss
        loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy

        # Update networks
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()
    
n_chr = 1
n_loci = 500
founder_pop_size = 200

G = Genome(n_chr, n_loci)
founder_pop = create_pop(G, create_random_pop(G, founder_pop_size))
T = Trait(G, founder_pop, target_mean=0.0, target_variance=1.0)


sim = BreedingEnvironment(G,T,h2=.5,reps=1,pop_size=founder_pop_size)
net = BreedingNetwork(n_loci*n_chr*2)

input_net = sim.population.haplotypes.view(founder_pop_size, n_loci*n_chr*2)
#example input
net(input_net,  0.5)

tensor([[-0.0115],
        [ 0.0443],
        [ 0.0966],
        [ 0.1589],
        [ 0.0359],
        [ 0.1215],
        [ 0.0256],
        [ 0.1162],
        [ 0.0761],
        [ 0.1222],
        [ 0.0108],
        [ 0.1478],
        [ 0.0938],
        [ 0.1359],
        [ 0.0126],
        [ 0.1062],
        [ 0.0533],
        [ 0.0419],
        [ 0.1082],
        [ 0.0670],
        [ 0.0770],
        [ 0.1396],
        [ 0.0725],
        [ 0.0962],
        [ 0.1536],
        [ 0.0791],
        [ 0.0505],
        [ 0.0868],
        [ 0.0786],
        [ 0.0854],
        [ 0.0248],
        [ 0.0910],
        [ 0.1261],
        [ 0.1952],
        [ 0.1334],
        [ 0.1311],
        [ 0.0898],
        [ 0.1687],
        [ 0.1043],
        [ 0.0831],
        [ 0.0560],
        [ 0.1386],
        [ 0.1083],
        [ 0.1072],
        [ 0.0761],
        [ 0.0632],
        [ 0.1177],
        [ 0.0033],
        [ 0.0366],
        [ 0.0608],
        [ 0.0932],
        [ 0.1157],
        [ 0.