In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import os

try:
    from IPython.display import SVG, clear_output
    import numpy as np
    import msprime as msprime
    import matplotlib.pyplot as plt
    import statsmodels.api as sm 
    
except ModuleNotFoundError:
    try: 
        print('ERROR: missing required packages.')
        install = input('Attempt to install required modules from requirements.txt? [Y/N] ').lower()
        if install == 'y':
            os.system("conda install --file ./requirements.txt")
        else:
            print("EXITING")
            exit()

    except Exception as e: 
        print("ERROR: failed installing modules.")
        print(e)
        print('\nEXITING')
        exit()

In [None]:
demography = msprime.Demography()
#demography.add_population(name="A", initial_size=7e9, growth_rate=0.006729)    
#demography.add_population(name="A", initial_size=1, growth_rate=0)   

global s_errors_live
global N_range_live
global threshold_exceed_live
s_errors_live = []
N_range_live = []
threshold_exceed_live = []

def ensemble(N):
    # given pop. size N, determines number of ensembles needed
    # for trade-off of sim. time and accuracy.
    no_ensembles = np.arange(1,1+int(5**(4-N)))
    if len(no_ensembles) > 1:
        return no_ensembles
    else: 
        return [1,2]


def produce_reg_var(N, beta, error, sig_SNP, causal=False, *args, **kwargs):
    # return a list of matrices needed for OLS regression
    # causal cn be set to 'all', 'causal', 'noncausal' depending on which SNPs are desired
    reg_vars = []
    X_red = []      # entries are columns of X, dimensions of N x 1
    beta_red = []       # entries are effect size for each site
    Y_red = []      # entries are Y for each SNP

    # iterate over the SNPs that are significant
    for j, SNP in enumerate(sig_SNP):
        SNP_beta = beta[int(SNP[0])]
        genotype = SNP[1]
        # keeps certain SNPs based on preference for causal etc.
        if not causal or causal == 'all':
            X_red.append(genotype)
            beta_red.append(SNP_beta)
            Y_red.append(genotype*SNP_beta+error) 
        if causal == 'causal' and not SNP_beta == 0:
            X_red.append(genotype)
            beta_red.append(SNP_beta)
            Y_red.append(genotype*SNP_beta+error)
        if causal == 'noncausal' and SNP_beta == 0:
            X_red.append(genotype)
            beta_red.append(SNP_beta)
            Y_red.append(genotype*SNP_beta+error)
            
    beta_red = np.array(beta_red)       # list -> array

    # add matrices for this seed to greater list of regression variables
    reg_vars.append([X_red, beta_red, Y_red])
    return reg_vars


