In [None]:
import numpy as np
import matplotlib.pyplot as plt
    
    
    
class ChewC:
    def __init___(self):
        self.trait = None
        self.genome = None
        self.population = None

        
class Genome:
    def __init__(self, n_chr, n_loci):
        self.ploidy = 2
        self.n_chr = n_chr
        self.n_loci = n_loci
        self.shape = (self.ploidy, self.n_chr, self.n_loci)
        
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance):
        #store attributes
        self.target_mean = target_mean
        self.target_variance = target_variance
        
        #sample initial random effects
        random_effects = np.random.randn(g.n_chr, g.n_loci)
        #calculate the founder_population mean and var given these effects
        founder_scores = np.array([random_effects @ x.haplotype for x in population.individuals])
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        #scale the random effects to match our target variance
        scaling_factors = np.sqrt(self.target_variance / founder_var)
        random_effects *= scaling_factors
        self.effects = random_effects
        
        self.intercept = target_mean - founder_mean
        
    def __matmul__(self,other):
        if isinstance(other,Individual):
            return self.effects * other.haplotype + self.intercept
        elif isinstance(other, Population):
            return np.sum(np.array([self @ ind + self.intercept for ind in other.individuals]), axis=(1,2))
        
class Population:
    def __init__(self, genome, size):
        self.genome = genome
        self.size = size
        self.ploidy = 2
        self.individuals = self._create_initial_population()
        self.chewc = None

    def _create_initial_population(self):
        """Create an initial population of founder individuals."""
        return [Individual(self.genome, chewc=self) for _ in range(self.size)]
    
    def get_haplo(self):
        return np.array([x.haplotype for x in self.individuals])

    def __repr__(self):
        return f'Population of size: {self.size}'
    
class Trait:
    def __init__(self, genome, founder_population, target_mean, target_variance):
        #store attributes
        self.target_mean = target_mean
        self.target_variance = target_variance
        
        #sample initial random effects
        random_effects = np.random.randn(g.n_chr, g.n_loci)
        #calculate the founder_population mean and var given these effects
        
        founder_scores = np.einsum('ij,kij->k',random_effects, np.sum(population.get_haplo(),axis=1))
        founder_mean, founder_var = founder_scores.mean(), founder_scores.var()
        #scale the random effects to match our target variance
        scaling_factors = np.sqrt(self.target_variance / founder_var)
        random_effects *= scaling_factors
        self.effects = random_effects
        
        self.intercept = target_mean - founder_mean
        
    def __matmul__(self,other):
        if isinstance(other,Individual):
            return self.effects * other.haplotype
        elif isinstance(other, Population):
            return np.sum(np.array([self @ ind for ind in other.individuals]), axis=(1,2,3))
        
class Individual:
    def __init__(self, genome,haplotype=None, mother = None, father = None, descendants = 0, source='founder', chewc=None):
        self.genome = genome  # must be Genome class
        self.haplotype = None
        self.source = source  # 'founder', 'cross', 'self', 'dh'
        self.descendents = []
        self.fitness = 0

        
        # Logic for setting haplotype,mother,father
        if self.source == 'founder':
            self.haplotype= self._generate_random_haplotype()
            self.mother = None
            self.father = None
        elif self.source == 'cross':
            self.haplotype = haplotype
            self.mother = mother
            self.father = father
        elif self.source == 'dh' or self.source == 'self':
            self.haplotype = haplotype
            self.mother = mother
            self.father = mother
        else:
            raise ValueError(f"Invalid source: {source}")
        
    def __repr__(self):
        return f'Individual with haplotype shape: {self.haplotype.shape}'

    def _generate_random_haplotype(self):
        """Generate a random haplotype for the individual."""
        return np.random.choice([0, 1], size=(self.genome.ploidy, self.genome.n_chr, self.genome.n_loci))
    
    def __add__(self, h2):
        self.fitness = self.phenotype + h2
    
    def gametes(self):
        haplotypes = self.haplotype
        def shuffle_chr(chromosome_pair):
            """
            Perform crossover on a pair of chromosomes.

            Parameters:
            chromosome_pair (list): A list of two equal length lists representing chromosomes.

            Returns:
            list: A new chromosome formed by shuffling the given chromosome pair.
            """
            # Ensure the chromosome pair contains two chromosomes of equal length
            assert len(chromosome_pair) == 2
            assert len(chromosome_pair[0]) == len(chromosome_pair[1])

            # Number of crossover points, sampled from a Poisson distribution with Î»=1.3
            n_crossover = np.random.poisson(1.3)

            # Determine crossover locations, sampled without replacement from chromosome length
            chromosome_length = len(chromosome_pair[0])
            crossover_locs = np.sort(np.random.choice(chromosome_length, n_crossover, replace=False))

            # Initialize the new chromosome and set the current chromosome to the first one
            new_chromosome = []
            current_chr = 0

            # Perform crossover by alternating segments between the two chromosomes
            last_loc = 0
            for loc in crossover_locs:
                new_chromosome.extend(chromosome_pair[current_chr][last_loc:loc])
                current_chr = 1 - current_chr  # Switch to the other chromosome
                last_loc = loc

            # Append the remaining segment
            new_chromosome.extend(chromosome_pair[current_chr][last_loc:])
            return np.array(new_chromosome)


        # Initialize an empty array to store the shuffled chromosomes
        shuffled_haplotypes = np.zeros_like(haplotypes)
        ploidy, n_chr, n_loci = chewc.genome.shape
        # Iterate over each chromosome and apply the shuffle_chr function
        for i in range(haplotypes.shape[1]):  # Iterate over the chromosomes
            chromosome_pair = haplotypes[:, i, :]  # Extract the chromosome pair (2, 77)
            shuffled_chromosome = shuffle_chr(chromosome_pair)  # Shuffle the chromosome pair    
            shuffled_haplotypes[:, i, :] = np.array(shuffled_chromosome).reshape(1, n_loci)  # Store the shuffled chromosome
        return shuffled_haplotypes[0,:,:]

    def phenotype(self, h2):
        breeding_value = chewc.trait @ self
        print(breeding_value.shape)
        genetic_variance = np.var(breeding_value, ddof=1)
        return breeding_value
        
        

        # Example usage
