In [0]:
import scandir
import os, sys
import rpy2
from rpy2.robjects import pandas2ri
pandas2ri.activate()
import rpy2.robjects as ro
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import dill
import random
import vcf
from hdfstorehelper import HDFStoreHelper
import statsmodels.api as sm
import statsmodels.formula.api as smf
import operator
import traceback
%load_ext rpy2.ipython
from rpy2.robjects import pandas2ri as p2r
p2r.activate()
r = ro.r
import shutil
from utils import read_df, save_df
from pathlib import Path, PurePath
from ipyparallel import Client
from collections import Counter, defaultdict, namedtuple, OrderedDict
from scipy.stats import mannwhitneyu, ks_2samp, f_oneway
import tables
import ujson
import pickle

In [0]:
rc = Client(profile="sge")

In [0]:
dview = rc[:]
lview = rc.load_balanced_view()
len(dview)

In [0]:
analysis_dir = "/home/cfriedline/eckertlab/gypsy_indiv/raw_demult/analysis/samtools1.3_masurca3/beagle40/"

In [0]:
gemma_dir = os.path.join(analysis_dir, "gemma_run")
gemma_dir = os.path.join(gemma_imputed.ipynba_dir, "output")

In [0]:
ni_data = read_df("/home/cfriedline/eckertlab/gypsy_indiv/raw_demult/analysis/samtools1.3_masurca3/ni", "z12_df")

In [0]:
def percent_missing(snp):
    c = snp.value_counts()
    if not -1 in c:
        return 0
    return c[-1]/np.sum(c)

In [0]:
percent_missing = ni_data.apply(percent_missing)

In [0]:
phenos = ["mass", "tdt", "pd"]

In [0]:
combined_dfs = pickle.load(open(os.path.join(gemma_dir, "combined_dfs.pkl"), "rb"))

In [0]:
effect_snps = pickle.load(open(os.path.join(gemma_dir, "effect_snps.pkl"), "rb"))

In [0]:
effect_snps.keys()

In [0]:
gt_base_df = read_df(analysis_dir, 'gt_base_df')

In [0]:
pops = sorted(set([x.split("_")[0] for x in gt_base_df.index]))

In [0]:
def count_genotypes(snp):
    counts = Counter()
    for gt in snp:
        try:
            float(gt) #if gt is nan
        except:
            counts[gt[0]]+=1
            counts[gt[-1]]+=1
    return sorted(counts.items(), key=operator.itemgetter(1))

In [0]:
gt_base_df['population'] = gt_base_df.apply(lambda x: x.name.split("_")[0], axis=1)

In [0]:
pop_allele_data = {}

def add_allele_freq(gt_list):
    data = gt_list
    ret = OrderedDict()
    if len(gt_list) == 2:
        total = data[0][1]+data[1][1]
        ret[data[0][0]] = [data[0][1], data[0][1]/total]
        ret[data[1][0]] = [data[1][1], data[1][1]/total]
    else:
        ret[data[0][0]] = [data[0][1], 1.0]
    return ret

for group, data in gt_base_df.groupby('population'):
    data = data.drop('population', axis=1)
    print(group, data.shape)
    gt = data.apply(count_genotypes).apply(add_allele_freq)
    pop_allele_data[group] = gt.to_dict()

In [0]:
gemma_gt = read_df(analysis_dir, '_gemma_gt').replace("NA", np.nan)

In [0]:
gemma_gt.head()

In [0]:
pd.DataFrame(pop_allele_data['VA1'])

In [0]:
gt_counts = gt_base_df.apply(count_genotypes)

In [0]:
gt_counts_af = gt_counts.apply(add_allele_freq)

In [0]:
gt_counts_af.head()

In [0]:
gemma_gt.head()

In [0]:
dview['gemma_gt'] = gemma_gt

dview['pops'] = pops

dview['analysis_dir'] = analysis_dir

In [0]:
%%px 
import os, pickle, traceback

In [0]:
with open(os.path.join(analysis_dir, "pop_allele_data.pkl"), "wb") as o:
    pickle.dump(pop_allele_data, o, pickle.HIGHEST_PROTOCOL)

In [0]:
%%px 
if not 'pop_allele_data' in dir():
    pop_allele_data = pickle.load(open(os.path.join(analysis_dir, "pop_allele_data.pkl"), "rb"))

In [0]:
def compute_obs_heterozygosity(snp):
    het = 0
    total = 0
    for gt in snp:
        if gt[0] != gt[-1]:
            het += 1
        if gt[1] == "/" or gt[1] == "|":
            total += 1
    return het/total

