In [None]:
import gtools
import sys
import os
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'
import pandas as pd
import csv
from tqdm import tqdm
import numpy as np
import math
import argparse
import re

In [None]:
parser = argparse.ArgumentParser(description='''
Visualizes how adjusting the filter thresholds of filter_duplex_variants.py will affect precision and recall. Uses 
rate of C>T mutations and mutations called in "swapped" control samples to estimate the false positive rate and number 
of true positives passing filters. Default values and tested range of filters can be modified by passing values 
DEFAULT,START:END:STEP to the filter arguments. END is non-inclusive. False positive rate of a given set of filters is 
calculated as (passing variants in swapped samples / callable coverage of swapped samples) / (passing variants in real 
samples / callable coverage of real samples). Note: the callable coverage is static and does not adapt to changes in 
filters, so it's best to use swapped samples made from the real samples to minimize any error introduced by this. 
True positive count is estimated as the number of passing variants in real samples * (1 - false positive rate). C>T 
rate refers to the fraction of passing variants in real samples which are C>T or G>A.|n
Note: this script may require a large amount of memory, as it needs to hold all the input TSVs in memory. A preliminary 
filter K is applied to reduce memory usage, so setting the END value of -K lower will reduce memory footprint.
    ''')
parser.add_argument('-v', '--real_vars', required=True, dest='real_vars', metavar='TSV1', nargs='+', type=str,
                   help='variant TSV(s) of "real" (non-swapped) samples output by add_duplex_filter_columns.py')
parser.add_argument('-V', '--swapped_vars', required=True, dest='swapped_vars', metavar='TSV1', nargs='+', type=str,
                   help='variant TSV(s) output running add_duplex_filter_columns.py on "swapped" samples generated \
                   by duplex_strand_swapper.py')
parser.add_argument('-w', '--real_cov', required=True, dest='real_cov', metavar='PREFIX1', nargs='+', type=str,
                   help='prefix of coverage arrays for the same files as -v, generated by duplex_coverage.py')
parser.add_argument('-W', '--swapped_cov', required=True, dest='swapped_cov', metavar='PREFIX1', nargs='+', type=str,
                   help='prefix of coverage arrays for the same files as -V, generated by duplex_coverage.py')
parser.add_argument('-o', '--output', required=True, dest='output', metavar='FILE', type=str,
                   help='name of the output figure (file format is inferred)')

parser.add_argument('-A', '--min_frac_init', dest='min_frac_init', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='0.24,0:1:0.05',
                   help='initial minimum fraction of supporting reads per strand (default: 0.24,0:1:0.05)')
parser.add_argument('-B', '--min_strand_cov', dest='min_strand_cov', metavar='INT,INT:INT:INT', type=str, default='2,0:10:1',
                   help='minimum reads per strand overlapping variant (default: 2,0:10:1)')
parser.add_argument('-C', '--min_total_cov', dest='min_total_cov', metavar='INT,INT:INT:INT', type=str, default='6,0:20:1',
                   help='minimum total reads overalpping variant (default: 6,0:20:1)')
parser.add_argument('-D', '--min_high_bq', dest='min_high_bq', metavar='INT,INT:INT:INT', type=str, default='4,0:20:1',
                   help='minimum supporting reads with high BQ (>=30) (default: 4,0:20:1)')
parser.add_argument('-E', '--min_avg_mq', dest='min_avg_mq', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='10,0:44:1',
                   help='minimum average MQ of supporting reads (default: 10,0:44:1)')
parser.add_argument('-F', '--max_end_mismatch', dest='max_end_mismatch', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='0.21,0:1:0.02',
                   help='maximum mismatch rate between a position and a fragment end before clipping (default: 0.21,0:1:0.02)')
parser.add_argument('-G', '--min_indel_end_dist', dest='min_indel_end_dist', metavar='INT,INT:INT:INT', type=str, default='6,0:20:1',
                   help='minimum distance between indels and a fragment end (default: 6,0:20:1)')
parser.add_argument('-H', '--min_cov_percentile', dest='min_cov_percentile', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='0.007,0:0.2:0.01',
                   help='minimum sequencing coverage percentile of the variant position (default: 0.007,0:0.2:0.01)')
parser.add_argument('-I', '--max_poly_at', dest='max_poly_at', metavar='INT,INT:INT:INT', type=str, default='7,0:15:1',
                   help='maximum length of a poly-A/T repeat ending at the variant position (default: 7,0:10:1)')
