# Simulation Model - Version 2

### Imports

In [4]:
import numpy as np
from sklearn.preprocessing import StandardScaler
import sys

### Class Definitions

In [347]:
class Trait():

    def __init__(self, type, s_causal, s_noncausal, h2, n):
        self.s_causal = s_causal
        self.s_noncausal = s_noncausal
        self.s = s_causal + s_noncausal
        self.h2 = h2
        self.n = n
        self.type = type
    
    def __repr__(self):
        return f"Trait(type='{self.type}', n={self.n})"

    def define_uncorrelated_polygenic_trait(self):
        # compute va and ve from h2
        ve = 1 - self.h2
        va = self.h2

        # simulate genotypes for each person for each SNP
        G = np.zeros((self.n, self.s))
        snp_freqs = []
        for snp in range(self.s):
            freq = np.random.uniform()
            snp_freqs.append(freq)
            s_array = np.array(np.random.binomial(2, freq, size = self.n))
            G[:, snp] = s_array
        self.snp_freqs = snp_freqs

        # scale genotypes
        self.alleles = G.astype(int)
        standardized_G = StandardScaler().fit_transform(G)
        self.standard_alleles = standardized_G

        # choose causal SNPs and effect sizes
        causal_snps = np.random.choice(range(0, self.s), size=self.s_causal, replace=False)
        effect_sizes = list(np.random.normal(loc=0, scale = np.sqrt(va/self.s_causal), size=self.s_causal))
        
        causal_snp_effect = dict(zip(range(0, self.s), [0]*self.s))
        for k in list(causal_snp_effect.keys()):
            if k in causal_snps:
                causal_snp_effect[k] = effect_sizes[0]
                del effect_sizes[0]
        self.causal_snp_effect = causal_snp_effect

        # simulate phenotypes
        phenotypes = []
        genotypes = []
        environments = []
        for i, row in enumerate(self.standard_alleles):
            genotype = np.sum(np.array(row) * np.array(list(causal_snp_effect.values())))
            environment = np.random.normal(loc=0, scale=np.sqrt(ve))
            phenotype = genotype + environment
            phenotypes.append(phenotype)
            genotypes.append(genotype)
            environments.append(environment)

        self.phenotypes = phenotypes
        self.genotypes = genotypes
        self.environment = environments

        return self.genotypes, self.environment, self.phenotypes

    def define_two_correlated_polygenic_traits(self, trait_2, rg, re):
        # simulate genotypes for each person for each SNP
        trait_1_G = np.zeros((self.n, self.s))
        trait_2_G = np.zeros((trait_2.n, trait_2.s))

        assert self.s == trait_2.s
        assert self.s_causal == trait_2.s_causal
        assert self.s_noncausal == trait_2.s_noncausal
        assert self.n == trait_2.n

        snp_freqs = []
        for snp in range(self.s):
            freq = np.random.uniform()
            snp_freqs.append(freq)
            s_array = np.array(np.random.binomial(2, freq, size = self.n))
            trait_1_G[:, snp] = s_array
            trait_2_G[:, snp] = s_array
        self.snp_freqs = snp_freqs
        trait_2.snp_freqs = snp_freqs

        # scale genotypes
        self.alleles = trait_1_G.astype(int)
        standardized_G1 = StandardScaler().fit_transform(trait_1_G)
        self.standard_alleles = standardized_G1

        trait_2.alleles = trait_2_G.astype(int)
        standardized_G2 = StandardScaler().fit_transform(trait_2_G)
        trait_2.standard_alleles = standardized_G2

        # choose causal SNPs
        causal_snps = np.random.choice(range(0, self.s), size=self.s_causal, replace=False)

        # choose effect sizes so rg is true
        effects = np.random.multivariate_normal([0, 0],
            [[1, rg], [rg, 1]], size=self.s_causal)
        trait_1_effect_sizes = effects[:, 0]
        trait_2_effect_sizes = effects[:, 1]

        # normalize so that genetic variance equals h2
        trait_1_effect_sizes *= np.sqrt(self.h2 / np.var(standardized_G1[:, causal_snps] @ trait_1_effect_sizes))
        trait_2_effect_sizes *= np.sqrt(trait_2.h2 / np.var(standardized_G2[:, causal_snps] @ trait_2_effect_sizes))
        trait_1_effect_sizes = list(trait_1_effect_sizes)
        trait_2_effect_sizes = list(trait_2_effect_sizes)

        causal_snp_effect_1 = dict(zip(range(0, self.s), [0]*self.s))
        for k in list(causal_snp_effect_1.keys()):
            if k in causal_snps:
                causal_snp_effect_1[k] = trait_1_effect_sizes[0]
                del trait_1_effect_sizes[0]
        self.causal_snp_effect = causal_snp_effect_1

        causal_snp_effect_2 = dict(zip(range(0, trait_2.s), [0]*trait_2.s))
        for k in list(causal_snp_effect_2.keys()):
            if k in causal_snps:
                causal_snp_effect_2[k] = trait_2_effect_sizes[0]
                del trait_2_effect_sizes[0]
        trait_2.causal_snp_effect = causal_snp_effect_2

        # choose environment so re is true
        mean = [0, 0]
        cov_e = [[1-self.h2, re*np.sqrt((1-self.h2)*(1-trait_2.h2))], 
                [re*np.sqrt((1-self.h2)*(1-trait_2.h2)), 1-trait_2.h2]]
        environment = np.random.multivariate_normal(mean, cov_e, size=self.n)
        trait_1_phenotypes_noise = environment[:, 0]
        trait_2_phenotypes_noise = environment[:, 1]

        # simulate phenotypes
        trait_1_phenotypes_causal = []
        for i, row in enumerate(self.standard_alleles):
            phenotype_causal = np.sum(np.array(row) * np.array(list(self.causal_snp_effect.values())))
            trait_1_phenotypes_causal.append(phenotype_causal)
        trait_1_phenotypes = np.array(trait_1_phenotypes_causal) + np.array(trait_1_phenotypes_noise)

        trait_2_phenotypes_causal = []
        for i, row in enumerate(trait_2.standard_alleles):
            phenotype_causal = np.sum(np.array(row) * np.array(list(trait_2.causal_snp_effect.values())))
            trait_2_phenotypes_causal.append(phenotype_causal)
        trait_2_phenotypes = np.array(trait_2_phenotypes_causal) + np.array(trait_2_phenotypes_noise)

        # update objects
        self.phenotypes = trait_1_phenotypes
        self.genotypes = trait_1_phenotypes_causal
        self.environment = trait_1_phenotypes_noise
        self.correlated_trait = trait_2
        self.re = re
        self.rg = rg
        trait_2.phenotypes = trait_2_phenotypes
        trait_2.genotypes = trait_2_phenotypes_causal
        trait_2.environment = trait_2_phenotypes_noise
        trait_2.correlated_trait = self
        trait_2.re = re
        trait_2.rg = rg

        #return self.genotypes, self.environment, self.phenotypes
        return (trait_1_phenotypes,
            trait_2_phenotypes,
            trait_1_phenotypes_causal,
            trait_2_phenotypes_causal,
            trait_1_phenotypes_noise,
            trait_2_phenotypes_noise)

    def define_monogenic_recessive_trait(self):
        assert self.s == 1
        assert self.s_causal == 1
        assert self.h2 == 1

        # compute va and ve from h2
        ve = 1 - self.h2
        va = self.h2

        # simulate genotypes for each person for each SNP
        snp_freqs = [0.5]
        G = np.array(np.random.binomial(2, snp_freqs[0], size = self.n))
        self.snp_freqs = snp_freqs

        # scale genotypes
        self.alleles = G.astype(int)
        standardized_G = StandardScaler().fit_transform(G.reshape(-1, 1))
        self.standard_alleles = standardized_G

        # simulate phenotypes
        phenotypes = []
        phenotypes_causal = []
        phenotypes_noise = []
        for i, val in enumerate(self.alleles):
            phenotype_causal = val
            phenotype = 1 if val == 0 else 0
            phenotypes.append(phenotype)
            phenotypes_causal.append(phenotype_causal)
            phenotypes_noise.append(0)

        self.phenotypes = phenotypes
        self.genotypes = phenotypes_causal
        self.environment = phenotypes_noise

        return self.genotypes, self.environment, self.phenotypes

    @staticmethod
    def get_assortative_pairs(phenotypes, target_corr=0.3, max_tries=1000):
        phenotypes = np.array(phenotypes)
        n = len(phenotypes)

        sampled_indices = np.random.choice(n, size=n, replace=False)
        phenos_subset = phenotypes[sampled_indices]

        # sort sampled phenotypes and split into two
        sorted_local_indices = np.argsort(phenos_subset)
        sorted_indices = sampled_indices[sorted_local_indices]

        # create initial index pairings (first half with second half)
        p1_idx = sorted_indices[:n // 2].copy()
        p2_idx = sorted_indices[n // 2:].copy()

        # compute initial correlation
        p1_vals = phenotypes[p1_idx]
        p2_vals = phenotypes[p2_idx]
        initial_corr = np.corrcoef(p1_vals, p2_vals)[0, 1]

        for i in range(max_tries):
            prev_corr = np.corrcoef(phenotypes[p1_idx], phenotypes[p2_idx])[0, 1]

            # swap one random pair
            pair_idx = np.random.randint(len(p1_idx))
            p1_idx[pair_idx], p2_idx[pair_idx] = p2_idx[pair_idx], p1_idx[pair_idx]

            # compute new correlation
            new_corr = np.corrcoef(phenotypes[p1_idx], phenotypes[p2_idx])[0, 1]

            # check if we crossed the target correlation
            crossed = (prev_corr > target_corr and new_corr < target_corr) or (prev_corr < target_corr and new_corr > target_corr)
            if crossed:
                if abs(prev_corr - target_corr) < abs(new_corr - target_corr):
                    final_corr = prev_corr
                    # undo last shuffle to return previous state
                    p1_idx[pair_idx], p2_idx[pair_idx] = p2_idx[pair_idx], p1_idx[pair_idx]
                else:
                    final_corr = new_corr

                # return index-based pairs
                index_pairs = list(zip(p1_idx, p2_idx))
                return index_pairs, final_corr

        # if no crossing occurred after max_tries (not likely)
        final_corr = np.corrcoef(phenotypes[p1_idx], phenotypes[p2_idx])[0, 1]
        index_pairs = list(zip(p1_idx, p2_idx))
        return index_pairs, final_corr

    def create_trait_offspring(self, offspring_trait, fitness, method, target_corr = 0.5):
        if self.type != "monogenic_recessive":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
        offspring_trait.snp_freqs = self.snp_freqs # might have to edit this if I want to use allele freqs in graph
        if self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait_2 = Trait("correlated_polygenic", trait_2.s_causal, trait_2.s_noncausal, trait_2.h2, offspring_trait.n)
            offspring_trait.correlated_trait = offspring_trait_2
            offspring_trait_2.correlated_trait = offspring_trait
            offspring_trait_2.causal_snp_effect = trait_2.causal_snp_effect
            offspring_trait_2.snp_freqs = trait_2.snp_freqs
            offspring_trait_2_alleles = []
        
        if method == "assortative":
            corr = 0
            while abs(corr - target_corr) > 0.05:
                pairs, corr = Trait.get_assortative_pairs(self.phenotypes, target_corr=target_corr, max_tries=1000)
        offspring_alleles = []
        for n in range(0, offspring_trait.n):
            if method == "random":
                par1, par2 = np.random.choice(range(0, self.n), p = fitness, size = 2, replace=False)
            elif method == "assortative":
                pair = pairs[n % len(pairs)]
                par1 = pair[0]
                par2 = pair[1]
            else:
                print("Method does not match expected options.")
                sys.exit()

            par1_allele = self.alleles[par1]
            par2_allele = self.alleles[par2]
            par1_allele = np.atleast_1d(par1_allele)
            par2_allele = np.atleast_1d(par2_allele)
            
            punnet_dict = {
                    (0, 0): 0,
                    (0, 1): np.random.choice([0, 1], p=[0.75, 0.25]),
                    (0, 2): 1,
                    (1, 1): np.random.choice([0, 1, 2], p=[0.25, 0.5, 0.25]),
                    (1, 2): np.random.choice([1, 2]),
                    (2, 2): 2,
                }
            offspring_allele = [
                    punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                    for i in range(len(par1_allele))
                ]
            offspring_alleles.append(offspring_allele)

            if self.type == "correlated_polygenic":
                par1_allele = trait_2.alleles[par1]
                par2_allele = trait_2.alleles[par2]
                par1_allele = np.atleast_1d(par1_allele)
                par2_allele = np.atleast_1d(par2_allele)
                
                punnet_dict = {
                    (0, 0): 0,
                    (0, 1): np.random.choice([0, 1], p=[0.75, 0.25]),
                    (0, 2): 1,
                    (1, 1): np.random.choice([0, 1, 2], p=[0.25, 0.5, 0.25]),
                    (1, 2): np.random.choice([1, 2]),
                    (2, 2): 2,
                }
                trait_2_allele = [
                        punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                        for i in range(len(par1_allele))
                    ]
                offspring_trait_2_alleles.append(trait_2_allele)
        offspring_trait.alleles = np.asarray(offspring_alleles, dtype=np.float64)
        offspring_trait.standard_alleles = StandardScaler().fit_transform(offspring_trait.alleles)
        if self.type == "correlated_polygenic":
            offspring_trait_2.alleles = np.asarray(offspring_trait_2_alleles, dtype=np.float64)
            offspring_trait_2.standard_alleles = StandardScaler().fit_transform(offspring_trait_2.alleles)

        if self.type == "monogenic_recessive":
            offspring_trait.environment = [0]*offspring_trait.n
            offspring_trait.genotypes = [int(x[0]) for x in offspring_trait.alleles]
            offspring_trait.phenotypes = (np.array(offspring_trait.genotypes) == 0).astype(int)
        elif self.type == "uncorrelated_polygenic":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
            ve = 1 - offspring_trait.h2
            phenotypes = []
            genotypes = []
            environments = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(self.causal_snp_effect.values())))
                environment = np.random.normal(loc=0, scale=np.sqrt(ve))
                phenotype = genotype + environment
                phenotypes.append(phenotype)
                genotypes.append(genotype)
                environments.append(environment)
            offspring_trait.phenotypes = phenotypes
            offspring_trait.genotypes = genotypes
            offspring_trait.environment = environments
        elif self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait.correlated_trait = offspring_trait_2
            re = self.re
            offspring_trait_2.re = re
            offspring_trait.re = re
            # choose environment so re is true
            mean = [0, 0]
            cov_e = [[1-offspring_trait.h2, re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2))], 
                    [re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2)), 1-trait_2.h2]]
            environment = np.random.multivariate_normal(mean, cov_e, size=offspring_trait.n)
            offspring_trait.environment = environment[:, 0]
            offspring_trait_2.environment = environment[:, 1]
            # simulate phenotypes
            offspring_trait.genotypes = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait.causal_snp_effect.values())))
                offspring_trait.genotypes.append(genotype)
            offspring_trait.phenotypes = np.array(offspring_trait.genotypes) + np.array(offspring_trait.environment)

            offspring_trait_2.genotypes = []
            for i, row in enumerate(offspring_trait_2.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait_2.causal_snp_effect.values())))
                offspring_trait_2.genotypes.append(genotype)
            offspring_trait_2.phenotypes = np.array(offspring_trait_2.genotypes) + np.array(offspring_trait_2.environment)
        else:
            print("Trait type does not match expected values.")
            sys.exit()

    def create_trait_offspring_ivf(self, offspring_trait, fitness, ivf_n):
        if self.type != "monogenic_recessive":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
        offspring_trait.snp_freqs = self.snp_freqs
        punnet_dict = {
                    (0, 0): 0,
                    (0, 1): np.random.choice([0, 1], p=[0.75, 0.25]),
                    (0, 2): 1,
                    (1, 1): np.random.choice([0, 1, 2], p=[0.25, 0.5, 0.25]),
                    (1, 2): np.random.choice([1, 2]),
                    (2, 2): 2,
                }
        if self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait_2 = Trait("correlated_polygenic", trait_2.s_causal, trait_2.s_noncausal, trait_2.h2, offspring_trait.n)
            offspring_trait.correlated_trait = offspring_trait_2
            offspring_trait_2.correlated_trait = offspring_trait
            offspring_trait_2.causal_snp_effect = trait_2.causal_snp_effect
            offspring_trait_2.snp_freqs = trait_2.snp_freqs
            offspring_trait_2_alleles = []
        
        offspring_alleles = []
        for n in range(0, offspring_trait.n-ivf_n):
            par1, par2 = np.random.choice(range(0, self.n), p = fitness, size = 2, replace=False)
            par1_allele = self.alleles[par1]
            par2_allele = self.alleles[par2]
            par1_allele = np.atleast_1d(par1_allele)
            par2_allele = np.atleast_1d(par2_allele)
            
            offspring_allele = [
                    punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                    for i in range(len(par1_allele))
                ]
            offspring_alleles.append(offspring_allele)

            if self.type == "correlated_polygenic":
                par1_allele = trait_2.alleles[par1]
                par2_allele = trait_2.alleles[par2]
                par1_allele = np.atleast_1d(par1_allele)
                par2_allele = np.atleast_1d(par2_allele)
                    
                trait_2_allele = [
                        punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                        for i in range(len(par1_allele))
                    ]
                offspring_trait_2_alleles.append(trait_2_allele)
        for n in range(offspring_trait.n-ivf_n, offspring_trait.n):
            par1, par2 = np.random.choice(range(0, self.n), p = fitness, size = 2, replace=False)
            par1_allele = self.alleles[par1]
            par2_allele = self.alleles[par2]
            par1_allele = np.atleast_1d(par1_allele)
            par2_allele = np.atleast_1d(par2_allele)

            offspring_allele_ivf = [[
                    punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                    for i in range(len(par1_allele))
                ] for j in range(0,10)]
            offspring_allele = max(offspring_allele_ivf)
            offspring_alleles.append(offspring_allele)

            if self.type == "correlated_polygenic":
                par1_allele = trait_2.alleles[par1]
                par2_allele = trait_2.alleles[par2]
                par1_allele = np.atleast_1d(par1_allele)
                par2_allele = np.atleast_1d(par2_allele)
                    
                trait_2_allele_ivf = [[
                        punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                        for i in range(len(par1_allele))
                    ] for j in range(0,10)]
                trait_2_allele = max(trait_2_allele)
                offspring_trait_2_alleles.append(trait_2_allele)
        offspring_trait.alleles = offspring_alleles
        standardized_G = StandardScaler().fit_transform(offspring_trait.alleles)
        offspring_trait.standard_alleles = standardized_G
        if self.type == "correlated_polygenic":
            offspring_trait_2.alleles = offspring_trait_2_alleles
            standardized_G = StandardScaler().fit_transform(offspring_trait_2.alleles)
            offspring_trait_2.standard_alleles = standardized_G

        if self.type == "monogenic_recessive":
            offspring_trait.environment = [0]*offspring_trait.n
            offspring_trait.genotype = [0 if x == [0] else 1 for x in offspring_trait.alleles]
            offspring_trait.phenotype = offspring_trait.genotype
        elif self.type == "uncorrelated_polygenic":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
            ve = 1 - offspring_trait.h2
            phenotypes = []
            genotypes = []
            environments = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(self.causal_snp_effect.values())))
                environment = np.random.normal(loc=0, scale=np.sqrt(ve))
                phenotype = genotype + environment
                phenotypes.append(phenotype)
                genotypes.append(genotype)
                environments.append(environment)
            offspring_trait.phenotypes = phenotypes
            offspring_trait.genotypes = genotypes
            offspring_trait.environment = environments

        elif self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait.correlated_trait = offspring_trait_2
            re = self.re
            offspring_trait_2.re = re
            # choose environment so re is true
            mean = [0, 0]
            cov_e = [[1-offspring_trait.h2, re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2))], 
                    [re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2)), 1-trait_2.h2]]
            environment = np.random.multivariate_normal(mean, cov_e, size=offspring_trait.n)
            offspring_trait.environment = environment[:, 0]
            offspring_trait_2.environment = environment[:, 1]
            # simulate phenotypes
            offspring_trait.genotypes = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait.causal_snp_effect.values())))
                offspring_trait.genotypes.append(genotype)
            offspring_trait.phenotypes = np.array(offspring_trait.genotypes) + np.array(offspring_trait.environment)

            offspring_trait_2.genotypes = []
            for i, row in enumerate(offspring_trait_2.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait_2.causal_snp_effect.values())))
                offspring_trait_2.genotypes.append(genotype)
            offspring_trait_2.phenotypes = np.array(offspring_trait_2.genotypes) + np.array(offspring_trait_2.environment)
        else:
            print("Trait type does not match expected values.")
            sys.exit()
    
    def create_trait_offspring_gene_editing(self, offspring_trait, fitness, edit_n):
        if self.type != "monogenic_recessive":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
        offspring_trait.snp_freqs = self.snp_freqs
        punnet_dict = {
                    (0, 0): 0,
                    (0, 1): np.random.choice([0, 1], p=[0.75, 0.25]),
                    (0, 2): 1,
                    (1, 1): np.random.choice([0, 1, 2], p=[0.25, 0.5, 0.25]),
                    (1, 2): np.random.choice([1, 2]),
                    (2, 2): 2,
                }
        if self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait_2 = Trait("correlated_polygenic", trait_2.s_causal, trait_2.s_noncausal, trait_2.h2, offspring_trait.n)
            offspring_trait.correlated_trait = offspring_trait_2
            offspring_trait_2.correlated_trait = offspring_trait
            offspring_trait_2.causal_snp_effect = trait_2.causal_snp_effect
            offspring_trait_2.snp_freqs = trait_2.snp_freqs
            offspring_trait_2_alleles = []
        
        offspring_alleles = []
        for n in range(0, offspring_trait.n-edit_n):
            par1, par2 = np.random.choice(range(0, self.n), p = fitness, size = 2, replace=False)
            par1_allele = self.alleles[par1]
            par2_allele = self.alleles[par2]
            par1_allele = np.atleast_1d(par1_allele)
            par2_allele = np.atleast_1d(par2_allele)
            
            offspring_allele = [
                    punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                    for i in range(len(par1_allele))
                ]
            offspring_alleles.append(offspring_allele)

            if self.type == "correlated_polygenic":
                par1_allele = trait_2.alleles[par1]
                par2_allele = trait_2.alleles[par2]
                par1_allele = np.atleast_1d(par1_allele)
                par2_allele = np.atleast_1d(par2_allele)
                    
                trait_2_allele = [
                        punnet_dict[tuple(sorted((par1_allele[i], par2_allele[i])))]
                        for i in range(len(par1_allele))
                    ]
                offspring_trait_2_alleles.append(trait_2_allele)
        for n in range(offspring_trait.n-edit_n, offspring_trait.n):
            par1, par2 = np.random.choice(range(0, self.n), p = fitness, size = 2, replace=False)
            par1_allele = self.alleles[par1]
            par2_allele = self.alleles[par2]
            par1_allele = np.atleast_1d(par1_allele)
            par2_allele = np.atleast_1d(par2_allele)

            offspring_allele = 2
            offspring_alleles.append(offspring_allele)

            if self.type == "correlated_polygenic":
                par1_allele = trait_2.alleles[par1]
                par2_allele = trait_2.alleles[par2]
                par1_allele = np.atleast_1d(par1_allele)
                par2_allele = np.atleast_1d(par2_allele)
                    
                trait_2_allele = 2
                offspring_trait_2_alleles.append(trait_2_allele)
        offspring_trait.alleles = offspring_alleles
        standardized_G = StandardScaler().fit_transform(offspring_trait.alleles)
        offspring_trait.standard_alleles = standardized_G
        if self.type == "correlated_polygenic":
            offspring_trait_2.alleles = offspring_trait_2_alleles
            standardized_G = StandardScaler().fit_transform(offspring_trait_2.alleles)
            offspring_trait_2.standard_alleles = standardized_G

        if self.type == "monogenic_recessive":
            offspring_trait.environment = [0]*offspring_trait.n
            offspring_trait.genotype = [0 if x == [0] else 1 for x in offspring_trait.alleles]
            offspring_trait.phenotype = offspring_trait.genotype
        elif self.type == "uncorrelated_polygenic":
            offspring_trait.causal_snp_effect = self.causal_snp_effect
            ve = 1 - offspring_trait.h2
            phenotypes = []
            genotypes = []
            environments = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(self.causal_snp_effect.values())))
                environment = np.random.normal(loc=0, scale=np.sqrt(ve))
                phenotype = genotype + environment
                phenotypes.append(phenotype)
                genotypes.append(genotype)
                environments.append(environment)
            offspring_trait.phenotypes = phenotypes
            offspring_trait.genotypes = genotypes
            offspring_trait.environment = environments

        elif self.type == "correlated_polygenic":
            trait_2 = self.correlated_trait
            offspring_trait.correlated_trait = offspring_trait_2
            re = self.re
            offspring_trait_2.re = re
            # choose environment so re is true
            mean = [0, 0]
            cov_e = [[1-offspring_trait.h2, re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2))], 
                    [re*np.sqrt((1-offspring_trait.h2)*(1-trait_2.h2)), 1-trait_2.h2]]
            environment = np.random.multivariate_normal(mean, cov_e, size=offspring_trait.n)
            offspring_trait.environment = environment[:, 0]
            offspring_trait_2.environment = environment[:, 1]
            # simulate phenotypes
            offspring_trait.genotypes = []
            for i, row in enumerate(offspring_trait.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait.causal_snp_effect.values())))
                offspring_trait.genotypes.append(genotype)
            offspring_trait.phenotypes = np.array(offspring_trait.genotypes) + np.array(offspring_trait.environment)

            offspring_trait_2.genotypes = []
            for i, row in enumerate(offspring_trait_2.standard_alleles):
                genotype = np.sum(np.array(row) * np.array(list(offspring_trait_2.causal_snp_effect.values())))
                offspring_trait_2.genotypes.append(genotype)
            offspring_trait_2.phenotypes = np.array(offspring_trait_2.genotypes) + np.array(offspring_trait_2.environment)
        else:
            print("Trait type does not match expected values.")
            sys.exit()