class AncestrySimulations():
    '''For generating trees and performing statistical analysis for given population exponent N, and demography.'''
    def __init__(
        self, 
        pop_exp, 
        model=None, 
        demography=None, 
        sequence_length=int(10000), 
        heritability=0.8, 
        recombination=False, 
        mutation=[True,1.1e-3], 
        restrict_SNPs=[True,0.95],
        multi_seed=True, 
        *args, 
        **kwargs
        ):
        
        # creates percentage complete counter
        try:
            print(str(round((pop_exp-pop_exp_range[0])/(max(pop_exp_range)-min(pop_exp_range))*100, 2)) + '% complete') 
        except NameError as e:
            print(e)
            pass
        
        mutation, rate = mutation[0], mutation[1]
        self.mutation = mutation
        # population size is scaled exponentially to cut number of simulations
        self.N = int(10**pop_exp)
        
        # sets whether to iterate over multiple seeds or not 
        # generates seeds needed to perform ensemble average of tree data
        if multi_seed == True:
            seeds = ensemble(pop_exp)
            print('N: {}\nNumber of seeds for this N: {}.'.format(self.N, len(seeds)))
        if multi_seed == False: 
            seeds = [1]

        sig_g = 5.2/np.sqrt(self.N)   # empirical variance of Xβ 
        h = heritability**2
        
        np.random.seed(0)
        self.recombination_rate = recombination
        self.sequence_length = sequence_length

        self.is_causal = np.random.uniform(0, 1, sequence_length) > restrict_SNPs[1]
        self.beta = np.random.normal(sig_g, 1, sequence_length) * self.is_causal
        self.error = np.random.normal(0, (sig_g**2)*((1/(h))-1), self.N)

        self.all_sig_SNP=[]
        
        # iterate over seeds and generate trees for each
        for seed in seeds:
            ts = msprime.sim_ancestry(
            self.N, 
            model=None,
            ploidy=1,
            demography=demography,
            sequence_length=self.sequence_length,
            random_seed=seed,
            recombination_rate=recombination  
            )

            if mutation == True:
                self.mts = msprime.sim_mutations(
                    ts, 
                    model = msprime.BinaryMutationModel(),
                    rate=rate,    #WAS 1.1e-3
                    random_seed=seed
                    )  
            else:
                self.mts = ts

            # limit for how many N need SNP for it to be considered significant, 1% here
            if restrict_SNPs[0] == True:
                self.sig_SNP_threshold = self.N/100
            else:
                self.sig_SNP_threshold = 0

            # form of a SNP is (site position, who has mutation)        
            self.sig_SNP = [(var.site.position, var.genotypes) for var in self.mts.variants() if len(*np.nonzero(var.genotypes)) > self.sig_SNP_threshold]
            self.all_sig_SNP.append(self.sig_SNP)
           
    # generates the standard error of the estimates of beta and plots them live    
    def s_e(self, causal='all', live_plot=False):
        for sig_SNP in self.all_sig_SNP:
            s_error_arr = [] 
            # list of the matrices needed to carry out regression, reduced to only significant SNPs
            reg_vars = produce_reg_var(self.N, self.beta, self.error, sig_SNP, causal=causal) 
            s_error_seed = np.zeros(len(reg_vars))
            # iteration over the lists in greater regression varibales matrix 
            for count, data in enumerate(reg_vars):
                Y = data[2]
                X = data[0]
                total_beta_error = np.zeros(len(X))
                # ordinary least squares regression on Y, X
                for j, X_j in enumerate(X):
                # ordinary least squares regression on Y, X
                    ols = sm.OLS(np.array(Y[j]), np.array(X_j))
                    ols_result = ols.fit()
                    total_beta_error[j] = ols_result.bse**0.5
                s_error_seed[count] = np.mean(total_beta_error)
                s_error_arr.append(s_error_seed)
            
       
        # standard error for this N is mean of all the seeds' standard errors
        s_error = np.mean(s_error_arr)
        if live_plot == True:
            fig1, live_ax = plt.subplots()
            live_ax.set_xlabel('N')
            live_ax.set_ylabel(r's.e. of $\^\beta$')
            live_ax.set_title(r'Stadard error of estimated $\^\beta$')

            # deals with anomalous s_error (can happen for low N if there are an insufficient number of seeds)
            if not s_error == float("inf"):
                N_range_live.append(self.N)
                if s_error > 1000:
                    print('s.e. above threshold: {}'.format(s_error))
                    s_errors_live.append(s_errors_live[-1])
                    threshold_exceed_live.append(s_errors_live[-1])
                else: 
                    s_errors_live.append(s_error)
                    threshold_exceed_live.append(None)


            live_ax.plot(N_range_live, s_errors_live)
            live_ax.plot(N_range_live, threshold_exceed_live, 'rx')
            live_ax.grid()

            clear_output(wait=True)
            plt.show() 
        return s_error

    # plots the effect size of a give SNP against its position on the genome
    def SNP_plot(self, causal='all'):
        # produce regression variables
        reg_vars = produce_reg_var(self.N, self.beta, self.error, self.sig_SNP, causal=causal)        
        Y = reg_vars[0][2]
        X = reg_vars[0][0]
        beta_est = np.zeros(len(X))

        for j in range(X.shape[0]):
            X_j = X[j][:]
            # ordinary least squares regression on Y, X
            ols = sm.OLS(np.array([Y[j]]), np.array(X_j))
            # extract estimated beta
            beta_est[j] = ols.fit().params
        
        # recover full length matrix of significant SNPs
        sig_locs = []
        full_beta_est = []
        
        beta_actual = np.zeros(self.sequence_length)
        
        for count, data in enumerate(self.sig_SNP):
            if abs(beta_est[count]) > 1e-14:
                sig_locs.append(int(data[0]))
                full_beta_est.append(abs(beta_est[count]))
            beta_actual[int(data[0])] = abs(reg_vars[0][1][count])

            
        SNP_fig, ax = plt.subplots()
        ax.set_xlabel('Site')
        ax.set_ylabel('Effect size')
        ax.set_title('Effect size of SNP at each location,\nrecombination rate {}'.format(self.recombination_rate))
        ax.plot(sig_locs, full_beta_est, 'rx', label='Estimated')
        ax.plot(np.arange(self.sequence_length), beta_actual, 'g:', label='Actual')
        ax.legend()
        return 

