# Simulations for evaluating archaic ancestry inference

## Set paths and import libraries

In [None]:
# run in archanc root directory
import os
proj_dir = os.getcwd()
msprime_dir = proj_dir +"/output/msprime/"
archie_src_dir = proj_dir + "/src/ArchIE/"
archie_out_dir = proj_dir + "/output/ArchIE/"

In [None]:
import sys
sys.path.insert(0, proj_dir + '/src/stdpopsim')

from stdpopsim import homo_sapiens, models
import msprime
import itertools
import random
import numpy as np
import pandas as pd
import math
import cyvcf2

## Classes for generating EIGENSTRAT-formatted output from the `msprime` tree sequence

### Simulating MNMs
The function `generateEigData.mnms()` will randomly select a fraction of variant sites to be MNMs. The second mutation in each MNM is randomly placed up to K bp downstream of the first, and the genotypes of the first mutation are duplicated. 

In [None]:
class mnmEigData:
    """
    object defines .geno and .snp matrices 
    """

    def __init__(self, geno, snp, prefix, repeat_rows): 
        self.geno = geno
        self.snp = snp
        self.prefix = prefix
        self.repeat_rows = repeat_rows


class generateEigData:
    """
    class describes the .snp, .geno, and .ind EIGENSTRAT formats
    """
    
    def __init__(self, ts, model_label, rep_label=0, mnm_frac=0, mnm_dist=0):
        self.ts = ts
        self.model_label = model_label
        self.rep_label = rep_label
        if mnm_frac > 0:
            self.sim_mnms = True
        else:
            self.sim_mnms = False
        self.mnm_frac = mnm_frac
        self.mnm_dist = mnm_dist
        self.geno = self.geno()
        self.snp = self.snp()
        self.prefix = self.prefix()
    
    def prefix(self):
        """
        get unique identifier for model based on inputs
        """

        prefix = self.model_label + "_rep" + str(self.rep_label)
#         if self.sim_mnms:
#             prefix = prefix + "_mnm"+str(self.mnm_dist)+"-"+str(self.mnm_frac)
        
        return prefix
    
    def geno(self):
        """
        generate .geno file from tree sequence
        """

        geno = np.zeros(200, dtype=np.int8)
        
        for variant in list(self.ts.variants()):
            geno = np.vstack([geno, variant.genotypes])
        
        geno = np.delete(geno, (0), axis=0) # remove dummy 1st row
        return geno
                
    def snp(self):
        """
        generate .snp file from tree sequence
        """

        d = pd.DataFrame()
        for variant in list(self.ts.variants()):
            d = d.append(pd.DataFrame({'ID': "1:"+str(round(variant.site.position)),
                      'CHR': "1",
                      'POS': str(round(variant.site.position)),
                      'POS1': str(variant.site.position/10e6), 
                      'REF': "A", 
                      'ALT': "G"
                     }, index=[0]), ignore_index=True)

        snp_df = pd.DataFrame(d)
        return snp_df
    
    def mnms(self):
        """
        update .geno and .snp output to include simulated MNMs
        """
        
        snp_mnm = pd.DataFrame()
#         geno_mat = self.geno()
        repeat_rows = []
        for index, snp in self.snp.iterrows():
            snp_mnm = snp_mnm.append(snp, ignore_index=True)
            
            random.seed(snp['POS'])
            if random.random() < self.mnm_frac:

                dist = random.randint(1, self.mnm_dist)

                mnm_snp = snp
                mnm_snp['POS'] = str(round(int(snp['POS'])+dist))
                mnm_snp['POS1'] = float(mnm_snp['POS'])/10e6
                mnm_snp['ID'] = "1:" + mnm_snp['POS']
                snp_mnm = snp_mnm.append(mnm_snp, ignore_index=True)
                repeat_rows.append(2)
            else:
                repeat_rows.append(1)
        
        geno_mat_mnm = np.repeat(self.geno, repeats=repeat_rows, axis=0)

        mnm_prefix = self.prefix + "_mnm"+str(self.mnm_dist)+"-"+str(self.mnm_frac)
        
        return mnmEigData(geno=geno_mat_mnm,
                          snp=snp_mnm,
                          prefix=mnm_prefix,
                          repeat_rows=repeat_rows)

    
class writeEigData:
    """
    functions for writing .snp .geno and .ind files
    """
    
    def __init__(self, eig_data, output_dir="./"):
        self.eig_data = eig_data
        self.prefix = eig_data.prefix
        self.output_dir = output_dir

    def dump(self):
        """
        run all 3 output functions: write_snp, write_geno, and write_ind
        """
        
        self.write_snp()
        self.write_geno()
        self.write_ind()
        
    def write_snp(self):
        """
        write .snp files
        """

        self.eig_data.snp.to_csv(self.output_dir + self.prefix + ".snp", 
                  index=True,
                  sep="\t")

    def write_geno(self):
        """
        write .geno files
        """

        geno_afr = self.eig_data.geno[:,:100]
        geno_eur = self.eig_data.geno[:,100:200]

        geno_pop = [geno_afr, geno_eur]

        for i, pop in enumerate(["afr", "eur"]):
            np.savetxt(self.output_dir + self.prefix + "_" + pop + ".geno", geno_pop[i], delimiter="", fmt='%i')

    def write_ind(self):
        """
        write .ind files
        """

        for pop in ["afr", "eur"]:
            # write out separate .ind files per population.
            # columns indicate sample ID, sex (set as 'U'), and label (set as 'ADMIXED')
            ind_file = self.output_dir + self.prefix + "_" + pop + ".ind"
            with open(ind_file, "w") as id_file:
                for sample_id in range(0,100):
                    sample_name = self.prefix + "_" + pop + \
                        "_sample_" + str(sample_id)
                    print("\t".join([sample_name, "U", pop]),  file=id_file)

