Skip to content

Commit

Permalink
Merge f9f7e7d into f732d3d
Browse files Browse the repository at this point in the history
  • Loading branch information
ernfrid committed Mar 9, 2018
2 parents f732d3d + f9f7e7d commit 61438be
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 45 deletions.
151 changes: 109 additions & 42 deletions svtools/sv_classifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#!/usr/bin/env python

import argparse, sys, copy, gzip, time, math, re
import argparse, sys, copy, gzip, math
import numpy as np
import pandas as pd
from scipy import stats
from collections import Counter, defaultdict, namedtuple
from collections import namedtuple
import statsmodels.formula.api as smf
from operator import itemgetter
import warnings
from svtools.vcf.file import Vcf
from svtools.vcf.genotype import Genotype
from svtools.vcf.variant import Variant
import svtools.utils as su

Expand All @@ -22,7 +21,7 @@
def mad(arr):
""" Median Absolute Deviation: a "Robust" version of standard deviation.
Indices variabililty of the sample.
https://en.wikipedia.org/wiki/Median_absolute_deviation
https://en.wikipedia.org/wiki/Median_absolute_deviation
"""
arr = np.ma.array(arr).compressed() # should be faster to not use masked arrays.
med = np.median(arr)
Expand Down Expand Up @@ -125,7 +124,7 @@ def annotation_intersect(var, ae_dict, threshold):

# dictionary with number of bases of overlap for each class
class_overlap = {}

# first check for reciprocal overlap
if var.chrom in ae_dict:
var_start = var.pos
Expand Down Expand Up @@ -170,12 +169,12 @@ def lld(xx, mean, sd):
ll = 1 / sd * math.exp(-(xx-mean) * (xx-mean) / (2*sd*sd))
return ll

def calc_params(vcf_path):
def calc_params(vcf_path, sex_chrom_names):

tSet = list()
epsilon=0.1
header=[]

in_header = True
vcf = Vcf()
if vcf_path.endswith('.gz'):
Expand All @@ -201,17 +200,28 @@ def calc_params(vcf_path):
svtype = x.split('=')[1]
break

if svtype not in ['DEL', 'DUP'] or v[0]=="X" or v[0]=="Y":
if svtype not in ['DEL', 'DUP'] or v[0] in sex_chrom_names:
continue

var = Variant(v, vcf)

for sample in vcf_samples:
sample_genotype = var.genotype(sample)
if sample_genotype.get_format('GT') != './.':
log2r = math.log((float(sample_genotype.get_format('CN'))+ epsilon)/2,2) #to avoid log(0)
tSet.append(CN_rec(var.var_id, sample, var.info['SVTYPE'], abs(float(var.info['SVLEN'])), var.info['AF'],
sample_genotype.get_format('GT'), sample_genotype.get_format('CN'), sample_genotype.get_format('AB'), math.log(abs(float(var.info['SVLEN']))), log2r))
tSet.append(
CN_rec(
var.var_id,
sample,
var.info['SVTYPE'],
abs(float(var.info['SVLEN'])),
var.info['AF'],
sample_genotype.get_format('GT'),
sample_genotype.get_format('CN'),
sample_genotype.get_format('AB'),
math.log(abs(float(var.info['SVLEN']))), log2r
)
)

df=pd.DataFrame(tSet, columns=CN_rec._fields)
#exclude from training data, DELs and DUPs with CN in the tails of the distribution
Expand Down Expand Up @@ -241,9 +251,9 @@ def calc_params(vcf_path):
df1.loc[:,'log2r_adj']=df1.loc[:,'log2r']
df1=df1.append(small_dels)
params=df1.groupby(['sample', 'svtype', 'GT'])['log2r_adj'].aggregate([np.mean,np.var, len]).reset_index()
params=pd.pivot_table(params, index=['sample', 'svtype'], columns='GT', values=['mean', 'var', 'len']).reset_index()
params=pd.pivot_table(params, index=['sample', 'svtype'], columns='GT', values=['mean', 'var', 'len']).reset_index()
params.columns=['sample', 'svtype', 'mean0', 'mean1', 'mean2', 'var0', 'var1', 'var2', 'len0', 'len1', 'len2']
params['std_pooled']=np.sqrt((params['var0']*params['len0']+params['var1']*params['len1']+params['var2']*params['len2'])/(params['len0']+params['len1']+params['len2']))
params['std_pooled'] = np.sqrt((params['var0']*params['len0']+params['var1']*params['len1']+params['var2']*params['len2'])/(params['len0']+params['len1']+params['len2']))
#params.to_csv('./params.csv')
return (params, het_del_fit, hom_del_fit)