parser.add_argument('-J', '--max_poly_rep', dest='max_poly_rep', metavar='INT,INT:INT:INT', type=str, default='4,0:10:1',
                   help='maximum length of a poly-di/trinucleotide repeat end at the variant position (default: 4,0:10:1)')
parser.add_argument('-K', '--max_control_sup', dest='max_control_sup', metavar='INT,INT:INT:INT', type=str, default='3,0:40:1',
                   help='maximum number of control sample fragments supporting the variant. The END value of this argument \
                   affects how much memory the script requires with lower values reducing memory footprint (default: 3,0:40:1)')
parser.add_argument('-L', '--max_frag_len', dest='max_frag_len', metavar='INT,INT:INT:INT', type=str, default='300,100:2000:100',
                   help='maximum fragment length (default: 300,100:2000:100)')
parser.add_argument('-M', '--max_contamination_dist', dest='max_contamination_dist', metavar='INT,INT:INT:INT', type=str, default='4,0:10:1',
                   help='maximum distance between two variants passing prior filters (A-L) within a single fragment \
                    (default: 4,0:10:1)')
parser.add_argument('-N', '--min_frac_final', dest='min_frac_final', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='0.76,0:1:0.05',
                   help='final minimum fraction of supporting reads per strand (default: 0.76,0:1:0.05)')
parser.add_argument('-O', '--min_hamming_dist', dest='min_hamming_dist', metavar='INT,INT:INT:INT', type=str, default='0,0:10:1',
                   help='minimum hamming distance between k-mer starting at the fragment position and any other k-mer \
                   in the genome (default: 0,0:10:1)')
parser.add_argument('-P', '--max_cov_percentile', dest='max_cov_percentile', metavar='FLOAT,FLOAT:FLOAT:FLOAT', type=str, default='1,0.8:1.02:0.01',
                   help='maximum sequencing coverage percentile of the variant position (default: 1, i.e. nothing filtered)')
parser.add_argument('-Q', '--require_overlap', dest='require_overlap', action='store_true',
                   help='whether to require support from both forward and reverse mapping reads, i.e. variant within \
                   the read1-read2 overlap. Unlike the other filters, this one is just a flag which determines whether \
                   to require overlap while varying the other filters (default: do not require)')
parser.add_argument('-k', '--keep_dups', dest='keep_dups', action='store_true', 
                   help='by default, filtered mutations observed in mutliple fragments within a TSV (-v or -V) will only count \
                   as a single mutation. Setting this flag will instead count each observation of the mutation. Setting this flag \
                   may be useful if TSVs were merged before being passed to -v or -V, and duplicates could represent real \
                   mutations or false positives. Note: if the same mutation is found in three different TSVs, it will always count as \
                   three mutations.')

if gtools.running_notebook():
    print('Determined code is running in Jupyter')
    if os.getcwd()[:8] != '/scratch': # switch to the scratch directory where all the data files are
        os.chdir(f'/scratch/cam02551/{os.getcwd().split("/")[-2]}')
        
    args = parser.parse_args('--real_vars data/variant/untreated_1_merged_added.tsv data/variant/untreated_2_merged_added.tsv data/variant/untreated_3_merged_added.tsv data/variant/untreated_4_merged_added.tsv data/variant/untreated_5_merged_added.tsv data/variant/untreated_6_merged_added.tsv data/variant/untreated_7_merged_added.tsv data/variant/untreated_8_merged_added.tsv --swapped_vars data/variant/duplex_swapped_0_added.tsv data/variant/duplex_swapped_1_added.tsv data/variant/duplex_swapped_2_added.tsv data/variant/duplex_swapped_3_added.tsv --real_cov data/coverage/untreated_1_merged_duplex/ data/coverage/untreated_2_merged_duplex/ data/coverage/untreated_3_merged_duplex/ data/coverage/untreated_4_merged_duplex/ data/coverage/untreated_5_merged_duplex/ data/coverage/untreated_6_merged_duplex/ data/coverage/untreated_7_merged_duplex/ data/coverage/untreated_8_merged_duplex/ --swapped_cov data/coverage/duplex_swapped_0_duplex/ data/coverage/duplex_swapped_1_duplex/ data/coverage/duplex_swapped_2_duplex/ data/coverage/duplex_swapped_3_duplex/ --output data/metadata/duplex_filters.svg'.split()) # used for testing
    # args = parser.parse_args('-v tmp/added_info_Col-0-1.tsv -V tmp/added_info_Col-0-swapped-1.tsv -w data/coverage/2f1r/big_Col-0-1_ -W data/coverage/2f1r/big_Col-0-swapped-1_ -o tmp/visualize_testing.png'.split()) # used for testing
