## Genotypic recalibration with machine learning: usage example

This notebook illustrates the use of machine learning methods and associated code described in the parent paper.

In [1]:
import pandas as pd

import sys
sys.path.insert(0, '../python')

from preprocessing import VCF, load_suffixes, prepare_input
from recalibrator import Recalibrator

## Training & saving a model

Training is performed on a dataset consisting of VCF files resulting from performing variant calling with GATK on reads from a family trio, along with a 'synthetic abortus' that contains a mixture of the mother's and child's reads. The code to read and process the dataset relies on a specific directory structure. The `VCF` class and the `prepare input` method are all that's needed to read a VCF and convert it into an array that can be input into a model. Once trained, a recalibrator model can be serialized and saved.

In [2]:
trios = ["ajt", "chd", "corpas", "yri"]

# Pre-processing. Uncomment during first run of the script, then
# comment to avoid re-computing

# for trio in trios:
#     data_dir = '../data/' + trio + '/'
#     df = load_suffixes(data_dir)
#     df.to_csv(trio + '.csv', index=False)

Construct training dataset by concatenating rows from all the synthetic abortus trios

In [3]:
df_train = pd.DataFrame()
    
for train in trios:
    df_train = df_train.append(pd.read_csv(train + '.csv'))
    
df_train = df_train[::100] # Train on subset of input rows

* Split dataset into input and target parameters
* Initialize a recalibrator model and train it on our data
* Save the serialized model

In [4]:
X_train = prepare_input(df_train, target_cols=['justchild^GT'])
y_train = df_train['justchild^GT'].values

r = Recalibrator()
r.train(X_train, y_train)
r.save("model.pickle")


Training logistic regression
Training XGB
[0]	validation_0-merror:0.056753
Will train until validation_0-merror hasn't improved in 20 rounds.
[1]	validation_0-merror:0.054217
[2]	validation_0-merror:0.053266
[3]	validation_0-merror:0.051997
[4]	validation_0-merror:0.05168
[5]	validation_0-merror:0.050729
[6]	validation_0-merror:0.049144
[7]	validation_0-merror:0.048193
[8]	validation_0-merror:0.049461
[9]	validation_0-merror:0.047559
[10]	validation_0-merror:0.045656
[11]	validation_0-merror:0.044705
[12]	validation_0-merror:0.044071
[13]	validation_0-merror:0.044071
[14]	validation_0-merror:0.044071
[15]	validation_0-merror:0.043754
[16]	validation_0-merror:0.04312
[17]	validation_0-merror:0.042803
[18]	validation_0-merror:0.043437
[19]	validation_0-merror:0.042486
[20]	validation_0-merror:0.042486
[21]	validation_0-merror:0.042169
[22]	validation_0-merror:0.042803
[23]	validation_0-merror:0.04312
[24]	validation_0-merror:0.041218
[25]	validation_0-merror:0.040583
[26]	validation_0-me

## Loading a model & recalibrating

* Instantiate a recalibrator object and load a saved model
* Instantiate a VCF object from a VCF file

In [5]:
r = Recalibrator()
r.load("model.pickle")

abortus = VCF("../data/ajt/abortus.frac0.5.seed151_trio.vcf")
abortus.df.head()

Unnamed: 0,#CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO,FORMAT,abortus,father,mother
0,chr1,69270,.,A,G,91.17,.,AC=2;AF=1.00;AN=2;DP=4;ExcessHet=3.0103;FS=0.0...,GT:AD:DP:GQ:PL,"./.:0,0:0:.:0,0,0","1/1:0,4:4:12:117,12,0","./.:0,0:0:.:0,0,0"
1,chr1,69511,.,A,G,3514.9,.,AC=2;AF=1.00;AN=2;DP=128;ExcessHet=3.0103;FS=0...,GT:AD:DP:GQ:PL,"./.:0,0:0:.:0,0,0","1/1:0,128:128:99:3541,383,0","./.:0,0:0:.:0,0,0"
2,chr1,183238,.,G,C,20.18,.,AC=1;AF=0.250;AN=6;BaseQRankSum=-1.282e+00;Cli...,GT:AD:DP:GQ:PL,"0/0:9,0:9:27:0,27,231","0/0:14,0:14:39:0,39,585","0/1:3,3:6:51:51,0,81"
3,chr1,187485,.,G,A,244.19,.,AC=5;AF=0.750;AN=6;BaseQRankSum=-9.670e-01;Cli...,GT:AD:DP:GQ:PGT:PID:PL,"1/1:0,2:2:6:.:.:49,6,0","0/1:1,2:3:19:.:.:40,0,19","1/1:0,5:5:15:1|1:187485_G_A:185,15,0"
4,chr1,942451,.,T,C,980.92,.,AC=6;AF=1.00;AN=6;DP=35;ExcessHet=3.0103;FS=0....,GT:AD:DP:GQ:PL,"1/1:0,20:20:60:561,60,0","1/1:0,8:8:24:242,24,0","1/1:0,7:7:21:204,21,0"


In [6]:
# VCF class automatically infers the contamination fraction
abortus.process("abortus", "mother", "father")
print("Estimated contamination: {}".format(abortus.estimated_contamination))

Estimated contamination: 0.5066857431635673


In [7]:
abortus.df_processed.head()

Unnamed: 0,#CHROM,POS,REF,ALT,AC,AF,abortus^GT,abortus^GQ,abortus^DP,mother^GT,...,abortus^GT^0^1H,abortus^GT^1^1H,abortus^GT^2^1H,mother^GT^0^1H,mother^GT^1^1H,mother^GT^2^1H,father^GT^0^1H,father^GT^1^1H,father^GT^2^1H,contamination
0,chr1,183238,3,2,1,0.25,0,27,9,1,...,True,False,False,False,True,False,True,False,False,0.506686
1,chr1,187485,3,1,5,0.75,2,6,2,2,...,False,False,True,False,False,True,False,True,False,0.506686
2,chr1,942451,4,2,6,1.0,2,60,20,2,...,False,False,True,False,False,True,False,False,True,0.506686
3,chr1,942934,3,2,3,0.5,1,99,184,0,...,False,True,False,True,False,False,False,False,True,0.506686
4,chr1,943937,2,4,2,0.25,1,99,88,1,...,False,True,False,False,True,False,True,False,False,0.506686


In [8]:
# Predicted labels
preds_lr = r.predict_lr(abortus.prepare_input())

In [9]:

abortus.save_predictions(preds_lr, filename="recalibrated_lr.vcf", sample="abortus")

## Recalibrating with confidence intervals

We use the `VCF` class's inbuilt method to process the VCF and extract the fields required by `confidence_intervals`.

In [10]:
abortus = VCF("../data/ajt/abortus.frac0.5.seed151_trio.vcf")
abortus.process("abortus", "mother", "father")
preds_ci = r.model_ci.predict(abortus.prepare_input())

# abortus.save_predictions(preds_ci, filename="recalibrated_ci.vcf", sample="abortus")