def compute_exp_heterozygosity(snp):
    het = 0
    total = 0
    c = Counter()
    for gt in snp:
        c[gt[0]] += 1
        c[gt[-1]] += 1
    total = np.sum(list(c.values()))
    He = 2
    for a in c:
        He *= (c[a]/total)
    return He

In [0]:
het_bins = np.linspace(0,1,20)

In [0]:
He = gt_base_df.drop("population", axis=1).apply(compute_exp_heterozygosity)
He = pd.DataFrame(He, columns=["He"])
He['rs'] = He.index

In [0]:
Ho = gt_base_df.drop("population", axis=1).apply(compute_obs_heterozygosity)
Ho = pd.DataFrame(Ho, columns=["Ho"])
Ho['rs'] = Ho.index

In [0]:
He.head()

In [0]:
tdt_ho = combined_dfs['tdt'].join(Ho, how="inner")
mass_ho = combined_dfs['mass'].join(Ho, how="inner")
pd_ho = combined_dfs['pd'].join(Ho, how="inner")

tdt_he = combined_dfs['tdt'].join(He, how="inner")
mass_he = combined_dfs['mass'].join(He, how="inner")
pd_he = combined_dfs['pd'].join(He, how="inner")

In [0]:
tdt_ho['het_bin'] = np.digitize(tdt_ho.Ho, het_bins)
mass_ho['het_bin'] = np.digitize(mass_ho.Ho, het_bins)
pd_ho['het_bin'] = np.digitize(pd_ho.Ho, het_bins)

tdt_he['het_bin'] = np.digitize(tdt_he.He, het_bins)
mass_he['het_bin'] = np.digitize(mass_he.He, het_bins)
pd_he['het_bin'] = np.digitize(pd_he.He, het_bins)

In [0]:
PhenoContainer = namedtuple("PhenoContainer", ["He", "Ho", "hmean", "sig", "relaxed"])

In [0]:
PC = {"mass": PhenoContainer(mass_he, mass_ho, combined_dfs['mass'], 
                             effect_snps[('mass', 'gamma_hmean', 'total_effect', 0.999)],
                             effect_snps[('mass', 'gamma_hmean', 'total_effect', 0.995)]),
      "pd":PhenoContainer(pd_he, pd_ho, combined_dfs['pd'],
                          effect_snps[('pd', 'gamma_hmean', 'total_effect', 0.999)],
                          effect_snps[('pd', 'gamma_hmean', 'total_effect', 0.995)]),
      "tdt":PhenoContainer(tdt_he, tdt_ho, combined_dfs['tdt'],
                           effect_snps[('tdt', 'gamma_hmean', 'total_effect', 0.999)],
                           effect_snps[('tdt', 'gamma_hmean', 'total_effect', 0.995)])}

In [0]:
for pheno in PC:
    plt.scatter(PC[pheno].He.gamma_hmean, PC[pheno].He.He)
    plt.xlabel("PIP")
    plt.ylabel(r"$H_{exp}$")
    plt.title("TDT")
    plt.show()

In [0]:
for pheno in PC:
    plt.scatter(PC[pheno].Ho.gamma_hmean, PC[pheno].Ho.Ho)
    plt.xlabel("PIP")
    plt.ylabel(r"$H_{obs}$")
    plt.title(pheno.upper())
    plt.show()

In [0]:
test_snp = 'ctg7180005039298_50'
test_minor = gemma_gt.ix[test_snp, "minor"]
print(test_minor)
for p in pop_allele_data:
    if test_minor in pop_allele_data[p][test_snp]:
        print(pop_allele_data[p][test_snp][test_minor])

In [0]:
@lview.remote()
def do_pairwise(sig_list):
    import numpy as np
    import traceback
    ret = []
    for i, snp in enumerate(sig_list):
        snp_i = snp
        minor_allele_i = gemma_gt.ix[snp_i, "minor"]
        for j in range(i):
            snp_j = sig_list[j]
            minor_allele_j = gemma_gt.ix[snp_j, "minor"]
            in_prods = []
            freqs = {snp_i: [], snp_j: []}
            for p in pops:
                paf_i = paf_j = 0.0
                
                if minor_allele_i in pop_allele_data[p][snp]:
                    paf_i = pop_allele_data[p][snp][minor_allele_i][1]
                    
                if minor_allele_j in pop_allele_data[p][snp_j]:
                    paf_j = pop_allele_data[p][snp_j][minor_allele_j][1]
                
                freqs[snp_i].append(paf_i)
                freqs[snp_j].append(paf_j)
                in_prods.append(paf_i * paf_j)
                avg_in_prod = np.mean(in_prods)
                freqs_avg = {k: np.mean(freqs[k]) for k in freqs}
                across_freqs = list(freqs_avg.values())
                across_prod = across_freqs[0] * across_freqs[1]
                ret.append(avg_in_prod-across_prod)
    return ret

