In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score

from preprocessing import *
from recalibrator import Recalibrator
from utils import match
from confidence_intervals import confidence_intervals

In [3]:
trios = ["ajt", "chd", "corpas", "yri"]

# # Pre-processing. Uncomment during first run of the script, then
# # comment to avoid re-computing

# for trio in trios:
#     data_dir = '../data/' + trio + '/'
#     df = load_suffixes(data_dir)
#     df.to_csv(trio + '.csv')

In [None]:

results_cum = {}

for test in trios:
    results = {}
    test = 'yri'
    df_train = pd.DataFrame()

    for train in trios:
        if train != test:
            df_train = df_train.append(pd.read_csv(train + '.csv'))

    df_test = pd.read_csv(test + '.csv')

    gt_cols = list(filter(match("GT", pos=-1), df_train.columns.values))
    to_drop = list(set(['#CHROM', 'POS', 'Unnamed: 0', 'Unnamed: 1'] + gt_cols))

    X_train = df_train.drop(to_drop + ["justchild^GT"], axis=1).values
    y_train = df_train['justchild^GT'].values
    X_test = df_test.drop(to_drop + ["justchild^GT"], axis=1).values
    y_test = df_test['justchild^GT'].values

    contaminations = df_test['contamination'].values
    contamination_values = list(sorted(np.unique(contaminations)))

    X_tests = {}
    y_tests = {}
    idx = {}

    for contamination in contamination_values:
        idx[contamination] = contaminations == contamination
        X_tests[contamination] = X_test[idx[contamination]]
        y_tests[contamination] = y_test[idx[contamination]]

    results['y_test'] = y_test
    results['preds_naive'] = df_test['abortus^GT'].values
    results['idx'] = idx

    r = Recalibrator()
    r.train(X_train, y_train)
    results['preds_lr'] = r.model_lr.predict(X_test)
    results['preds_xgb'] = r.model_xgb.predict(X_test)
    results['preds_ci'] = confidence_intervals(df_test)

    results_cum[test] = results


[0]	validation_0-merror:0.064491
Will train until validation_0-merror hasn't improved in 20 rounds.
[1]	validation_0-merror:0.062706
[2]	validation_0-merror:0.062318
[3]	validation_0-merror:0.06055
[4]	validation_0-merror:0.059001
[5]	validation_0-merror:0.058642
[6]	validation_0-merror:0.058449
[7]	validation_0-merror:0.05359
[8]	validation_0-merror:0.051836
[9]	validation_0-merror:0.050304
[10]	validation_0-merror:0.048739
[11]	validation_0-merror:0.046979
[12]	validation_0-merror:0.046083
[13]	validation_0-merror:0.04349
[14]	validation_0-merror:0.041348
[15]	validation_0-merror:0.039691
[16]	validation_0-merror:0.038705
[17]	validation_0-merror:0.036742
[18]	validation_0-merror:0.035867
[19]	validation_0-merror:0.035226
[20]	validation_0-merror:0.033912
[21]	validation_0-merror:0.03352
[22]	validation_0-merror:0.03296
[23]	validation_0-merror:0.032323
[24]	validation_0-merror:0.031726
[25]	validation_0-merror:0.031204
[26]	validation_0-merror:0.03064
[27]	validation_0-merror:0.0301

In [6]:
import pickle

with open("results_1vA_new.pickle", "wb") as f:
    pickle.dump(results, f)

  lower_bound = contaminations - z*np.sqrt(contaminations*(1 - contaminations)/df_test[sample_name + '^DP'].values)
  upper_bound = contaminations + z*np.sqrt(contaminations*(1 - contaminations)/df_test[sample_name + '^DP'].values)


array([2, 1, 1, ..., 1, 0, 2])

In [7]:
accuracy_score(y_test, df_test['abortus^GT'].values)

0.8757561999434468

In [8]:
sample_name = 'abortus'
df_test[sample_name + '^GT']

0         2
1         1
2         1
3         1
4         0
5         0
6         0
7         0
8         0
9         2
10        2
11        2
12        1
13        0
14        1
15        0
16        1
17        1
18        2
19        1
20        0
21        2
22        0
23        0
24        1
25        1
26        1
27        1
28        2
29        1
         ..
502152    1
502153    1
502154    1
502155    1
502156    1
502157    0
502158    1
502159    1
502160    0
502161    1
502162    1
502163    1
502164    1
502165    1
502166    2
502167    1
502168    1
502169    1
502170    1
502171    1
502172    2
502173    2
502174    0
502175    1
502176    1
502177    0
502178    0
502179    1
502180    0
502181    2
Name: abortus^GT, Length: 502182, dtype: int64