In [348]:
class Generation():

    def __init__(self, n):
        self.n = n
        self.traits = []
    
    def __repr__(self):
        return f"Generation(n={self.n})"

    def add_trait(self, trait):
        assert self.n == trait.n
        self.traits.append(trait)
    
    def calculate_fitness(self, fitness_func):
        individual_fitness = []
        for i in range(self.n):
            phenotypes_i = [trait.phenotypes[i] for trait in self.traits]
            fitness_i = fitness_func(phenotypes_i)
            individual_fitness.append(fitness_i)
        def squish(x):
            x = np.array(x)
            e = np.exp(x - np.max(x))
            return e / np.sum(e)
        self.fitness = squish(individual_fitness)
    
    def create_offspring_random_mating(self, new_n):
        assert self.fitness.all() != None
        offspring = Generation(new_n)
        processed = set()
        for trait in self.traits:
            if id(trait) in processed:
                continue
            offspring_trait = Trait(
                trait.type, trait.s_causal, trait.s_noncausal, trait.h2, new_n
            )
            trait.create_trait_offspring(offspring_trait, self.fitness, "random")
            offspring.add_trait(offspring_trait)

            if hasattr(offspring_trait, "correlated_trait") and (
                offspring_trait.correlated_trait is not None
            ):
                offspring.add_trait(offspring_trait.correlated_trait)
                processed.add(id(trait))
                processed.add(id(trait.correlated_trait))
        return offspring

    # known bug: assortative mating does not increase offspring variance
    def create_offspring_assortative_mating(self, new_n):
        assert self.fitness.all() != None
        offspring = Generation(new_n)
        processed = set()
        for trait in self.traits:
            if id(trait) in processed:
                continue
            offspring_trait = Trait(
                trait.type, trait.s_causal, trait.s_noncausal, trait.h2, new_n
            )
            trait.create_trait_offspring(offspring_trait, self.fitness, "assortative")
            offspring.add_trait(offspring_trait)

            if hasattr(offspring_trait, "correlated_trait") and (
                offspring_trait.correlated_trait is not None
            ):
                offspring.add_trait(offspring_trait.correlated_trait)
                processed.add(id(trait))
                processed.add(id(trait.correlated_trait))
        return offspring
    
    def create_offspring_ivf(self, new_n, ivf_n):
        assert self.fitness.all() != None
        offspring = Generation(new_n)
        processed = set()
        for trait in self.traits:
            if id(trait) in processed:
                continue
            offspring_trait = Trait(
                trait.type, trait.s_causal, trait.s_noncausal, trait.h2, new_n
            )
            trait.create_trait_offspring_ivf(offspring_trait, self.fitness, ivf_n=ivf_n)
            offspring.add_trait(offspring_trait)

            if hasattr(offspring_trait, "correlated_trait") and (
                offspring_trait.correlated_trait is not None
            ):
                offspring.add_trait(offspring_trait.correlated_trait)
                processed.add(id(trait))
                processed.add(id(trait.correlated_trait))
        return offspring
    
    def create_offspring_gene_editing(self, new_n, edit_n):
        assert self.fitness.all() != None
        offspring = Generation(new_n)
        processed = set()
        for trait in self.traits:
            if id(trait) in processed:
                continue
            offspring_trait = Trait(
                trait.type, trait.s_causal, trait.s_noncausal, trait.h2, new_n
            )
            trait.create_trait_offspring_gene_editing(offspring_trait, self.fitness, edit_n=edit_n)
            offspring.add_trait(offspring_trait)

            if hasattr(offspring_trait, "correlated_trait") and (
                offspring_trait.correlated_trait is not None
            ):
                offspring.add_trait(offspring_trait.correlated_trait)
                processed.add(id(trait))
                processed.add(id(trait.correlated_trait))
        return offspring