In [0]:
def get_nulls_by_het(n, sig_df, df):
    unassoc = df.drop(sig_df.index)
    het_bin_counts = df.ix[sig_df.index]['het_bin'].value_counts()
    het_bins = het_bin_counts.index.tolist()
    unassoc = unassoc[unassoc.het_bin.isin(het_bins)]
    data = []
    for i in range(n):
        inner = []
        for het_bin, het_count in het_bin_counts.iteritems():
            inner.extend(unassoc[unassoc.het_bin == het_bin].rs.sample(het_count).tolist())
        data.append(inner)
    return data, het_bins

def get_nulls_naive(n, sig_df, df):
    unassoc = df.drop(sig_df.index)
    return [unassoc.rs.sample(len(sig_df)).tolist() for x in range(n)], []

In [0]:
get_nulls = get_nulls_by_het
dview['get_nulls_naive'] = get_nulls_naive
dview['get_nulls'] = get_nulls
dview['get_nulls_by_het'] = get_nulls

In [0]:
gwas_upper_D = {}
for pheno in PC:
    gwas_upper_D[pheno] = np.abs(do_pairwise(list(PC[pheno].sig)).r)

In [0]:
nulls_upper = {}
upper_sig_het_bins = {}
nulls_lower = {}
lower_sig_het_bins = {}
for pheno in PC:
    nulls_upper[pheno], upper_sig_het_bins[pheno] = get_nulls(1000, PC[pheno].upper20, PC[pheno].He)
    nulls_lower[pheno], lower_sig_het_bins[pheno] = get_nulls(1000, PC[pheno].lower20, PC[pheno].He)


In [0]:
nulls_upper_D = {}
nulls_lower_D = {}
for pheno in PC:
    nulls_upper_D[pheno] = []
    nulls_lower_D[pheno] = []
    
    for i, null_list in enumerate(nulls_upper[pheno]):
        nulls_upper_D[pheno].append(do_pairwise(null_list))
    
    for i, null_list in enumerate(nulls_lower[pheno]):
        nulls_lower_D[pheno].append(do_pairwise(null_list))

In [0]:
for pheno in PC:
    nulls_upper_D[pheno] = [np.abs(x.r) for x in nulls_upper_D[pheno]]
    nulls_lower_D[pheno] = [np.abs(x.r) for x in nulls_lower_D[pheno]]

In [0]:
sns.set_context("talk")

In [0]:
for pheno in PC:
    sns.distplot(gwas_upper_D[pheno], label="upper tail")
    sns.distplot(gwas_lower_D[pheno], label="lower tail")
    plt.xlabel("Pairwise D")
    plt.title(pheno.upper())
    plt.legend()
    plt.show()

In [0]:
len(nulls_upper_D['mass'][0])

In [0]:
pheno_boxdata = {}
for pheno in PC:
    boxdata = {"Upper tail":[], "Lower tail": [], "Random Upper": [], "Random Lower": []}
    minpoints = np.min((len(gwas_lower_D[pheno]), len(gwas_upper_D[pheno]), len(nulls_upper_D[pheno]), len(nulls_lower_D[pheno])))
    minpoints = len(gwas_lower_D[pheno])
    for d in gwas_lower_D[pheno][0:minpoints]:
        boxdata['Lower tail'].append(d)

    for d in gwas_upper_D[pheno][0:minpoints]:
        boxdata['Upper tail'].append(d)

    for d in nulls_upper_D[pheno][100][0:minpoints]:
        boxdata["Random Upper"].append(np.abs(d))

    for d in nulls_lower_D[pheno][100][0:minpoints]:
        boxdata["Random Lower"].append(np.abs(d))
    pheno_boxdata[pheno] = boxdata

In [0]:
percent_missing = pd.DataFrame(percent_missing, columns=["missing"])

percent_missing['rs'] = percent_missing.index

percent_missing.head()

In [0]:
upper_missing = {}
lower_missing = {}

for pheno in PC:
    upper_missing[pheno] = PC[pheno].upper20.join(percent_missing, on="rs", rsuffix="_pm", how="inner")
    lower_missing[pheno] = PC[pheno].lower20.join(percent_missing, on="rs", rsuffix="_pm", how="inner")

In [0]:
for pheno in PC:
    print(pheno, f_oneway(upper_missing[pheno].missing, lower_missing[pheno].missing))

In [0]:
for pheno in PC:
    bd = pd.DataFrame(pheno_boxdata[pheno])
    ax = sns.boxplot(data=bd)
    plt.ylabel("log D")
    ax.set_yscale("log", basey=10)
    plt.xlabel("n = %d" % len(bd))
    plt.title(pheno.upper())
    plt.show()