else: # run this if in a terminal
    args = parser.parse_args()
    
sys.stderr.write('Running visualize_duplex_filters.py with arguments:\n' + '\n'.join([f'{key}={val}' for key, val in vars(args).items()]) + '\n')

if args.output and '/' in args.output:
    os.makedirs(os.path.dirname(args.output), exist_ok=True)

In [None]:
param_to_desc = {
    'min_frac_init': 'initial min fraction of support per strand (A)',
    'min_strand_cov': 'min read coverage per strand (B)',
    'min_total_cov': 'min total read coverage (C)',
    'min_high_bq': 'min support from high BQ (>30) reads (D)',
    'min_avg_mq': 'min average MQ of supporting reads (E)',
    'max_end_mismatch': 'max mismatch rate before an end is trimmed (F)',
    'min_indel_end_dist': 'min distance from an indel to fragment end (G)',
    'min_cov_percentile': 'min coverage percentile of genomic site (H)',
    'max_poly_at': 'max length of poly-A/T repeat (I)',
    'max_poly_rep': 'max length of di/trinucleotide repeat (J)',
    'max_control_sup': 'max (non-sequencing error) occurences in control samples (K)',
    'max_frag_len': 'max fragment length (L)',
    'max_contamination_dist': 'max distance between any 2 variants in a fragment (M)',
    'min_frac_final': 'final min fraction of support per strand (N)',
    'min_hamming_dist': 'min hamming distance to another fragment in the genome (O)',
    'max_cov_percentile': 'max coverage percentile of genomic site (P)',
    'require_overlap': 'require (1) or don\'t require (0) read1-read2 overlap (Q)',
}

In [None]:
# extract the info from the filter arguments
sliders = dict() # key is filter, value is list of cutoffs
defaults = dict()
vargs = vars(args)
arg_pattern = re.compile('^([0-9.]+),([0-9.]+):([0-9.]+):([0-9.]+)$')
for fil in param_to_desc:
    # treat require_overlap filter differently
    if fil == 'require_overlap':
        sliders[fil] = [False, True]
        defaults[fil] = args.require_overlap
        continue
    
    # extract out DEFAULT,START:END:STEP from the filter argument
    match = re.fullmatch('^([0-9.]+),([0-9.]+):([0-9.]+):([0-9.]+)$', vargs[fil])
    if match is None:
        sys.stderr.write(f'ERROR: filter argument for --{fil} must be in format "DEFAULT,START:END:STEP"')
    groups = match.groups()
    
    # convert to ints or floats
    if '.' in vargs[fil]:
        groups = [float(g) for g in groups]
        groups[2]
    else:
        groups = [int(g) for g in groups]
    
    defaults[fil] = groups[0]
    
    # get the variable cutoffs for the filter
    cutoffs = [groups[1]]
    while cutoffs[-1] + groups[3] < groups[2]:
        cutoffs.append(cutoffs[-1] + groups[3])
    sliders[fil] = cutoffs


In [None]:
dfs_real = []
for f in args.real_vars:
    init_count = 0
    to_cat = []
    for df_chunk in pd.read_table(f, quoting=csv.QUOTE_NONE, chunksize=100000, dtype={'mq':str, 'bq':str}):
        init_count += len(df_chunk)
        df_chunk = df_chunk[df_chunk.control_sup <= max(sliders['max_control_sup'])]
        to_cat.append(df_chunk.copy())
    df = pd.concat(to_cat)
    df = df.sort_values('chrom frag_start frag_len frag_umi pos ref alt'.split()) # first by fragment (filter_vars will do this, but doing it here keeps the order stable)
    dfs_real.append(df)
    sys.stderr.write(f'Loaded {init_count} variants from real file {f} and kept {len(df)} after pre-filtering on control_sup <= {max(sliders["max_control_sup"])}\n')

In [None]:
dfs_swapped = []
for f in args.swapped_vars:
    init_count = 0
    to_cat = []
    for df_chunk in pd.read_table(f, quoting=csv.QUOTE_NONE, chunksize=100000, dtype={'mq':str, 'bq':str}):
        init_count += len(df_chunk)
        df_chunk.control_sup -= 2 # don't let the fragment contribute to its own control support
        df_chunk = df_chunk[df_chunk.control_sup <= max(sliders['max_control_sup'])]
        to_cat.append(df_chunk.copy())
    df = pd.concat(to_cat)
    df = df.sort_values('chrom frag_start frag_len frag_umi pos ref alt'.split()) # first by fragment (filter_vars will do this, but doing it here keeps the order stable)
    dfs_swapped.append(df)
    sys.stderr.write(f'Loaded {init_count} variants from swapped file {f} and kept {len(df)} after pre-filtering on control_sup <= {max(sliders["max_control_sup"])}\n')

