In [1]:
#x01_populationStatistics

In [39]:
import matplotlib.pyplot as plt
from fastcore.basics import patch
import uuid
import pdb
import torch
from matplotlib.animation import FuncAnimation
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
device='cpu'

class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Population:
    def __init__(self, genome, haplotypes, device=device):
        self.genome = genome
        self.device = device
        self.phenotypes = None
        self.bvs = None
        self.haplotypes = haplotypes
        self.dosages = haplotypes.sum(dim=1).float()
        self.size = haplotypes.shape[0]
                
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean

        
def calculate_breeding_value(population_dosages, trait_effects, device = device):
    return torch.einsum('hjk,jk->h', population_dosages,trait_effects)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape    
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    #crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny


def phenotype(population, trait, h2):
    breeding_values = calculate_breeding_value(population.dosages, trait.effects) 
    
    if breeding_values.var() == 0:
#         print('phenotype: no var')
        environmental_variance = 0  
    else:
        environmental_variance = (1 - h2) / h2 * breeding_values.var() 
    
    # Check if environmental_variance is zero before applying torch.sqrt and .clone()
    if environmental_variance == 0:
        environmental_noise = torch.zeros(breeding_values.shape, device=device)
    else:
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance).detach()
    
    population.breeding_values = breeding_values
    population.phenotypes = breeding_values + environmental_noise
#     def _create_random_haplotypes(self,num_individuals):
#         return torch.randint(0, 2, (num_individuals, *self.g.shape), device=self.device)
def create_random_pop(G, pop_size):
    return torch.randint(0, 2, (pop_size, *G.shape), device= device)

def update_pop(population, haplotype_pop_tensor):
    population.haplotypes = haplotype_pop_tensor
    population.dosages = haplotype_pop_tensor.sum(dim=1).float()
    return population

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.1):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
#     crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * torch.logical_not(crossovers) + paternal * crossovers
    return progeny

def breed(mother_tensor, father_tensor, recombination_rate=0.1):
    eggs = recombine(mother_tensor,recombination_rate)
    pollens = recombine(father_tensor,recombination_rate)
    return torch.stack((eggs,pollens), dim=1)

def create_pop(G, haplotypes):
    return Population(G, haplotypes=haplotypes)

def bv(P,T):
    P.breeding_values = calculate_breeding_value(P.dosages,T.effects)
    
def create_progeny(mother_gametes, father_gametes,reps = 1):
    progeny = []
    for _ in range(reps):
        # Randomly shuffle the gametes from each parent 
        shuffled_mother_indices = torch.randperm(mother_gametes.shape[0])
        shuffled_father_indices = torch.randperm(father_gametes.shape[0])

        # Select the shuffled gametes
        mother_gametes = mother_gametes[shuffled_mother_indices]
        father_gametes = father_gametes[shuffled_father_indices]

        # Stack the gametes to create progeny haplotypes
        progeny_haplotypes = torch.stack((mother_gametes, father_gametes),dim=1)
        progeny.append(progeny_haplotypes)
    return torch.vstack(progeny)
    

class BreedingEnvironment:
    def __init__(self, G, T, h2, reps, pop_size, max_generations=10, variance_threshold=1e-6):
        self.G = G
        self.T = T
        self.h2 = h2
        self.reps = reps
        self.pop_size = pop_size
        self.max_generations = max_generations
        self.variance_threshold = variance_threshold

        # Create and store the initial population
        self.initial_haplotypes = create_random_pop(G, pop_size)
        self.initial_population = create_pop(G, self.initial_haplotypes)
        phenotype(self.initial_population, self.T, self.h2)
        
        # Initialize current population
        self.population = self.initial_population
        self.history = []
        self.current_generation = 0
        
    def step(self, selection_scores):
        # Log current population
        current_state = self.get_state()

        # Select parents based on actions
        selected_parent_indices = self.select_parents(actions)
        selected = self.population.haplotypes[selected_parent_indices]

        # Breeding
        m = recombine(selected)  # Mother gametes
        f = recombine(selected)  # Father gametes
        progeny = create_progeny(m, f, reps=action2[actions])  # Create progeny

        # Create new population from progeny
        new_pop = create_pop(self.G, progeny)
        phenotype(new_pop, self.T, self.h2)

        # Switch current population to progeny population
        self.population = new_pop
        # Check if episode is done
        done = self.is_done()
        # Calculate reward only if the episode is done
        reward = self.calculate_reward() if self.is_done() else 0
        reward = self.calculate_reward()