In [None]:
global pop_exp_range
# the range of exponents used to generate N = 10**(pop_exp)
pop_exp_range = np.linspace(2.5,5,50)

s_errors_0_8 = np.zeros(len(pop_exp_range))
s_errors_0_1 = np.zeros(len(pop_exp_range))
s_errors_0_99 = np.zeros(len(pop_exp_range))
s_errors_0_01 = np.zeros(len(pop_exp_range))

for count, N in enumerate(pop_exp_range):
    s_errors_0_8[count] = AncestrySimulations(
        N, 
        model='dtwf',
        mutation=[True,1e-3]
        ).s_e(
            causal='all'
            )

    s_errors_0_1[count] = AncestrySimulations(
        N, 
        model='dtwf',
        heritability=0.1,
        mutation=[True,1e-3]
        ).s_e(
            causal='all'
            )

    s_errors_0_99[count] = AncestrySimulations(
        N, 
        model='dtwf',
        heritability=0.99,
        mutation=[True,1e-3]
        ).s_e(
            causal='all'
            )

    s_errors_0_01[count] = AncestrySimulations(
        N, 
        model='dtwf',
        heritability=0.01,
        mutation=[True,1e-3]
        ).s_e(
            causal='all'
            )

In [None]:
# plotting

fig, ax = plt.subplots()
ax.set_xlabel('log(N)')
ax.set_ylabel(r'log(s.e.($\^\beta))$')
ax.set_title(r'log(s.e.($\^\beta)) \, vs. \, log(N)$')
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_ylim([10e-5,10e-1])
ax.grid()
ax.plot(10**pop_exp_range, s_errors_0_01, 'm', label='0.01')
ax.plot(10**pop_exp_range, s_errors_0_1, 'r', label='0.1')
ax.plot(10**pop_exp_range, s_errors_0_8, label='0.8')
ax.plot(10**pop_exp_range, s_errors_0_99, 'g', label='0.99')
ax.legend(title='$h^{2}$')

fig, ax2 = plt.subplots()
ax2.set_xlabel('N')
ax2.set_ylabel(r's.e.($\^\beta)$')
ax2.set_title(r's.e.($\^\beta) \, vs. \, N$')
ax2.grid()
ax2.plot(10**pop_exp_range, s_errors_0_01, 'm', label='0.01')
ax2.plot(10**pop_exp_range, s_errors_0_1,'r', label='0.1')
ax2.plot(10**pop_exp_range, s_errors_0_8, label='0.8')
ax2.plot(10**pop_exp_range, s_errors_0_99, 'g', label='0.99')
ax2.legend(title='$h^{2}$')

print((np.log(s_errors_0_8[-1])-np.log(s_errors_0_8[0]))/(np.log(100000)-np.log(316)))