In [None]:
chroms = set(dfs_real[0].chrom)
sys.stderr.write(f'loading coverage of contigs {chroms}\n')
real_cov = 0
swapped_cov = 0
for chrom in set(dfs_real[0].chrom):
    for prefix in args.real_cov:
        real_cov += np.sum(np.load(f'{prefix}{chrom}.npy'))
    for prefix in args.swapped_cov:
        swapped_cov += np.sum(np.load(f'{prefix}{chrom}.npy'))
sys.stderr.write(f'coverage of real samples is {real_cov}bp, coverage of swapped samples is {swapped_cov}bp\n')

In [None]:
# this is all copied from filter_duplex_variants, but now uses a passed dictionary, copies the df, and doesn't sort

# returns a boolean list of whether each variant in the df passes filter G
def passes_g(df, filter_cutoffs):
    G = []
    for r in df.itertuples():
        if r.ref != '*' and r.alt != '*':
            g = True
        elif r.alt == '*' and r.pos > r.frag_start + r.frag_len // 2: # if a deletion nearer the end of the fragment
            g = r.frag_start + r.frag_len - r.pos - len(r.alt) >= filter_cutoffs['min_indel_end_dist'] # count distance from the last base of the deletion
        else:
            g = r.pos - r.frag_start >= filter_cutoffs['min_indel_end_dist'] and r.frag_start + r.frag_len - r.pos - 1 >= filter_cutoffs['min_indel_end_dist']
        G.append(g)
    return G

# Returs a boolean list of whether each variant in the df passes filter M
# Variants fail this filter if their fragment has 2 variants >max_dist bp apart
# max_dist=0 never allows more than one variant in a fragment, max_dist=1 allows a pair of back-to-back variants
# input df must be sorted by chrom, frag_start, frag_len, frag_umi, pos
def passes_m(df, filter_cutoffs):
    # print(df.columns)
    df.frag_umi = df.frag_umi.fillna('') # need to set NAs to something else, as NA != NA
    rows = df.itertuples()
    prev_frag = ('')
    keep = []
    poss = []
    while True:
        r = next(rows, None)
        if r is None or (r.chrom, r.frag_start, r.frag_len, r.frag_umi) != prev_frag:
            keep += [poss[-1] - poss[0] <= filter_cutoffs['max_contamination_dist']] * len(poss) if len(poss) > 0 else []
            if r is None:
                return keep
            prev_frag = (r.chrom, r.frag_start, r.frag_len, r.frag_umi)
            poss = []
        poss.append(r.pos)

def filter_vars(df, filter_cutoffs):
    df = df.copy()
    # add columns for whether each filter is passed
    df['A'] = (df.f_read1_sup / df.f_read1_cov >= filter_cutoffs['min_frac_init']) & (df.r_read1_sup / df.r_read1_cov >= filter_cutoffs['min_frac_init'])
    df['B'] = (df.f_read1_cov >= filter_cutoffs['min_strand_cov']) & (df.r_read1_cov >= filter_cutoffs['min_strand_cov'])
    df['C'] = (df.f_read1_cov + df.r_read1_cov >= filter_cutoffs['min_total_cov'])
    df['D'] = (df.high_sup >= filter_cutoffs['min_high_bq'])
    df['E'] = (df.avg_mq >= filter_cutoffs['min_avg_mq'])
    df['F'] = (df.end_mismatch_rate <= filter_cutoffs['max_end_mismatch'])
    # G is slow so I moved it later
    df['H'] = (df.cov_per >= filter_cutoffs['min_cov_percentile'])
    df['I'] = (df.poly_at <= filter_cutoffs['max_poly_at'])
    df['J'] = (df.poly_rep <= filter_cutoffs['max_poly_rep'])
    df['K'] = (df.control_sup <= filter_cutoffs['max_control_sup'])
    df['L'] = (df.frag_len <= filter_cutoffs['max_frag_len'])
    # skip M for now because it requires a slow sort
    df['N'] = (df.f_read1_sup / df.f_read1_cov >= filter_cutoffs['min_frac_final']) & (df.r_read1_sup / df.r_read1_cov >= filter_cutoffs['min_frac_final'])
    df['O'] = (df.dup_dist >= filter_cutoffs['min_hamming_dist'])
    df['P'] = (df.cov_per <= filter_cutoffs['max_cov_percentile'])
    df['Q'] = ((df.f_sup > 0) & (df.r_sup > 0)) if filter_cutoffs['require_overlap'] else True
    
    df = df[df.A & df.B & df.C & df.D & df.E & df.F & df.H & df.I & df.J & df.K & df.L]
    df['G'] = passes_g(df, filter_cutoffs)
    if len(df) > 0:
        df = df[df.G] # df would lose its columns if this were run on an empty df
    
    # add a column for filter M
    # df = df.sort_values('chrom frag_start frag_len frag_umi pos ref alt'.split()) # not needed because we sort above
    # print(df.columns)
    df['M'] = passes_m(df, filter_cutoffs)
    if len(df) > 0:
        df = df[df.M & df.N & df.O & df.P & df.Q]
    return df