### Variables

In [349]:
def linear_w_from_p(phenotypes):
    w = 1 + sum(phenotypes)
    return w
def no_fitness(phenotypes):
    w = 1
    return w
pop_size = [500, 500, 500, 600, 400, 100]
# trait_dict = {0:["monogenic_recessive", 1, 0, 1], 1:["uncorrelated_polygenic", 2, 2, 0.5], 
#               2:["correlated_polygenic", 2, 1, 0.2, 0.1, 0.1], 3:["correlated_polygenic", 2, 1, 0.8, 0.1, 0.1]}
# trait_dict = {0:["correlated_polygenic", 30, 10, 0.2, 0.3, 0.8], 1:["correlated_polygenic", 30, 10, 0.8, 0.3, 0.8]}
trait_dict = {0:["monogenic_recessive", 1, 0, 1], 1:["monogenic_recessive", 1, 0, 1]}
mating = "random" # random, assortative, ivf, gene_editing
fitness_func = no_fitness
# ivf_size = [0, 2, 2]
# edit_size = [0, 2, 2]

In [350]:
num_gen = len(pop_size)
num_traits = len(trait_dict)

### System

In [364]:
generations = []
gen1 = Generation(pop_size[0])
i = 0
while i < num_traits:
    info = trait_dict[i]
    trait = Trait(info[0], info[1], info[2], info[3], pop_size[0])
    if trait.type == "monogenic_recessive":
        trait.define_monogenic_recessive_trait()
        gen1.add_trait(trait)
        i += 1
    elif trait.type == "uncorrelated_polygenic":
        trait.define_uncorrelated_polygenic_trait()
        gen1.add_trait(trait)
        i += 1
    elif trait.type == "correlated_polygenic":
        trait_2_info = trait_dict[i+1]
        trait_2 = Trait(trait_2_info[0], trait_2_info[1], trait_2_info[2], trait_2_info[3], pop_size[0])
        trait.define_two_correlated_polygenic_traits(trait_2, info[4], info[5])
        gen1.add_trait(trait)
        gen1.add_trait(trait_2)
        i += 2
    else:
        print("Trait type does not match expected value.")
        sys.exit()