Expand All @@ -253,13 +263,13 @@ def rd_support_nb(temp, p_cnv):
temp = pd.merge(temp, tr, on='GT', how='left')
temp['p_mix'] = temp['lld0'] * temp['p0'] + temp['lld1'] * temp['p1'] + temp['lld2'] * temp['p2']
return np.log(p_cnv)+np.sum(np.log(temp['p_mix'])) > np.log(1-p_cnv)+np.sum(np.log(temp['lld0']))


def has_rd_support_by_nb(test_set, het_del_fit, hom_del_fit, params, p_cnv = 0.5):
svtype=test_set['svtype'][0]
svlen=test_set['svlen'][0]
log_len=test_set['log_len'][0]

if svtype == 'DEL' and svlen<1000:
params1=params[params.svtype=='DEL'].copy()
if svlen<50:
Expand Down Expand Up @@ -306,21 +316,21 @@ def has_rd_support_by_nb(test_set, het_del_fit, hom_del_fit, params, p_cnv = 0.5
mm.loc[:,'lld0'] = mm.apply(lambda row:lld(row["log2r"], row["mean0"],row["std_pooled"]), axis=1)
mm.loc[:,'lld1'] = mm.apply(lambda row:lld(row["log2r"], row["mean1_adj"],row["std_pooled"]), axis=1)
mm.loc[:,'lld2'] = mm.apply(lambda row:lld(row["log2r"], row["mean2_adj"],row["std_pooled"]), axis=1)

return rd_support_nb(mm, p_cnv)


def load_df(var, exclude, sex):
epsilon=0.1
def load_df(var, exclude, sex, sex_chrom_names):

epsilon = 0.1
test_set = list()

for s in var.sample_list:
if s in exclude:
continue
cn = var.genotype(s).get_format('CN')
if (var.chrom == 'X' or var.chrom == 'Y') and sex[s] == 1:
cn=str(float(cn)*2)
if (var.chrom in sex_chrom_names) and sex[s] == 1:
cn = str(float(cn) * 2)
log2r = math.log((float(cn)+epsilon)/2, 2) # to avoid log(0)
test_set.append(CN_rec(var.var_id, s, var.info['SVTYPE'], abs(float(var.info['SVLEN'])), var.info['AF'],
var.genotype(s).get_format('GT'), cn , var.genotype(s).get_format('AB'), math.log(abs(float(var.info['SVLEN']))), log2r))
Expand All @@ -332,7 +342,7 @@ def load_df(var, exclude, sex):
def has_low_freq_depth_support(test_set, mad_threshold=2, absolute_cn_diff=0.5):

mad_quorum = 0.5 # this fraction of the pos. genotyped results must meet the mad_threshold

hom_ref_cn=test_set[test_set.GT=="0/0"]['CN'].values.astype(float)
hom_het_alt_cn=test_set[(test_set.GT=="0/1") | (test_set.GT=="1/1")]['CN'].values.astype(float)

Expand All @@ -353,7 +363,7 @@ def has_low_freq_depth_support(test_set, mad_threshold=2, absolute_cn_diff=0.5):
#if test_set['svtype'][0]=='DEL':
if test_set.loc[0, 'svtype']=='DEL':
resid=-resid

resid=resid[(resid > (cn_mad * mad_threshold) ) & (resid>absolute_cn_diff)]

if float(len(resid))/len(hom_het_alt_cn)>mad_quorum:
Expand All @@ -363,10 +373,10 @@ def has_low_freq_depth_support(test_set, mad_threshold=2, absolute_cn_diff=0.5):

# test whether variant has read depth support by regression
def has_high_freq_depth_support(df, slope_threshold, rsquared_threshold):

rd = df[[ 'AB', 'CN']][df['AB']!='.'].values.astype(float)
if len(np.unique(rd[:,0])) > 1 and len(np.unique(rd[:,1])) > 1:

(slope, intercept, r_value, p_value, std_err) = stats.linregress(rd)
if df['svtype'][0] == 'DEL':
slope=-slope
Expand All @@ -387,7 +397,7 @@ def has_rd_support_by_ls(df, slope_threshold, rsquared_threshold, num_pos_samps,
return False

def has_rd_support_hybrid(df, het_del_fit, hom_del_fit, params, p_cnv, slope_threshold, rsquared_threshold, num_pos_samps):

hybrid_support=False
nb_support=has_rd_support_by_nb(df, het_del_fit, hom_del_fit, params, p_cnv)
ls_support=has_rd_support_by_ls(df, slope_threshold, rsquared_threshold, num_pos_samps)
Expand All @@ -402,7 +412,7 @@ def has_rd_support_hybrid(df, het_del_fit, hom_del_fit, params, p_cnv, slope_thr


# primary function
def sv_classify(vcf_in, vcf_out, gender_file, exclude_file, ae_dict, f_overlap, slope_threshold, rsquared_threshold, p_cnv, het_del_fit, hom_del_fit, params, diag_outfile, method):
def sv_classify(vcf_in, vcf_out, gender_file, sex_chrom_names, exclude_file, ae_dict, f_overlap, slope_threshold, rsquared_threshold, p_cnv, het_del_fit, hom_del_fit, params, diag_outfile, method):

vcf = Vcf()
header = []
Expand Down Expand Up @@ -443,7 +453,7 @@ def sv_classify(vcf_in, vcf_out, gender_file, exclude_file, ae_dict, f_overlap,
if svtype not in ['DEL', 'DUP']:
vcf_out.write(line)
continue

var = Variant(v, vcf)

# check intersection with mobile elements
Expand Down Expand Up @@ -473,15 +483,23 @@ def sv_classify(vcf_in, vcf_out, gender_file, exclude_file, ae_dict, f_overlap,
if num_pos_samps == 0:
vcf_out.write(line)
else:
df=load_df(var, exclude, sex)
if method=='large_sample':
df = load_df(var, exclude, sex, sex_chrom_names)
if method == 'large_sample':
ls_support = has_rd_support_by_ls(df, slope_threshold, rsquared_threshold, num_pos_samps)
has_rd_support=ls_support
elif method=='naive_bayes':
has_rd_support = ls_support
elif method == 'naive_bayes':
nb_support = has_rd_support_by_nb(df, het_del_fit, hom_del_fit, params, p_cnv)
has_rd_support=nb_support
elif method=='hybrid':
ls_support, nb_support, hybrid_support = has_rd_support_hybrid(df, het_del_fit, hom_del_fit, params, p_cnv, slope_threshold, rsquared_threshold, num_pos_samps)
has_rd_support = nb_support
elif method == 'hybrid':
ls_support, nb_support, hybrid_support = has_rd_support_hybrid(
df,
het_del_fit,
hom_del_fit,
params, p_cnv,
slope_threshold,
rsquared_threshold,
num_pos_samps
)
has_rd_support=hybrid_support

if has_rd_support:
Expand All @@ -492,7 +510,18 @@ def sv_classify(vcf_in, vcf_out, gender_file, exclude_file, ae_dict, f_overlap,

if diag_outfile is not None:
svlen=df['svlen'][0]
outf.write(var.var_id+"\t"+svtype+"\t"+str(svlen)+"\t"+str(num_pos_samps)+"\t"+str(nb_support)+"\t"+str(ls_support)+"\t"+str(hybrid_support)+"\t"+str(has_rd_support)+"\n")
outf.write(
'\t'.join((
var.var_id,
svtype,
str(svlen),
str(num_pos_samps),
str(nb_support),
str(ls_support),
str(hybrid_support),
str(has_rd_support)
)) + "\n"
)

vcf_out.close()
if diag_outfile is not None:
Expand Down Expand Up @@ -525,7 +554,20 @@ def get_ae_dict(ae_path):
ae_bedfile.close()
return ae_dict

def run_reclassifier(vcf_file, vcf_out, sex_file, ae_path, f_overlap, exclude_list, slope_threshold, rsquared_threshold, training_data, method, diag_outfile):
def run_reclassifier(
vcf_file,
vcf_out,
sex_file,
sex_chrom_names,
ae_path,
f_overlap,
exclude_list,
slope_threshold,
rsquared_threshold,
training_data,
method,
diag_outfile
):

ae_dict = None
params = None
Expand All @@ -536,16 +578,17 @@ def run_reclassifier(vcf_file, vcf_out, sex_file, ae_path, f_overlap, exclude_li
if ae_path is not None:
sys.stderr.write("loading annotations\n")
ae_dict=get_ae_dict(ae_path)

if(method!="large_sample"):
sys.stderr.write("calculating parameters\n")
#calculate per-sample CN profiles on training set
[params, het_del_fit, hom_del_fit]=calc_params(training_data)
[params, het_del_fit, hom_del_fit] = calc_params(training_data, sex_chrom_names)

sys.stderr.write("reclassifying\n")
sv_classify(vcf_file,
vcf_out,
sex_file,
sex_chrom_names,
exclude_list,
ae_dict,
f_overlap,
Expand All @@ -554,13 +597,19 @@ def run_reclassifier(vcf_file, vcf_out, sex_file, ae_path, f_overlap, exclude_li
p_cnv,
het_del_fit,
hom_del_fit,
params,
params,
diag_outfile,
method)

def chromosome_prefix(chrom):
if chrom.startswith('chr'):
# Add a non-chr version
return chrom[3:]
else:
return 'chr' + chrom

def add_arguments_to_parser(parser):
parser.add_argument('-i', '--input', metavar='<VCF>', default=None, help='VCF input')
#parser.add_argument('-i', '--input', metavar='<STRING>', dest='vcf_in', type=argparse.FileType('r'), default=None, help='VCF input [stdin]')
parser.add_argument('-o', '--output', metavar='<VCF>', dest='vcf_out', type=argparse.FileType('w'), default=sys.stdout, help='VCF output [stdout]')
parser.add_argument('-g', '--gender', metavar='<FILE>', dest='gender', type=argparse.FileType('r'), required=True, default=None, help='tab delimited file of sample genders (male=1, female=2)\nex: SAMPLE_A\t2')
parser.add_argument('-a', '--annotation', metavar='<BED>', dest='ae_path', type=str, default=None, help='BED file of annotated elements')
Expand All @@ -571,6 +620,7 @@ def add_arguments_to_parser(parser):
parser.add_argument('-t', '--tSet', metavar='<STRING>', dest='tSet', type=str, default=None, required=False, help='high quality deletions & duplications training dataset[vcf], required by naive Bayes reclassification')
parser.add_argument('-m', '--method', metavar='<STRING>', dest='method', type=str, default="large_sample", required=False, help='reclassification method, one of (large_sample, naive_bayes, hybrid)', choices=['large_sample', 'naive_bayes', 'hybrid'])
parser.add_argument('-d', '--diag_file', metavar='<STRING>', dest='diag_outfile', type=str, default=None, required=False, help='text file to output method comparisons')
parser.add_argument('--sex-chrom', metavar='<STRING>', default='chrX,chrY', help='Comma-separated list of sex chromosome names [chrX,chrY]')
parser.set_defaults(entry_point=run_from_args)

def description():
Expand All @@ -589,7 +639,24 @@ def run_from_args(args):
parser.print_help()
sys.exit(1)
with su.InputStream(args.input) as stream:
run_reclassifier(stream, args.vcf_out, args.gender, args.ae_path, args.f_overlap, args.exclude, args.slope_threshold, args.rsquared_threshold, args.tSet, args.method, args.diag_outfile)
sex_chrom_names = set(args.sex_chrom.strip().split(','))
for chrom in sex_chrom_names:
sex_chrom_names.add(chromosome_prefix(chrom))
sys.stderr.write('sex chromosome names are: {0}\n'.format(str(sex_chrom_names)))
run_reclassifier(
stream,
args.vcf_out,
args.gender,
sex_chrom_names,
args.ae_path,
args.f_overlap,
args.exclude,
args.slope_threshold,
args.rsquared_threshold,
args.tSet,
args.method,
args.diag_outfile
)

if __name__ == '__main__':
parser = command_parser()
Expand Down
13 changes: 10 additions & 3 deletions tests/reclassifier_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

class IntegrationTest_sv_classify(TestCase):

def test_chromosome_prefix(self):
self.assertEqual(svtools.sv_classifier.chromosome_prefix('chrBLAH'), 'BLAH')
self.assertEqual(svtools.sv_classifier.chromosome_prefix('BLAH'), 'chrBLAH')

def test_integration_nb(self):
test_directory = os.path.dirname(os.path.abspath(__file__))
test_data_dir = os.path.join(test_directory, 'test_data', 'sv_classifier')
Expand All @@ -22,8 +26,9 @@ def test_integration_nb(self):
diags_handle, diags_file = tempfile.mkstemp(suffix='.txt')
temp_descriptor, temp_output_path = tempfile.mkstemp(suffix='.vcf')
sex=open(sex_file, 'r')
sex_chrom_names = set(('X', 'Y'))
with gzip.open(input, 'rb') as input_handle, os.fdopen(temp_descriptor, 'w') as output_handle:
svtools.sv_classifier.run_reclassifier(input_handle, output_handle , sex , annot, 0.9, None, 1.0, 0.2, train, 'naive_bayes', diags_file)
svtools.sv_classifier.run_reclassifier(input_handle, output_handle, sex, sex_chrom_names, annot, 0.9, None, 1.0, 0.2, train, 'naive_bayes', diags_file)
expected_lines = gzip.open(expected_result, 'rb').readlines()
expected_lines[1] = '##fileDate=' + time.strftime('%Y%m%d') + '\n'
produced_lines = open(temp_output_path).readlines()
Expand All @@ -50,8 +55,9 @@ def test_integration_ls(self):
diags_handle, diags_file = tempfile.mkstemp(suffix='.txt')
temp_descriptor, temp_output_path = tempfile.mkstemp(suffix='.vcf')
sex=open(sex_file, 'r')
sex_chrom_names = set(('X', 'Y'))
with gzip.open(input, 'rb') as input_handle, os.fdopen(temp_descriptor, 'w') as output_handle:
svtools.sv_classifier.run_reclassifier(input_handle, output_handle , sex , annot, 0.9, None, 1.0, 0.2, train, 'large_sample', diags_file)
svtools.sv_classifier.run_reclassifier(input_handle, output_handle, sex, sex_chrom_names, annot, 0.9, None, 1.0, 0.2, train, 'large_sample', diags_file)
expected_lines = gzip.open(expected_result, 'rb').readlines()
expected_lines[1] = '##fileDate=' + time.strftime('%Y%m%d') + '\n'
produced_lines = open(temp_output_path).readlines()
Expand Down Expand Up @@ -79,8 +85,9 @@ def test_integration_hyb(self):
diags_handle, diags_file = tempfile.mkstemp(suffix='.txt')
temp_descriptor, temp_output_path = tempfile.mkstemp(suffix='.vcf')
sex=open(sex_file, 'r')
sex_chrom_names = set(('X', 'Y'))
with gzip.open(input, 'rb') as input_handle, os.fdopen(temp_descriptor, 'w') as output_handle:
svtools.sv_classifier.run_reclassifier(input_handle, output_handle , sex , annot, 0.9, None, 1.0, 0.2, train, 'hybrid', diags_file)
svtools.sv_classifier.run_reclassifier(input_handle, output_handle, sex, sex_chrom_names, annot, 0.9, None, 1.0, 0.2, train, 'hybrid', diags_file)
expected_lines = gzip.open(expected_result, 'rb').readlines()
expected_lines[1] = '##fileDate=' + time.strftime('%Y%m%d') + '\n'
produced_lines = open(temp_output_path).readlines()
Expand Down

0 comments on commit 61438be

Please sign in to comment.