In [None]:
#x02_chewc

In [None]:
import torch
import matplotlib.pyplot as plt
from fastcore.basics import patch
import uuid

In [None]:
import torch

device='cpu'

class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Population:
    def __init__(self, genome, haplotypes=None, device=device):
        self.g = genome
        self.device = device
        if haplotypes:
            self.haplotypes=haplotypes
        else:
            self.haplotypes = self._create_random_haplotypes(1000)
        self._calculate_dosages()
        self.phenotypes = None
        
    def _calculate_dosages(self):
        self.dosages = self.haplotypes.sum(dim=1).float()
        
    def _create_random_haplotypes(self,num_individuals):
        return torch.randint(0, 2, (num_individuals, *self.g.shape), device=self.device)
    
    def __getitem__(self,index):
        return self.haplotypes[index]
    
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance,default_h2=.2, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.default_h2 = default_h2
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean

        
def calculate_breeding_value(population_dosages, trait_effects, device = device):
    return torch.einsum('hjk,jk->h', population_dosages,trait_effects)

def phenotype(population, trait_effects, h2):
    # Check if phenotypes have been calculated
    if population.phenotypes is None: 
        breeding_values = calculate_breeding_value(population.dosages.float(), trait_effects)
        environmental_variance = (1 - h2) / h2 * breeding_values.var()
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance.clone().detach())
        population.phenotypes = breeding_values + environmental_noise # Store the phenotypes
    return population.phenotypes

def truncation_selection(population, trait_effects, h2, top_percent):
    # Calculate phenotypes (will only calculate if population.phenotypes is None)
    fitnesses = phenotype(population, trait_effects, h2)
    return torch.topk(fitnesses, top_percent).indices


# meiosis
def recombine(parent, recombination_rate=0.01):
    num_individuals, ploidy, num_chromosomes, num_loci = parent.shape    
    # Generate crossover masks
    maternal, paternal = parent[:,0,:,:],parent[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    #crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny

In [None]:


G = Genome(50,10)
P = Population(genome=G)
T = Trait(G,P, 0, 1)


mother_index = truncation_selection(P,T.effects,.5,200)
mother_index = torch.randperm(mother_index.size(0))

father_index = truncation_selection(P,T.effects,.5,200)
father_index = torch.randperm(father_index.size(0))


#n_individuals, ploidy, chr , loci
mother_tensor = P.haplotypes[mother_index] #torch.Size([200, 2, 50, 10])
father_tensor = P.haplotypes[father_index] # torch.Size([200, 2, 50, 10])

In [None]:
mother_tensor.shape

torch.Size([200, 2, 50, 10])

In [None]:
gamete = recombine(mother_tensor)

In [None]:
# %%timeit
# gamete = recombine(mother_tensor)

805 µs ± 71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
