Skip to content

Commit

Permalink
[vcf] speed up writing
Browse files Browse the repository at this point in the history
The 2 main changes here are changing the const sites list to a set to
improve lookup time and pre-computing the positions at each variable
site in a more efficient manner.

For my test VCF (100k rows, 100 samples) `write_vcf` went from 95
seconds to 0.7 seconds. For the full 800k rows VCF, `write_vcf` went
from ~8 hours to 5.6 seconds (5000x speedup).

Times & profiling using cProfile

CLoses nextstrain/augur#1378
  • Loading branch information
jameshadfield committed Dec 29, 2023
1 parent 1aae85e commit 8322d54
Showing 1 changed file with 23 additions and 28 deletions.
51 changes: 23 additions & 28 deletions treetime/vcf_utils.py
Expand Up @@ -411,6 +411,21 @@ def write_vcf(tree_dict, file_name, mask=None):#, compress=False):
positions = tree_dict['positions']
ploidy = tree_dict.get('metadata', {}).get('ploidy', 2)
chrom_name = tree_dict.get('metadata', {}).get('chrom', '1')
sample_names = list(sequences.keys())
inferred_const_sites = set(tree_dict.get('inferred_const_sites', []))

# For every variable site in sequences, flip the format around so
# we can have fast lookups later on.
alleles = {}
num_samples = len(sample_names)
for idx, name in enumerate(sample_names):
for posn, allele in sequences[name].items():
if posn not in alleles:
alleles[posn] = np.zeros(num_samples, dtype='U')
alleles[posn][idx] = allele
# fill in reference
for posn,bases in alleles.items():
bases[bases==''] = ref[posn]

def handleDeletions(i, pi, pos, ref, delete, pattern):
refb = ref[pi]
Expand Down Expand Up @@ -475,7 +490,7 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):


#prepare the header of the VCF & write out
header=["#CHROM","POS","ID","REF","ALT","QUAL","FILTER","INFO","FORMAT"]+list(sequences.keys())
header=["#CHROM","POS","ID","REF","ALT","QUAL","FILTER","INFO","FORMAT"]+sample_names

opn = gzip.open if file_name.endswith(('.gz', '.GZ')) else open
out_file = opn(file_name, 'w')
Expand Down Expand Up @@ -508,28 +523,8 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):
i+=1
continue

#try/except is much more efficient than 'if' statements for constructing patterns,
#as on average a 'variable' location will not be variable for any given sequence
pattern = []
#pattern2 gets the pattern at next position to check for upcoming deletions
#it's more efficient to get both here rather than loop through sequences twice!
pattern2 = []
for k,v in sequences.items():
try:
pattern.append(sequences[k][pi])
except KeyError:
pattern.append(ref[pi])

try:
pattern2.append(sequences[k][pi+1])
except KeyError:
try:
pattern2.append(ref[pi+1])
except IndexError:
pass

pattern = np.array(pattern).astype('U')
pattern2 = np.array(pattern2).astype('U')
pattern = alleles[pi] if pi in alleles else np.array([]).astype('U')
pattern2 = alleles[pi+1] if pi+1 in alleles else np.array([]).astype('U')

#If a deletion here, need to gather up all bases, and position before
if any(pattern == '-'):
Expand Down Expand Up @@ -564,17 +559,13 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):
for u in uniques:
pattern[np.where(pattern==u)[0]] = str(j)
j+=1
#Now convert these calls to a VCF format matching the ploidy.
#In case ploidy>1, we treat it as unphased ('/' as the separator)
#Note that this includes patterns of "." (no-calls)
calls = [ "/".join([j]*ploidy) for j in pattern ]

#What if there's no variation at a variable site??
#This can happen when sites are modified by TreeTime - see below.
printPos = True
if len(uniques)==0:
#If we expect it (it was made constant by TreeTime), it's fine.
if 'inferred_const_sites' in tree_dict and pi in tree_dict['inferred_const_sites']:
if pi in inferred_const_sites:
explainedErrors += 1
printPos = False #and don't output position to the VCF
else:
Expand All @@ -584,6 +575,10 @@ def handleDeletions(i, pi, pos, ref, delete, pattern):
#Write it out - Increment positions by 1 so it's in VCF numbering
#If no longer variable, and explained, don't write it out
if printPos:
#Now convert these calls to a VCF format matching the ploidy.
#In case ploidy>1, we treat it as unphased ('/' as the separator)
#Note that this includes patterns of "." (no-calls)
calls = [ "/".join([j]*ploidy) for j in pattern ]
output = [chrom_name, str(pos), ".", refb, ",".join(uniques), ".", "PASS", ".", "GT"] + calls
vcfWrite.append("\t".join(output))

Expand Down

0 comments on commit 8322d54

Please sign in to comment.