#         print(self.population.phenotypes.var())
        # Get new state
        new_state = self.get_state()

        # Increment generation counter
        self.current_generation += 1
        return new_state, reward, done

    def select_parents(self, selection_scores):
        k = int(self.pop_size * 0.1)  # Select top 10% as in the paper
        parents = torch.topk(selection_scores, k).indices
        return parents
    

    def calculate_reward(self):
        if self.current_generation == self.max_generations - 1:
            return self.population.phenotypes.max()
        return 0
    
    def is_done(self):
        return self.current_generation >= self.max_generations

    def reset(self):
        # Reset to the initial population
        self.population = create_pop(self.G, self.initial_haplotypes.clone())
        phenotype(self.population, self.T, self.h2)
        self.history = []
        self.current_generation = 0
        return self.get_state()
    
    def get_state(self):
        generation_progress = self.current_generation / self.max_generations
        return self.population.haplotypes, generation_progress


In [69]:
n_chr = 1
n_loci = 500
founder_pop_size = 200

G = Genome(n_chr, n_loci)
founder_pop = create_pop(G, create_random_pop(G, founder_pop_size))
T = Trait(G, founder_pop, target_mean=0.0, target_variance=1.0)


sim = BreedingEnvironment(G,T,h2=.5,reps=1,pop_size=founder_pop_size)
net = BreedingNetwork(n_loci*n_chr*2)

In [70]:
# Neural Network
class BreedingNetwork(nn.Module):
    def __init__(self, genotype_size, hidden_size=64):
        super(BreedingNetwork, self).__init__()
        self.fc1 = nn.Linear(genotype_size + 1, hidden_size)  # +1 for generation %
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
    
    def forward(self, genotype, generation_percent):
        x = torch.cat([genotype, generation_percent.unsqueeze(1)], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
    
# Agent
class BreedingAgent:
    def __init__(self, genotype_size, learning_rate=0.001):
        self.network = BreedingNetwork(genotype_size)
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        self.genotype_size = genotype_size
    
    def select_action(self, genotype, generation_percent):
        with torch.no_grad():
            action = self.network(genotype, generation_percent)
        return action.item()
    
    def update(self, genotype, generation_percent, reward):
        self.optimizer.zero_grad()
        action_value = self.network(genotype, generation_percent)
        loss = F.mse_loss(action_value, torch.tensor([reward]))
        loss.backward()
        self.optimizer.step()
        return loss.item()

In [71]:
BreedingAgent(n_loci)

<__main__.BreedingAgent at 0x7f5f579a7150>

In [73]:
input_net = sim.population.haplotypes.view(founder_pop_size, n_loci*n_chr*2)

In [79]:
net(input_net,torch.ones(200)*.5)

tensor([[ 0.1061],
        [-0.0283],
        [ 0.0145],
        [ 0.1000],
        [-0.0073],
        [-0.0155],
        [ 0.0316],
        [ 0.0099],
        [ 0.0568],
        [ 0.0663],
        [ 0.0082],
        [ 0.0284],
        [ 0.1492],
        [-0.0112],
        [ 0.0193],
        [ 0.0524],
        [ 0.0547],
        [ 0.1129],
        [ 0.0910],
        [-0.0030],
        [ 0.0185],
        [ 0.1059],
        [-0.0120],
        [ 0.0660],
        [ 0.0538],
        [ 0.0858],
        [ 0.0410],
        [ 0.0304],
        [ 0.0678],
        [ 0.0477],
        [ 0.0366],
        [ 0.0924],
        [ 0.0772],
        [ 0.0551],
        [ 0.0888],
        [ 0.0737],
        [ 0.0225],
        [ 0.0459],
        [-0.0043],
        [ 0.0049],
        [ 0.0682],
        [ 0.0187],
        [ 0.0400],
        [ 0.0264],
        [ 0.0715],
        [ 0.0327],
        [ 0.0908],
        [ 0.0263],
        [ 0.0003],
        [ 0.0678],
        [ 0.0065],
        [ 0.0121],
        [ 0.