In [None]:
real_counts = {fil:[] for fil in sliders}
swapped_counts = {fil:[] for fil in sliders}
ct_counts = {fil:[] for fil in sliders}

for sliding_fil in sliders:
    for cutoff in tqdm(sliders[sliding_fil], position=0, desc='varying cutoff for ' + sliding_fil):
        adjusted_cutoffs = defaults.copy()
        adjusted_cutoffs[sliding_fil] = cutoff
        
        real_count = 0
        swapped_count = 0
        ct_count = 0
        for df in dfs_real:
            df_fil = filter_vars(df, adjusted_cutoffs)
            if not args.keep_dups:
                df_fil = df_fil.drop_duplicates('chrom pos ref alt'.split())
            real_count += len(df_fil)
            ct_count += sum(((df_fil.ref == 'C') & (df_fil.alt == 'T')) | ((df_fil.ref == 'G') & (df_fil.alt == 'A')))
        for df in dfs_swapped:
            df_fil = filter_vars(df, adjusted_cutoffs)
            if not args.keep_dups:
                df_fil = df_fil.drop_duplicates('chrom pos ref alt'.split())
            swapped_count += len(df_fil)
        real_counts[sliding_fil].append(real_count)
        swapped_counts[sliding_fil].append(swapped_count)
        ct_counts[sliding_fil].append(ct_count)


In [None]:
fig, axs = plt.subplots(math.ceil(len(sliders) / 3), 3, figsize=(20, 15), gridspec_kw={'hspace':0.33}, sharey=True)
for i, fil in enumerate(sliders):
    tp_rates = []
    fp_rates = []
    ct_rates = []
    for j in range(len(sliders[fil])):
        fp_rate = (swapped_counts[fil][j] / swapped_cov) / (real_counts[fil][j] / real_cov) if real_counts[fil][j] > 0 else np.nan
        tp_rates.append(max(0, real_counts[fil][j] * (1 - fp_rate)))
        fp_rates.append(fp_rate)
        ct_rates.append(ct_counts[fil][j] / real_counts[fil][j] if real_counts[fil][j] > 0 else np.nan)
    ct_depletion = [max(ct_rates) - x for x in ct_rates]
    
    # cap FP rate and C>T depletion at 0.2
    fp_rates = [min(0.2, x) if not np.isnan(x) else np.nan for x in fp_rates]
    ct_depletion = [min(0.2, x) if not np.isnan(x) else np.nan for x in ct_depletion]
        
    row = i // 3
    col = i % 3
    ax_count = axs[row, col].twinx()
    a, = axs[row, col].plot(sliders[fil], fp_rates, c=plt.get_cmap('tab20')(6))
    b, = axs[row, col].plot(sliders[fil], ct_depletion, c=plt.get_cmap('tab20')(7))
    c, = ax_count.plot(sliders[fil], tp_rates, c=plt.get_cmap('tab20')(0))
    axs[row, col].set_xlabel(param_to_desc[fil])
    d = axs[row, col].vlines(defaults[fil], ymin=-1, ymax=1, color='black')
    
    
axs[0, 0].set_ylim(0, 0.205)
fig.legend([a, b, c, d], ['FP rate (est swapped samples)', 'C>T rate below max (max rate within plot - rate at threshold)', 'TPs detected (est swapped samples)', 'selected cutoff'], loc='upper center')
fig.savefig(args.output, dpi=300, bbox_inches='tight')

In [None]:
sys.stderr.write(f'completed visualize_duplex_filters.py with output {args.output}\n')