gen1.calculate_fitness(fitness_func)
generations.append(gen1)

In [365]:
for g in range(1, num_gen):
    if mating == "random":
        nextgen = generations[-1].create_offspring_random_mating(pop_size[g])
    elif mating == "assortative":
        nextgen = generations[-1].create_offspring_assortative_mating(pop_size[g])
    elif mating == "ivf":
        nextgen = generations[-1].create_offspring_ivf(pop_size[g], ivf_size[g])
    elif mating == "gene_editing":
        nextgen = generations[-1].create_offspring_gene_editing(pop_size[g], edit_size[g])
    nextgen.calculate_fitness(fitness_func)
    generations.append(nextgen)

### Testing Functions

In [361]:
def compute_h2(genotypes, phenotypes):
    """Estimate narrow-sense heritability as Var(G) / Var(P)."""
    return np.var(genotypes) / np.var(phenotypes)

def compute_rg(trait1_geno, trait2_geno):
    """Compute genetic correlation."""
    return np.corrcoef(trait1_geno, trait2_geno)[0,1]

def compute_re(trait1_env, trait2_env):
    return np.corrcoef(trait1_env, trait2_env)[0,1]

In [362]:
import numpy as np

def compute_trait_stats(trait):
    """Return h2, genetic variance, environmental variance."""
    
    P = np.array(trait.phenotypes, dtype=float)
    E = np.array(trait.environment, dtype=float)

    if trait.type == "monogenic_recessive":
        G_effect = P.copy()
    else:
        G_effect = np.array(trait.genotypes, dtype=float)
    
    var_G = np.var(G_effect)
    var_E = np.var(E)
    var_P = np.var(P)

    h2 = var_G / var_P if var_P > 0 else 0.0

    return h2, var_G, var_E


