diff --git a/treetime/vcf_utils.py b/treetime/vcf_utils.py index 348af91d..eb3c17d5 100644 --- a/treetime/vcf_utils.py +++ b/treetime/vcf_utils.py @@ -411,6 +411,21 @@ def write_vcf(tree_dict, file_name, mask=None):#, compress=False): positions = tree_dict['positions'] ploidy = tree_dict.get('metadata', {}).get('ploidy', 2) chrom_name = tree_dict.get('metadata', {}).get('chrom', '1') + sample_names = list(sequences.keys()) + inferred_const_sites = set(tree_dict.get('inferred_const_sites', [])) + + # For every variable site in sequences, flip the format around so + # we can have fast lookups later on. + alleles = {} + num_samples = len(sample_names) + for idx, name in enumerate(sample_names): + for posn, allele in sequences[name].items(): + if posn not in alleles: + alleles[posn] = np.zeros(num_samples, dtype='U') + alleles[posn][idx] = allele + # fill in reference + for posn,bases in alleles.items(): + bases[bases==''] = ref[posn] def handleDeletions(i, pi, pos, ref, delete, pattern): refb = ref[pi] @@ -475,7 +490,7 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): #prepare the header of the VCF & write out - header=["#CHROM","POS","ID","REF","ALT","QUAL","FILTER","INFO","FORMAT"]+list(sequences.keys()) + header=["#CHROM","POS","ID","REF","ALT","QUAL","FILTER","INFO","FORMAT"]+sample_names opn = gzip.open if file_name.endswith(('.gz', '.GZ')) else open out_file = opn(file_name, 'w') @@ -509,28 +524,9 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): i+=1 continue - #try/except is much more efficient than 'if' statements for constructing patterns, - #as on average a 'variable' location will not be variable for any given sequence - pattern = [] - #pattern2 gets the pattern at next position to check for upcoming deletions - #it's more efficient to get both here rather than loop through sequences twice! - pattern2 = [] - for k,v in sequences.items(): - try: - pattern.append(sequences[k][pi]) - except KeyError: - pattern.append(ref[pi]) - - try: - pattern2.append(sequences[k][pi+1]) - except KeyError: - try: - pattern2.append(ref[pi+1]) - except IndexError: - pass - - pattern = np.array(pattern).astype('U') - pattern2 = np.array(pattern2).astype('U') + # patterns will be empty if there was no variation in the sequences + pattern = alleles[pi] if pi in alleles else np.array([]).astype('U') + pattern2 = alleles[pi+1] if pi+1 in alleles else np.array([]).astype('U') #If a deletion here, need to gather up all bases, and position before if any(pattern == '-'): @@ -551,6 +547,8 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): #If deletion, treat affected bases as 1 'call': if delete or deleteGroup: + if pattern.size==0: # no variation in sequences + pattern = np.full(num_samples, refb, dtype='U') i, pi, pos, refb, pattern = handleDeletions(i, pi, pos, ref, delete, pattern) #If no deletion, replace ref with '0' which means the reference base is unchanged else: @@ -565,10 +563,6 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): for u in uniques: pattern[np.where(pattern==u)[0]] = str(j) j+=1 - #Now convert these calls to a VCF format matching the ploidy. - #In case ploidy>1, we treat it as unphased ('/' as the separator) - #Note that this includes patterns of "." (no-calls) - calls = [ "/".join([j]*ploidy) for j in pattern ] #What if there's no variation at a variable site?? #This can happen when sites are modified by TreeTime - see below. @@ -576,7 +570,7 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): any_variation = len(uniques)!=0 if not any_variation: #If we expect it (it was made constant by TreeTime), it's fine. - if 'inferred_const_sites' in tree_dict and pi in tree_dict['inferred_const_sites']: + if pi in inferred_const_sites: explainedErrors += 1 else: #If we don't expect, raise an error @@ -585,6 +579,7 @@ def handleDeletions(i, pi, pos, ref, delete, pattern): #Write it out - Increment positions by 1 so it's in VCF numbering #If no longer variable, and explained, don't write it out if any_variation: + calls = [ "/".join([j]*ploidy) for j in pattern ] output = [chrom_name, str(pos), ".", refb, ",".join(uniques), ".", "PASS", ".", "GT"] + calls vcfWrite.append("\t".join(output))