In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score

from preprocessing import *
from recalibrator import Recalibrator
from utils import match
from confidence_intervals import confidence_intervals

In [None]:
trios = ["ajt", "chd", "corpas", "yri"]

# # Pre-processing. Uncomment during first run of the script, then
# # comment to avoid re-computing

# for trio in trios:
#     data_dir = '../data/' + trio + '/'
#     df = load_suffixes(data_dir)
#     df.to_csv(trio + '.csv')

In [None]:
test = 'yri'
df_train = pd.DataFrame()

for train in trios:
    if train != test:
        df_train = df_train.append(pd.read_csv(train + '.csv'))

df_test = pd.read_csv(test + '.csv')

gt_cols = list(filter(match("GT", pos=-1), df_train.columns.values))
to_drop = list(set(['#CHROM', 'POS', 'Unnamed: 0', 'Unnamed: 1'] + gt_cols))

X_train = df_train.drop(to_drop + ["justchild^GT"], axis=1).values
y_train = df_train['justchild^GT'].values
X_test = df_test.drop(to_drop + ["justchild^GT"], axis=1).values
y_test = df_test['justchild^GT'].values

contaminations = df_test['contamination'].values
contamination_values = list(sorted(np.unique(contaminations)))

for contamination in contamination_values:
    idx[contamination] = contaminations == contamination

r = Recalibrator()
r.train(X_train, y_train)

In [None]:
from copy import deepcopy

def test_perturbed(predict, perturbation):
    df_test_perturbed = deepcopy(df_test)


    new_contaminations = contaminations + perturbation
    df_test_perturbed['contamination'] = new_contaminations

    X_test = df_test_perturbed.drop(to_drop, axis=1).values
    y_test = df_test['justchild^GT'].values
    genotype_ab_test = df_test['abortus^GT'].values

    return {contamination: predict(X_test[idx[contamination]]) for contamination in contamination_values}

In [None]:

stdev = 0.0
errors = [-0.05, -0.03, -0.01, 0, 0.01, 0.03, 0.05]

preds_perturbed_lr = {}
preds_perturbed_xgb = {}
for mean in errors:
    preds_perturbed_lr[mean] = test_perturbed(r.predict_lr, np.random.normal(mean, stdev, contaminations.shape))
    preds_perturbed_xgb[mean] = test_perturbed(r.predict_xgb, np.random.normal(mean, stdev, contaminations.shape))

scores_perturbed_lr = {mean: [accuracy_score(y_tests[contamination], preds_perturbed_lr[mean][contamination]) for contamination in contamination_values] for mean in errors}
scores_perturbed_xgb = {mean: [accuracy_score(y_tests[contamination], preds_perturbed_xgb[mean][contamination]) for contamination in contamination_values] for mean in errors}

# preds_perturbed = {contamination: model.predict(X_test[idx[contamination]]) for contamination in contamination_values}

In [None]:
plt.figure(figsize=(11, 8))
# plt.plot(contamination_values, scores_naive, label="No recalibration")
# plt.plot(contamination_values, scores_perturbed, label="Perturbed")

for mean in errors:
    plt.plot(contamination_values, scores_perturbed_lr[mean], label="{}".format(mean))
plt.legend()
# from labellines import labelLine, labelLines
# labelLines(plt.gca().get_lines())

plt.grid()
plt.xlabel("MCC")
plt.ylabel("Fraction of correct genotypes")
plt.title("Recalibration with logistic regression for biased MCC estimation (YRI trio)")
# plt.savefig('stabilities_lr.eps', format='eps', dpi=1000)

In [None]:
plt.figure(figsize=(11, 8))
# plt.plot(contamination_values, scores_naive, label="No recalibration")
# plt.plot(contamination_values, scores_perturbed, label="Perturbed")

for mean in errors:
    plt.plot(contamination_values, scores_perturbed_xgb[mean], label="{}".format(mean))
plt.legend()
# from labellines import labelLine, labelLines
# labelLines(plt.gca().get_lines())

plt.grid()
plt.xlabel("MCC")
plt.ylabel("Fraction of correct genotypes")
plt.title("Recalibration with logistic regression for biased MCC estimation (YRI trio)")
# plt.savefig('stabilities_xgb.eps', format='eps', dpi=1000)