def compute_rg_re(trait1, trait2):
    """Return genetic correlation and environmental correlation."""
    G1 = np.array(trait1.genotypes, dtype=float)
    G2 = np.array(trait2.genotypes, dtype=float)
    E1 = np.array(trait1.environment, dtype=float)
    E2 = np.array(trait2.environment, dtype=float)

    # handle zero variance cases
    if np.std(G1) == 0 or np.std(G2) == 0:
        rg = 0.0
    else:
        rg = np.corrcoef(G1, G2)[0,1]

    if np.std(E1) == 0 or np.std(E2) == 0:
        re = 0.0
    else:
        re = np.corrcoef(E1, E2)[0,1]

    return rg, re


In [366]:
print("=== Trait stability statistics ===")

for g_idx, gen in enumerate(generations):
    print(f"\nGeneration {g_idx}")

    # h2 for each trait
    for t_idx, trait in enumerate(gen.traits):
        h2, varG, varE = compute_trait_stats(trait)
        print(f"  Trait {t_idx}: h2={h2:.3f}, VarG={varG:.3f}, VarE={varE:.3f}")

    # correlations between trait 0 and 1 if they exist
    if len(gen.traits) >= 2:
        rg, re = compute_rg_re(gen.traits[0], gen.traits[1])
        print(f"  r_g(0,1) = {rg:.3f}")
        print(f"  r_e(0,1) = {re:.3f}")