g = Genome(3, 77)
#make population with 100 founder individuals
population = Population(g, size=100)
#make a trait
trait = Trait(g, population,0,1)
# how to do trait @ population to get the trait values for a population?

chewc = ChewC()
chewc.trait = trait
chewc.population = population
chewc.genome = g
trait @ chewc.population
chewc.population.individuals[0].phenotype(1)

(2, 3, 77)


array([[[-0.        , -0.16740674, -0.06527205,  0.07791231,
         -0.02334672,  0.10651681, -0.04032374, -0.09000908,
         -0.        , -0.        , -0.14130992, -0.10716073,
         -0.        , -0.        ,  0.        ,  0.07216024,
          0.        ,  0.01171841, -0.10167756, -0.0860365 ,
          0.        , -0.00251561,  0.16547332,  0.05970075,
          0.        ,  0.        ,  0.        ,  0.13273609,
          0.01546233,  0.        ,  0.        , -0.07790096,
         -0.        , -0.        ,  0.        , -0.        ,
         -0.        , -0.        ,  0.08016346,  0.        ,
         -0.        ,  0.04080275,  0.        ,  0.        ,
         -0.        ,  0.05893229,  0.        ,  0.03338421,
         -0.        , -0.0037284 ,  0.        ,  0.07843556,
         -0.19221128,  0.11882027,  0.17016669, -0.19694884,
          0.13400745,  0.        , -0.        , -0.        ,
          0.        , -0.        ,  0.        , -0.        ,
         -0.        ,  0

In [None]:
chewc.trait @ chewc.population

array([ 0.90058078,  1.30441654,  1.55869168,  1.61792792,  0.52265153,
        1.54847419,  0.6922385 ,  4.24468214,  0.3032351 ,  1.2531959 ,
       -0.37246337,  1.12530289,  0.07714408,  0.96496298,  0.65905819,
       -1.15774067, -0.23176033,  0.40473708, -0.58547203,  0.91608648,
        1.74541433,  1.68537778,  0.02966238,  1.66886463,  0.12316054,
        1.04415549,  1.47553108,  1.75153837,  0.44607381,  0.79847286,
        2.85142217,  0.70347733,  1.04798263,  1.08622491,  2.73747453,
        0.05264775, -0.16099917,  1.26926665,  2.64583469, -0.43661853,
        0.05883183,  0.39777427,  1.12004998, -0.69285545,  0.23794893,
       -0.48797881, -0.19745621,  1.88694249,  0.89300059,  0.41219806,
        1.5067358 ,  0.8994578 ,  1.86402182,  1.77564769, -0.08073889,
        0.19697015,  2.08561036,  0.3340788 , -1.00671384, -0.39244753,
        0.7540654 ,  1.8896223 , -1.01793872,  0.83940053,  0.56879004,
       -0.98796068,  0.99136101,  0.94027401,  0.33533645,  0.49