## Simulate models

Simulate 200 samples (100 each of European and African ancestry) under each of the specified models (with 1000 replicates each)

In [None]:
# coalescent simulation parameters
sample_size = 100 # each
length = 50000
mu = 1.15e-8
rr = 1e-8
replicates = 1000
seed = 30

# Gutenkunst 3-population model
GutenkunstThreePop_model = homo_sapiens.GutenkunstThreePopOutOfAfrica()
GutenkunstThreePop_ts = msprime.simulate(
    # first 100 samples from AFR, next 100 from EUR
    samples=[msprime.Sample(0, 0)]*sample_size + [msprime.Sample(1, 0)]*sample_size,
    length=length, 
    mutation_rate=mu, 
    recombination_rate=rr,
    random_seed=seed,
    num_replicates=replicates,
    **GutenkunstThreePop_model.asdict())

# Tennessen 2-population model
TennessenTwoPop_model = homo_sapiens.TennessenTwoPopOutOfAfrica()
TennessenTwoPop_ts = msprime.simulate(
    # first 100 samples from AFR, next 100 from EUR
    samples=[msprime.Sample(0, 0)]*sample_size + [msprime.Sample(1, 0)]*sample_size,
    length=length, 
    mutation_rate=mu, 
    recombination_rate=rr,
    random_seed=seed,
    num_replicates=replicates,
    **TennessenTwoPop_model.asdict())

#-------------------------------------------------------
# define other models here and add to model_dict below
#-------------------------------------------------------

# modify demographic parameters to include archaic branches
# GutenkunstThreePopArchaic_model = homo_sapiens.GutenkunstThreePopArchaic()

# GutenkunstThreePopArchaic_ts = msprime.simulate(
#     # first 100 samples from AFR, next 100 from EUR
#     samples=[msprime.Sample(0, 0)]*sample_size + [msprime.Sample(1, 0)]*sample_size,
#     length=length, 
#     mutation_rate=mu, 
#     recombination_rate=rr,
#     random_seed=seed,
#     num_replicates=replicates,
#     **GutenkunstThreePopArchaic_model.asdict())

# create dictionary of models
model_dict = {"GutenkunstThreePop": GutenkunstThreePop_ts,
             "TennessenTwoPop": TennessenTwoPop_ts}

# Get simulated data and run archaic admixture detection methods

In [None]:
run_archie = False
for model_label, model in model_dict.items():    
    for j, ts in enumerate(model):
        if j == 1: # testing with just 1st simulated tree sequence
            
            eig_data = generateEigData(ts, model_label, rep_label=j, mnm_frac=0.015, mnm_dist=100)
            print("---Without MNMs---")
            print(eig_data.prefix)
            print(eig_data.geno.shape)
            print(eig_data.snp.shape)
            writeEigData(eig_data, msprime_dir).dump()
            
            # write non-MNM data to VCF
            ts.write_vcf(msprime_dir + eig_data.prefix + ".vcf", 2)
            
            print("---With MNMs---")
            eig_data_mnms = eig_data.mnms()
            print(eig_data_mnms.prefix)
            print(eig_data_mnms.geno.shape)
            print(eig_data_mnms.snp.shape)
            writeEigData(eig_data_mnms, msprime_dir).dump()
            
            # add code for converting .snp/.geno/.ind data to VCF
            # alternatively, use cyvcf2 to read the non-MNM VCF into a numpy array 
            # and duplicate rows with the repeat_rows indices
            #
            # old_gts = VCF genotypes as numpy array
            # ...
            # new_gts = np.repeat(self.geno, repeats=eig_data_mnms.repeat_rows, axis=0)
            # ...
            # write new VCF
            
            if run_archie:
                for data in [eig_data, eig_data_mnms]:
                    for pop in ["afr", "eur"]:

                        if pop == "afr":
                            ref_pop = "eur"
                        else:
                            ref_pop = "afr"

                        prefix = data.prefix
                        stats_pop_cmd = "python " + archie_src_dir + "data/calc_stats_window_data.py" + \
                            " -s " + msprime_dir + prefix + ".snp" + \
                            " -i " + msprime_dir + prefix + "_" + pop + ".ind" + \
                            " -a " + msprime_dir + prefix + "_" + pop + ".geno" + \
                            " -r " + msprime_dir + prefix + "_" + ref_pop + ".geno" + \
                            " -c 1 -b 0 -e 50000 -w 50000 -z 50000 " + \
                            " > " + archie_out_dir + prefix + "_" + pop + ".txt"  
                        print(stats_pop_cmd + "\n")
#                         os.system(stats_pop_cmd)
            
            # add code for evaluating additional methods
            # if run_sprime:
            # ...
            # ...
            
            # if run_moments:
            # ...
            # ...
            
            # if run_idetect:
            # ...
            # ...