# Series analysis


Analysis updates:
- retrieve calls rejected only because VAF<0.01 (bcbio lower acceptable threshold) for Mutect2, Strelka2 and Vardict
- correct for germline mutations using GATK haplotype calls
- get PR curves for sinvict by encoding the 6 files as thresholds. Assumption: linear filters
- plotting: PR curves stops at 10e-2 on the left
- ground truths:

    1) Consensus: build using
        majority of 5/8 callers for SNV and 3/5 callers for INDELS


    2) Ranked mutations: metascore built using
        weigthed sum of normalised scores between 0 and 1 for each caller
        with weigths = inversially proportional to number of calls made by caller (if a caller calls few mutations higher weight, it many calls low weight)
        threshold = 1/ncallers
        interpretation: if 1 caller is sure (score = 1) of calling this position, add it to GT
        interpretation: if 2 callers are quite sure of calling this position (score > 0,5 each), add it to GT
        
- integrate VAF approx

    1) mixture with VAF instead of tumor burden
    
    2) correct for mutations non present in diluted samples (vaf = 0)
    
    3) pool patients together using VAF

In [None]:
# Imports

%load_ext autoreload
%autoreload 2

import io
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pysam
import warnings
from sklearn.metrics import precision_recall_curve, f1_score, average_precision_score
warnings.filterwarnings('ignore')
from sklearn.metrics import confusion_matrix

# set working directory
if not os.getcwd().endswith('cfdna_snv_benchmark'):
    os.chdir('../')
print('Current working directory: {}'.format(os.getcwd()))

from utils.config import Config
from utils.viz import *
from utils.table import *
from utils.metrics import *
from utils.calltable import *
from utils.calltableseries import *
from utils.groundtruth import *
from utils.metricsseries import *
from utils.venn import venn6, get_labels

In [None]:
# Config and Display paramaters

config = Config("config/", "config_viz.yaml")
set_display_params(config)
print(config.methods)

In [None]:
# Chomosome
mixtureid = 'CRC-COSMIC-5p_CRC-986_300316-CW-T'
mixtureids =  ['CRC-COSMIC-5p_CRC-1014_090516-CW-T', 'CRC-COSMIC-5p_CRC-986_300316-CW-T', 'CRC-COSMIC-5p_CRC-123_121115-CW-T']
reload = False
save = True
fixedvar='coverage'
filterparam = 'all'

markers = ['o', '^', 'X']
linestyles = ['-', '-', '-']
color_dict = {config.methods[i]: config.colors[i] for i in range(len(config.methods))}

muttypes = ['snv', 'indel']
metrics = ['auprc', 'precision', 'recall']

# chroms = 'all'
chroms = [str(c) for c in range(1,23) if c !=5 and c!=9 and c!=12 and c!=13 and c!= 14 and c!= 15 and c!=20 and c!=22]
print(chroms)
# chroms = 22

# Part I: (1) Load/Generate call tables, (2) Generate Ground truths and (3) Compute/Save metrics per patient

In [None]:
if fixedvar == 'coverage':
    seriesorder = [0.2, 0.15, 0.10, 0.05, 0.025, 0.01, 0]
    xaxis = 'vaf'