In [0]:
@lview.remote()
def ks_test(arr1, arr2):
    from scipy.stats import ks_2samp
    return ks_2samp(arr1, arr2)

@lview.remote()
def mwu_test(arr1, arr2):
    from scipy.stats import mannwhitneyu
    return mannwhitneyu(arr1, arr2)

In [0]:
ks_upper = {}
ks_lower = {}
for pheno in PC:
    ks_upper[pheno] = [ks_test(gwas_upper_D[pheno], x) for x in nulls_upper_D[pheno]]
    ks_lower[pheno] = [ks_test(gwas_lower_D[pheno], x) for x in nulls_lower_D[pheno]]

In [0]:
gwas_null_ks_upper = {}
gwas_null_ks_lower = {}
for pheno in PC:
    gwas_null_ks_upper[pheno] = [x.r for x in ks_upper[pheno]]
    gwas_null_ks_lower[pheno] = [x.r for x in ks_lower[pheno]]

In [0]:
gwas_null_pvals_upper = {}
gwas_null_pvals_lower = {}

for pheno in PC:    
    u = []
    l = []
    gwas_null_pvals_upper[pheno] = u
    gwas_null_pvals_lower[pheno] = l
    
    for stat, pval in gwas_null_ks_upper[pheno]:
        u.append(pval)

    for stat, pval in gwas_null_ks_lower[pheno]:
        l.append(pval)

In [0]:
nulls2_upper = nulls3_upper = nulls2_lower = nulls3_lower = {}

for pheno in PC:
    print(pheno)

    nulls2_upper[pheno] = get_nulls(1000, PC[pheno].upper20, PC[pheno].He)
    nulls3_upper[pheno] = get_nulls(1000, PC[pheno].upper20, PC[pheno].He)

    nulls2_lower[pheno] = get_nulls(1000, PC[pheno].lower20, PC[pheno].He)
    nulls3_lower[pheno] = get_nulls(1000, PC[pheno].lower20, PC[pheno].He)

In [0]:
n2_upper = n3_upper = n2_lower = n3_lower = {}

for pheno in PC:    
    n2_upper[pheno] = [do_pairwise(x) for x in nulls2_upper[pheno][0]]
    n3_upper[pheno] = [do_pairwise(x) for x in nulls3_upper[pheno][0]]

    n2_lower[pheno] = [do_pairwise(x) for x in nulls2_lower[pheno][0]]
    n3_lower[pheno] = [do_pairwise(x) for x in nulls3_lower[pheno][0]]

In [0]:
n2_r_upper = n2_r_lower = n3_r_upper = n3_r_lower = {}

for pheno in PC:
    n2_r_upper[pheno] = [x.r for x in n2_upper[pheno]]
    n3_r_upper[pheno] = [x.r for x in n3_upper[pheno]]

    n2_r_lower[pheno] = [x.r for x in n2_lower[pheno]]
    n3_r_lower[pheno] = [x.r for x in n3_lower[pheno]]

In [0]:
null_null_pvals_upper = null_null_pvals_lower = {}

for pheno in PC:
    null_null_pvals_upper[pheno] = []
    for x, y in zip(n2_r_upper[pheno], n3_r_upper[pheno]):
        stat, p = ks_2samp(x, y)
        null_null_pvals_upper[pheno].append(p)

    null_null_pvals_lower[pheno] = []
    for x, y in zip(n2_r_lower[pheno], n3_r_lower[pheno]):
        stat, p = ks_2samp(x, y)
        null_null_pvals_lower[pheno].append(p)

In [0]:
for pheno in PC:
    print("upper", pheno, ks_2samp(gwas_null_pvals_upper[pheno], null_null_pvals_upper[pheno]))
    print("loser", pheno, ks_2samp(gwas_null_pvals_lower[pheno], null_null_pvals_lower[pheno]))

In [0]:
for pheno in PC:
    He = PC[pheno].He
    plt.scatter(He.ix[PC[pheno].lower20.index].postrb_hmean, He.ix[PC[pheno].lower20.index].He)
    plt.title(pheno.upper() + " lower 20")
    plt.xlabel("PIP")
    plt.ylabel("Hexp")
    plt.show()

In [0]:
for pheno in PC:
    He = PC[pheno].He
    plt.scatter(He.ix[PC[pheno].upper20.index].postrb_hmean, He.ix[PC[pheno].upper20.index].He)
    plt.title(pheno.upper() + " upper 20")
    plt.xlabel("PIP")
    plt.ylabel("Hexp")
    plt.show()