=== Trait stability statistics ===

Generation 0
  Trait 0: h2=1.000, VarG=0.174, VarE=0.000
  Trait 1: h2=1.000, VarG=0.167, VarE=0.000
  r_g(0,1) = 0.035
  r_e(0,1) = 0.000

Generation 1
  Trait 0: h2=1.000, VarG=0.196, VarE=0.000
  Trait 1: h2=1.000, VarG=0.210, VarE=0.000
  r_g(0,1) = -0.014
  r_e(0,1) = 0.000

Generation 2
  Trait 0: h2=1.000, VarG=0.208, VarE=0.000
  Trait 1: h2=1.000, VarG=0.206, VarE=0.000
  r_g(0,1) = 0.007
  r_e(0,1) = 0.000

Generation 3
  Trait 0: h2=1.000, VarG=0.226, VarE=0.000
  Trait 1: h2=1.000, VarG=0.228, VarE=0.000
  r_g(0,1) = -0.004
  r_e(0,1) = 0.000

Generation 4
  Trait 0: h2=1.000, VarG=0.227, VarE=0.000
  Trait 1: h2=1.000, VarG=0.239, VarE=0.000
  r_g(0,1) = 0.011
  r_e(0,1) = 0.000

Generation 5
  Trait 0: h2=1.000, VarG=0.245, VarE=0.000
  Trait 1: h2=1.000, VarG=0.245, VarE=0.000
  r_g(0,1) = 0.006
  r_e(0,1) = 0.000
