In [None]:
#x02_chewc

In [None]:
import torch
import matplotlib.pyplot as plt
from fastcore.basics import patch
import uuid

In [None]:
import torch

device='cpu'

class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Population:
    def __init__(self, genome, haplotypes=None, size=1000, device=device):
        self.g = genome
        self.device = device
        if haplotypes is not None and haplotypes.numel() > 0:
            self.haplotypes = haplotypes
        else:
            self.haplotypes = self._create_random_haplotypes(size)  # Or handle it another way

        self._calculate_dosages()
        self.phenotypes = None
        
    def _calculate_dosages(self):
        self.dosages = self.haplotypes.sum(dim=1).float()
        
    def _create_random_haplotypes(self,num_individuals):
        return torch.randint(0, 2, (num_individuals, *self.g.shape), device=self.device)
    
    def __getitem__(self,index):
        return self.haplotypes[index]
    
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance,default_h2=.2, device=device):
        self.target_mean = target_mean
        self.target_variance = target_variance
        self.default_h2 = default_h2
        self.device = device
        random_effects = torch.randn(genome.n_chr, genome.n_loci, device=self.device)
        random_effects -= random_effects.mean()
        founder_scores = torch.einsum('kl,hkl->h', random_effects, founder_population.dosages)
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        scaling_factors = torch.sqrt(self.target_variance / founder_var)
        self.scaling_factors = scaling_factors
        random_effects *= scaling_factors
        self.effects = random_effects
        self.intercept = founder_mean - target_mean

        
def calculate_breeding_value(population_dosages, trait_effects, device = device):
    return torch.einsum('hjk,jk->h', population_dosages,trait_effects)

def truncation_selection(population, trait, top_percent):
    return torch.topk(population.phenotypes, top_percent).indices

def update_population(population, progeny):
    population.haplotypes = progeny
    population.dosages = population._calculate_dosages()
    return population

# meiosis
def recombine(parent_haplo_tensor, recombination_rate=0.01):
    num_individuals, ploidy, num_chromosomes, num_loci = parent_haplo_tensor.shape    
    # Generate crossover masks
    maternal, paternal = parent_haplo_tensor[:,0,:,:],parent_haplo_tensor[:,1,:,:],
    crossovers = torch.bernoulli(torch.full((num_individuals, num_chromosomes, num_loci), recombination_rate, device=device))
    #crossovers = torch.rand((num_individuals, num_chromosomes, num_loci), device=device) < recombination_rate
    progeny = maternal * (1 - crossovers) + paternal * crossovers
    return progeny

def breed(mother_tensor, father_tensor):
    eggs = recombine(mother_tensor)
    pollens = recombine(father_tensor)
    return torch.stack((eggs,pollens), dim=1)


def phenotype(population, trait, h2):
    # Check if phenotypes have been calculated
    if population.phenotypes is None: 
        breeding_values = calculate_breeding_value(population.dosages, trait.effects)
        environmental_variance = (1 - h2) / h2 * breeding_values.var()
        environmental_noise = torch.randn(breeding_values.shape, device=device) * torch.sqrt(environmental_variance.clone().detach())
        population.phenotypes = breeding_values + environmental_noise # Store the phenotypes
    return population


G = Genome(10,1000)
P = Population(genome=G,size=1000)
T = Trait(G,P, 0, 1)



In [None]:
P = phenotype(P,T,1)

In [None]:

G = Genome(10,1000)
P = Population(genome=G,size=1000)
T = Trait(G,P, 0, 1)

P = phenotype(P,T,1)

In [None]:
mating_pool = truncation_selection(P,T,100)
mothers = P.haplotypes[torch.randperm(mating_pool.size(0))]
fathers = P.haplotypes[torch.randperm(mating_pool.size(0))]
progeny = breed(mothers,fathers)

In [None]:
progeny = Population(G,haplotypes=progeny)

In [None]:
P.phenotypes.mean()

tensor(-0.0010)

In [None]:
progeny = phenotype(progeny, T, 1)

In [None]:
progeny.phenotypes.mean()

tensor(0.3487)