#for mixtureid in mixtureids:
mixtureid ='CRC-COSMIC-5p_CRC-986_300316-CW-T'
print('############# {} ############'.format(mixtureid))
for chrom in chroms:
    #chrom = '22'
    calltables = {'sampleid':[], 'tf':[], 'cov':[], 'snv':[], 'indel':[], 'snp':[]}
    calltable_snv, aux = get_calltableseries(config, mixtureid, chrom, muttype='snv', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
    calltable_indel, aux = get_calltableseries(config, mixtureid, chrom, muttype='indel', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
    calltable_snp, aux = get_calltableseries(config, mixtureid, chrom, muttype='snp', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
    print(calltable_snv.shape, calltable_indel.shape, calltable_snp.shape)
    #dilutionseries = aux.T[['spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.2f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:]) 
    #                        if (s!=0.2 and s!=0.025 and s!=0) else 'spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.1f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
    #                        if (s!=0.025 and s!=0) else 'spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.3f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
    #                        if s!=0 else 'spikein_' + mixtureid.split('_')[0] + '_vaf0_' + '_'.join(mixtureid.split('_')[1:])
    #                        for s in seriesorder]].T 
    dilutionseries = aux.T[['spikein_chr' + chrom + '_' + mixtureid.split('_')[0] + '_vaf' + str('{:.2f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:]) 
                            if (s!=0.2 and s!=0.025 and s!=0) else 'spikein_chr' + chrom + '_' + mixtureid.split('_')[0] + '_vaf' + str('{:.1f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
                            if (s!=0.025 and s!=0) else 'spikein_chr' + chrom + '_' + mixtureid.split('_')[0] + '_vaf' + str('{:.3f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
                            if s!=0 else 'spikein_chr' + chrom + '_' + mixtureid.split('_')[0] + '_vaf0_' + '_'.join(mixtureid.split('_')[1:])
                            for s in seriesorder]].T 
    dilutionseries['vaf'] = [float(ds.split('vaf')[1].split('_')[0]) for ds in list(dilutionseries.index)]
    print(dilutionseries)
    plasmasample = '_'.join(mixtureid.split('_')[:1])
    print(plasmasample)
    healthysample = '_'.join(mixtureid.split('_')[1:])
    print(healthysample)
    calltables['snv'] = calltable_snv
    calltables['indel'] = calltable_indel
    calltables['snp'] = calltable_snp
    calltables['sampleid'] = mixtureid
    calltables['tf'] = np.unique([cn.split('_')[0] for cn in list(calltable_snv.columns)])[:-5].astype(float)
    #for muttype in muttypes:
    muttype = 'snv'
    calltablesseries = generate_groundtruth(config, calltables[muttype], dilutionseries[['vaf']], ground_truth_method='spikein', muttype=muttype)
    results_auprc_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='auprc', ground_truth_method='spikein',
                                     refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    results_recall_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='recall', ground_truth_method='spikein',
                                           refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    results_precision_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='precision', ground_truth_method='spikein',
                                          refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    #results_auprc_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='auprc', ground_truth_method='spikein',
    #                                 refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    #results_recall_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='recall', ground_truth_method='spikein',
    #                                       refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    #results_precision_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='precision', ground_truth_method='spikein',
    #                                      refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    #figure_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, xy='pr', ground_truth_method='spikein',
    #                    refsample='undiluted', muttype=muttype.upper(), chrom=chrom, methods=None, fixedvar=fixedvar, save=save)

In [None]:
calltablesseries['truth'].sum()

In [None]:
calltablesseries['0.00_vardict'].sum()

In [None]:
calltablesseries[[c for c in list(calltablesseries.columns) if c.endswith('vardict')] + ['truth']].head(50)

In [None]:
if fixedvar == 'coverage':
    seriesorder = [0.2, 0.15, 0.10, 0.05, 0.025, 0.01, 0]
    xaxis = 'vaf'
# for mixtureid in mixtureids:
print('############# {} ############'.format(mixtureid))
#for chrom in chroms:
chrom = chroms
calltables = {'sampleid':[], 'tf':[], 'cov':[], 'snv':[], 'indel':[], 'snp':[]}
calltable_snv, aux = get_calltableseries(config, mixtureid, chrom, muttype='snv', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
calltable_indel, aux = get_calltableseries(config, mixtureid, chrom, muttype='indel', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
calltable_snp, aux = get_calltableseries(config, mixtureid, chrom, muttype='snp', filterparam=filterparam, reload=reload, save=save, diltype='spikein')
print(calltable_snv.shape, calltable_indel.shape, calltable_snp.shape)
dilutionseries = aux.T[['spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.2f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:]) 
                        if (s!=0.2 and s!=0.025 and s!=0) else 'spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.1f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
                        if (s!=0.025 and s!=0) else 'spikein_' + mixtureid.split('_')[0] + '_vaf' + str('{:.3f}'.format(s)) + '_' + '_'.join(mixtureid.split('_')[1:])
                        if s!=0 else 'spikein_' + mixtureid.split('_')[0] + '_vaf0_' + '_'.join(mixtureid.split('_')[1:])
                        for s in seriesorder]].T 
dilutionseries['vaf'] = [float(ds.split('vaf')[1].split('_')[0]) for ds in list(dilutionseries.index)]
print(dilutionseries)
plasmasample = '_'.join(mixtureid.split('_')[:1])
print(plasmasample)
healthysample = '_'.join(mixtureid.split('_')[1:])
print(healthysample)
calltables['snv'] = calltable_snv
calltables['indel'] = calltable_indel
calltables['snp'] = calltable_snp
calltables['sampleid'] = mixtureid
calltables['tf'] = np.unique([cn.split('_')[0] for cn in list(calltable_snv.columns)])[:-5].astype(float)
for muttype in muttypes:
    calltablesseries = generate_groundtruth(config, calltables[muttype], dilutionseries[['vaf']], ground_truth_method='spikein', muttype=muttype)
    #results_auprc_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='auprc', ground_truth_method=gtm,
    #                                 refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    #results_recall_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='recall', ground_truth_method=gtm,
    #                                       refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    #results_precision_df = metric_curve(config, calltablesseries, plasmasample, healthysample, seriesorder, metric='precision', ground_truth_method=gtm,
    #                                      refsample='undiluted', muttype=muttype, chrom=chrom, methods=config.methods, fixedvar=fixedvar)
    results_auprc_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='auprc', ground_truth_method='spikein',
                                     refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    results_recall_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='recall', ground_truth_method='spikein',
                                           refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    results_precision_df = metric_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, metric='precision', ground_truth_method='spikein',
                                          refsample='undiluted', muttype=muttype, methods=config.methods, fixedvar=fixedvar, xaxis=xaxis)
    #figure_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, xy='pr', ground_truth_method='spikein',
    #                    refsample='undiluted', muttype=muttype.upper(), chrom=chrom, methods=None, fixedvar=fixedvar, save=save)

In [None]:
np.unique([cn.split('_')[0] for cn in list(calltablesseries.columns[5:-1])]).astype(float)

#calltablesseries.columns

In [None]:
print(np.unique([cn.split('_')[0] for cn in list(calltablesseries.columns[5:])])[:-5].astype(float))
print(len(np.unique([cn.split('_')[0] for cn in list(calltablesseries.columns[5:])])[:-5].astype(float)))
calltablesseries.columns[5:]

In [None]:
print(np.unique([cn.split('_')[0] for cn in list(calltables['snv'].columns[5:])])[:-5].astype(float)) 

# Part II: (4) Load back metric results and plot combined metric plots

In [None]:
figure_curve_allchr(config, calltablesseries, dilutionseries, mixtureid, xy='pr', ground_truth_method=nref,
             refsample='undiluted', muttype=muttype.upper(), chrom=chrom, methods=None, fixedvar=fixedvar, save=save)

In [None]:
#plt.grid(linewidth=1)

fixedvar = 'coverage'
xaxis = 'vaf'
for mt in muttypes:
    if mt == 'snv':
        gtm = 5
        refname = 'inundilutedsamplebyatleast'+str(gtm)+'callers'
    else:  # elif mt == 'indel':
        gtm = 2
        refname = 'inundilutedsamplebyatleast'+str(gtm)+'callers'
    print(refname)
    for metric in metrics:
        # load results tables
        restables = {'snv': [], 'indel': []}
        for mixtureid in mixtureids:
            plasmasample = '_'.join(mixtureid.split('_')[:2])
            print(mixtureid, plasmasample)
            xa = xaxis if xaxis != 'tumor burden' else 'tb'
            print(xa)
            restable = pd.read_csv(os.path.join(*config.mixturefolder, 'mixtures_allchr', 'results', mixtureid+'_'+mt+'_'+metric+'_'+refname+'_fixed'+fixedvar+'_'+ xa +'.csv'), index_col=0)
            restable['plasma sample'] = plasmasample
            restables[mt].append(restable)
        restables[mt] = pd.concat(restables[mt])
        plot_metricsseries(config, restables, mixtureids, 'all', metric=metric, muttype=mt,
                           ground_truth_method='mixture', fixedvar=fixedvar, refname=refname, allpatients=True, logscale=False, save=True)
        plot_metricsseries(config, restables, mixtureids, 'all', metric=metric, muttype=mt,
                           ground_truth_method='mixture', fixedvar=fixedvar, refname=refname, allpatients=True, logscale=True, save=True)
        plot_metricsseries(config, restables, mixtureids, 'all', metric=metric, muttype=mt,
                           ground_truth_method='mixture', fixedvar=fixedvar, refname=refname, allpatients=False, logscale=False, save=True)
        plot_metricsseries(config, restables, mixtureids, 'all', metric=metric, muttype=mt,
                               ground_truth_method='mixture', fixedvar=fixedvar, refname=refname, allpatients=False, logscale=True, save=True)