/
metagenomics.py
executable file
·634 lines (529 loc) · 22.3 KB
/
metagenomics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
#!/usr/bin/env python
''' This script contains a number of utilities for metagenomic analyses.
'''
from __future__ import print_function
from __future__ import division
__author__ = "yesimon@broadinstitute.org"
import argparse
import collections
import csv
import gzip
import itertools
import logging
import os.path
from os.path import join
import operator
import queue
import shutil
import sys
import util.cmd
import util.file
import util.misc
import tools.kraken
import tools.krona
import tools.diamond
import tools.picard
__commands__ = []
log = logging.getLogger(__name__)
class TaxonomyDb(object):
def __init__(self, tax_dir=None, gis=None, nodes=None, names=None,
gis_paths=None, nodes_path=None, names_path=None):
if tax_dir:
gis_paths = [join(tax_dir, 'gi_taxid_nucl.dmp'),
join(tax_dir, 'gi_taxid_prot.dmp')]
nodes_path = join(tax_dir, 'nodes.dmp')
names_path = join(tax_dir, 'names.dmp')
self.gis_paths = gis_paths
self.nodes_path = nodes_path
self.names_path = names_path
if gis:
self.gis = gis
elif gis_paths:
self.gis = {}
for gi_path in gis_paths:
self.gis.update(load_gi_single_dmp(gi_path))
if nodes:
self.ranks, self.parents = nodes
elif nodes_path:
self.ranks, self.parents = load_nodes(nodes_path)
if names:
self.names = names
elif names_path:
self.names = load_names(names_path)
def load_gi_single_dmp(dmp_path):
'''Load a gi->taxid dmp file from NCBI taxonomy.'''
gi_array = {}
with open(dmp_path) as f:
for i, line in enumerate(f):
gi, taxid = line.strip().split('\t')
gi = int(gi)
taxid = int(taxid)
gi_array[gi] = taxid
if (i + 1) % 1000000 == 0:
print('Loaded {} gis'.format(i), file=sys.stderr)
return gi_array
def load_names(names_db, scientific_only=True):
'''Load the names.dmp file from NCBI taxonomy.'''
if scientific_only:
names = {}
else:
names = collections.defaultdict(list)
for line in open(names_db):
parts = line.strip().split('|')
taxid = int(parts[0])
name = parts[1].strip()
#unique_name = parts[2].strip()
class_ = parts[3].strip()
if scientific_only:
if class_ == 'scientific name':
names[taxid] = name
else:
names[taxid].append(name)
return names
def load_nodes(nodes_db):
'''Load ranks and parents arrays from NCBI taxonomy.'''
ranks = {}
parents = {}
with open(nodes_db) as f:
for line in f:
parts = line.strip().split('|')
taxid = int(parts[0])
parent_taxid = int(parts[1])
rank = parts[2].strip()
#embl_code = parts[3].strip()
#division_id = parts[4].strip()
parents[taxid] = parent_taxid
ranks[taxid] = rank
return ranks, parents
BlastRecord = collections.namedtuple(
'BlastRecord',
['query_id', 'subject_id', 'percent_identity', 'aln_length',
'mismatch_count', 'gap_open_count', 'query_start',
'query_end', 'subject_start', 'subject_end', 'e_val', 'bit_score'])
def blast_records(f):
'''Yield blast m8 records line by line'''
for line in f:
if line.startswith('#'):
continue
parts = line.strip().split()
for field in range(3, 10):
parts[field] = int(parts[field])
for field in (2, 10, 11):
parts[field] = float(parts[field])
yield BlastRecord(*parts)
def paired_query_id(record):
'''Replace paired suffixes in query ids.'''
suffixes = ('/1', '/2')
for suffix in suffixes:
if record.query_id.endswith(suffix):
rec_list = list(record)
rec_list[0] = record.query_id[:-len(suffix)]
return BlastRecord(*rec_list)
return record
def translate_gi_to_tax_id(db, record):
'''Replace gi headers in subject ids to int taxonomy ids.'''
gi = int(record.subject_id.split('|')[1])
tax_id = db.gis[gi]
rec_list = list(record)
rec_list[1] = tax_id
return BlastRecord(*rec_list)
def blast_lca(db, m8_file, output, paired=False, min_bit_score=50,
max_expected_value=0.01, top_percent=10, min_support_percent=0, min_support=1):
'''Calculate the LCA taxonomy id for groups of blast hits.
Writes tsv output: query_id \t tax_id
Args:
db: (TaxonomyDb) Taxonomy db.
m8_file: (io) Blast m8 file to read.
output: (io) Output file.
paired: (bool) Whether to count paired suffixes /1,/2 as one group.
min_bit_score: (float) Minimum bit score or discard.
max_expected_value: (float) Maximum e-val or discard.
top_percent: (float) Only this percent within top hit are used.
min_support_percent: (float) Find the LCA that covers this percent of hits.
min_support: (int) Find the LCA that covers this number of hits.
'''
records = blast_records(m8_file)
records = (r for r in records if r.e_val <= max_expected_value)
records = (r for r in records if r.bit_score >= min_bit_score)
if paired:
records = (paired_query_id(rec) for rec in records)
blast_groups = (v for k, v in itertools.groupby(records, operator.attrgetter('query_id')))
for blast_group in blast_groups:
blast_group = list(blast_group)
tax_id = process_blast_hits(db, blast_group, top_percent)
query_id = blast_group[0].query_id
if not tax_id:
log.debug('Query: {} has no valid taxonomy paths.'.format(query_id))
else:
output.write('{}\t{}\n'.format(query_id, tax_id))
def process_blast_hits(db, blast_hits, top_percent):
'''Filter groups of blast hits and perform lca.
Args:
db: (TaxonomyDb) Taxonomy db.
blast_hits: []BlastRecord groups of hits.
top_percent: (float) Only consider hits within this percent of top bit score.
Return:
(int) Tax id of LCA.
'''
hits = (translate_gi_to_tax_id(db, hit) for hit in blast_hits)
hits = [hit for hit in hits if hit.subject_id != 0]
if len(hits) == 0:
return
best_score = max(hit.bit_score for hit in hits)
cutoff_bit_score = (100 - top_percent) / 100 * best_score
valid_hits = (hit for hit in hits if hit.bit_score >= cutoff_bit_score)
valid_hits = list(valid_hits)
# Sort requires realized list
valid_hits.sort(key=operator.attrgetter('bit_score'), reverse=True)
if valid_hits:
tax_ids = [hit.subject_id for hit in valid_hits]
return coverage_lca(tax_ids, db.parents)
def coverage_lca(query_ids, parents, lca_percent=100):
'''Calculate the lca that will cover at least this percent of queries.
Args:
query_ids: []int list of nodes.
parents: []int array of parents.
lca_percent: (float) Cover at least this percent of queries.
Return:
(int) LCA
'''
lca_needed = lca_percent / 100 * len(query_ids)
paths = []
for query_id in query_ids:
path = []
while query_id != 1:
path.append(query_id)
if parents[query_id] == 0:
log.warn('Parent for query id: {} missing'.format(query_id))
break
query_id = parents[query_id]
if query_id == 1:
path.append(1)
path = list(reversed(path))
paths.append(path)
if not paths:
return
last_common = 1
max_path_length = max(len(path) for path in paths)
for level in range(max_path_length):
valid_paths = (path for path in paths if len(path) > level)
max_query_id, hits_covered = collections.Counter(
path[level] for path in valid_paths).most_common(1)[0]
if hits_covered >= lca_needed:
last_common = max_query_id
else:
break
return last_common
def tree_level_lookup(parents, node, level_cache):
'''Get the node level/depth.
Args:
parents: Array of node parents.
node: Node to get level (root == 1).
level_cache: Cache of previously found levels.
Returns:
(int) level of node
'''
path = []
while True:
level = level_cache.get(node)
if level:
for i, node in enumerate(reversed(path)):
level_cache[node] = level + i + 1
return level + len(path)
path.append(node)
node = parents[node]
def push_up_tree_hits(parents, hits, min_support_percent=None, min_support=None,
update_assignments=False):
'''Push up hits on nodes until min support is reached.
Args:
parents: Array of node parents.
hits: Counter of hits on each node.
min_support_percent: Push up hits until each node has
this percent of the sum of all hits.
min_support: Push up hits until each node has this number of hits.
Returns:
(counter) Hits mutated pushed up the tree.
'''
assert min_support_percent or min_support
if update_assignments:
pass
total_hits = sum(hits.values())
if not min_support:
min_support = round(min_support_percent * 0.01 * total_hits)
pq_level = queue.PriorityQueue()
level_cache = {1: 1}
for hit_id, num_hits in hits.items():
if num_hits < min_support:
pq_level.put((-tree_level_lookup(parents, hit_id, level_cache), hit_id))
while not pq_level.empty() > 0:
level, hit_id = pq_level.get()
level = -level
if hits[hit_id] >= min_support:
continue
if hit_id == 1:
del hits[1]
break
parent_hit_id = parents[hit_id]
num_hits = hits[hit_id]
hits[parent_hit_id] += num_hits
# Can't pop directly from hits because hit_id might not be stored in counter
if hit_id in hits:
del hits[hit_id]
if hits[parent_hit_id] < min_support:
pq_level.put((-tree_level_lookup(parents, parent_hit_id, level_cache), parent_hit_id))
return hits
def parents_to_children(parents):
'''Convert an array of parents to lists of children for each parent.
Returns:
(dict[list]) Lists of children
'''
children = collections.defaultdict(list)
for node, parent in parents.items():
if node == 1:
continue
if parent != 0:
children[parent].append(node)
return children
def rank_code(rank):
'''Get the short 1 letter rank code for named ranks.'''
if rank == "species":
return "S"
elif rank == "genus":
return "G"
elif rank == "family":
return "F"
elif rank == "order":
return "O"
elif rank == "class":
return "C"
elif rank == "phylum":
return "P"
elif rank == "kingdom":
return "K"
elif rank == "superkingdom":
return "D"
else:
return "-"
def taxa_hits_from_tsv(f, tax_id_column=2):
'''Return a counter of hits from tsv.'''
c = collections.Counter()
for row in csv.reader(f, delimiter='\t'):
tax_id = int(row[tax_id_column - 1])
c[tax_id] += 1
return c
def kraken_dfs_report(db, taxa_hits, prepend_column=True):
'''Return a kraken compatible DFS report of taxa hits.
Args:
db: (TaxonomyDb) Taxonomy db.
taxa_hits: (collections.Counter) # of hits per tax id.
Return:
[]str lines of the report
'''
line_prefix = '\t' if prepend_column else ''
db.children = parents_to_children(db.parents)
total_hits = sum(taxa_hits.values())
lines = []
kraken_dfs(db, lines, taxa_hits, total_hits, 1, 0)
unclassified_hits = taxa_hits.get(0, 0)
unclassified_hits += taxa_hits.get(-1, 0)
if unclassified_hits > 0:
percent_covered = '%.2f' % (unclassified_hits / total_hits * 100)
lines.append(line_prefix+'\t'.join([
str(percent_covered), str(unclassified_hits),
str(unclassified_hits), 'U', '0', 'unclassified'
]))
return reversed(lines)
def kraken_dfs(db, lines, taxa_hits, total_hits, taxid, level):
'''Recursively do DFS for number of hits per taxa.'''
cum_hits = num_hits = taxa_hits.get(taxid, 0)
for child_taxid in db.children[taxid]:
cum_hits += kraken_dfs(db, lines, taxa_hits, total_hits, child_taxid, level + 1)
percent_covered = '%.2f' % (cum_hits / total_hits * 100)
rank = rank_code(db.ranks[taxid])
name = db.names[taxid]
if cum_hits > 0:
lines.append('\t'.join([percent_covered, str(cum_hits), str(num_hits),
rank, str(taxid), ' ' * level + name]))
return cum_hits
def kraken(inBam, db, outReport=None, outReads=None,
filterThreshold=None, numThreads=1):
'''
Classify reads by taxon using Kraken
'''
assert outReads or outReport, (
'Either --outReads or --outReport must be specified.')
kraken_tool = tools.kraken.Kraken()
# kraken classify
tmp_reads = util.file.mkstempfname('.kraken')
kraken_tool.classify(inBam, db, tmp_reads, numThreads=numThreads)
# kraken filter
if filterThreshold:
tmp_filtered_reads = util.file.mkstempfname('.filtered-kraken')
kraken_tool.filter(tmp_reads, db, tmp_filtered_reads, filterThreshold)
os.unlink(tmp_reads)
else:
tmp_filtered_reads = tmp_reads
# copy outReads
if outReads:
with open(tmp_filtered_reads, 'rb') as f_in:
with util.file.open_or_gzopen(outReads, 'w') as f_out:
shutil.copyfileobj(f_in, f_out)
# kraken report
if outReport:
kraken_tool.report(tmp_filtered_reads, db, outReport)
os.unlink(tmp_filtered_reads)
def parser_kraken(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input unaligned reads, BAM format.')
parser.add_argument('db', help='Kraken database directory.')
parser.add_argument('--outReport', help='Kraken report output file.')
parser.add_argument('--outReads', help='Kraken per read output file.')
parser.add_argument('--filterThreshold',
default=0.05,
type=float,
help='Kraken filter threshold (default %(default)s)')
parser.add_argument('--numThreads', type=int, default=1, help='Number of threads to run. (default %(default)s)')
util.cmd.common_args(parser, (('loglevel', None), ('version', None),
('tmp_dir', None)))
util.cmd.attach_main(parser, kraken, split_args=True)
return parser
def krona(inTsv, db, outHtml, queryColumn=None, taxidColumn=None,
scoreColumn=None, noHits=None, noRank=None):
'''
Create an interactive HTML report from a tabular metagenomic report
'''
krona_tool = tools.krona.Krona()
if inTsv.endswith('.gz'):
tmp_tsv = util.file.mkstempfname('.tsv')
with gzip.open(inTsv, 'rb') as f_in:
with open(tmp_tsv, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
to_import = [tmp_tsv]
else:
to_import = [inTsv]
krona_tool.import_taxonomy(
db, to_import, outHtml, query_column=queryColumn, taxid_column=taxidColumn,
score_column=scoreColumn, no_hits=noHits, no_rank=noRank)
def parser_krona(parser=argparse.ArgumentParser()):
parser.add_argument('inTsv', help='Input tab delimited file.')
parser.add_argument('db', help='Krona taxonomy database directory.')
parser.add_argument('outHtml', help='Output html report.')
parser.add_argument('--queryColumn', help='Column of query id. (default %(default)s)',
type=int, default=2)
parser.add_argument('--taxidColumn', help='Column of taxonomy id. (default %(default)s)',
type=int, default=3)
parser.add_argument('--scoreColumn', help='Column of score. (default %(default)s)',
type=int)
parser.add_argument('--noHits', help='Include wedge for no hits.',
action='store_true')
parser.add_argument('--noRank', help='Include no rank assignments.',
action='store_true')
util.cmd.common_args(parser, (('loglevel', None), ('version', None)))
util.cmd.attach_main(parser, krona, split_args=True)
return parser
def diamond(inBam, db, taxDb, outReport, outM8=None, outLca=None, numThreads=1):
'''
Classify reads by the taxon of the Lowest Common Ancestor (LCA)
'''
tmp_fastq = util.file.mkstempfname('.fastq')
tmp_fastq2 = util.file.mkstempfname('.fastq')
# do not convert this to samtools bam2fq unless we can figure out how to replicate
# the clipping functionality of Picard SamToFastq
picard = tools.picard.SamToFastqTool()
picard_opts = {
'CLIPPING_ATTRIBUTE': tools.picard.SamToFastqTool.illumina_clipping_attribute,
'CLIPPING_ACTION': 'X'
}
picard.execute(inBam, tmp_fastq, tmp_fastq2,
picardOptions=tools.picard.PicardTools.dict_to_picard_opts(picard_opts),
JVMmemory=picard.jvmMemDefault)
diamond_tool = tools.diamond.Diamond()
diamond_tool.install()
tmp_alignment = util.file.mkstempfname('.daa')
tmp_m8 = util.file.mkstempfname('.diamond.m8')
diamond_tool.blastx(db, [tmp_fastq, tmp_fastq2], tmp_alignment,
options={'--threads': numThreads})
diamond_tool.view(tmp_alignment, tmp_m8,
options={'--threads': numThreads})
if outM8:
with open(tmp_m8, 'rb') as f_in:
with gzip.open(outM8, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
tax_db = TaxonomyDb(tax_dir=taxDb)
tmp_lca_tsv = util.file.mkstempfname('.tsv')
with open(tmp_m8) as m8, open(tmp_lca_tsv, 'w') as lca:
blast_lca(tax_db, m8, lca, paired=True, min_bit_score=50)
if outLca:
with open(tmp_lca_tsv, 'rb') as f_in:
with gzip.open(outLca, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
with open(tmp_lca_tsv) as f:
hits = taxa_hits_from_tsv(f)
with open(outReport, 'w') as f:
for line in kraken_dfs_report(tax_db, hits, prepend_column=True):
print(line, file=f)
def parser_diamond(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input unaligned reads, BAM format.')
parser.add_argument('db', help='Diamond database directory.')
parser.add_argument('taxDb', help='Taxonomy database directory.')
parser.add_argument('outReport', help='Output taxonomy report.')
parser.add_argument('--outM8', help='Blast m8 formatted output file.')
parser.add_argument('--outLca', help='Output LCA assignments for each read.')
parser.add_argument('--numThreads', default=1, help='Number of threads (default: %(default)s)')
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, diamond, split_args=True)
return parser
def metagenomic_report_merge(metagenomic_reports, out_kraken_summary, kraken_db, out_krona_input):
'''
Merge multiple metegenomic reports into a single metagenomic report.
Any Krona input files created by this
'''
assert out_kraken_summary or out_krona_input, (
"Either --outSummaryReport or --outByQueryToTaxonID must be specified")
assert kraken_db if out_kraken_summary else True, (
'A Kraken db must be provided via --krakenDB if outSummaryReport is specified')
# column numbers containing the query (sequence) ID and taxonomic ID
# these are one-indexed
# See: http://ccb.jhu.edu/software/kraken/MANUAL.html#output-format
# tool_data_columns = {
# "kraken": (2, 3)
# }
# if we're creating a Krona input file
if out_krona_input:
# open the output file (as gz if necessary)
with util.file.open_or_gzopen(out_krona_input ,"wt") as outf:
# create a TSV writer for the output file
output_writer = csv.writer(outf, delimiter='\t', lineterminator='\n')
if metagenomic_reports:
# for each Kraken-format metag file specified, pull out the appropriate columns
# and write them to the TSV output
for metag_file in metagenomic_reports:
with util.file.open_or_gzopen(metag_file.name ,"rt") as inf:
file_reader = csv.reader(inf, delimiter='\t')
for row in file_reader:
# for only the two relevant columns
output_writer.writerow([f for f in row])
#output_writer.writerow([row[c-1] for c in tool_data_columns["kraken"]])
# create a human-readable summary of the Kraken reports
# kraken-report can only be used on kraken reports since it depends on queries being in its database
if out_kraken_summary:
# create temporary file to hold combined kraken report
tmp_metag_combined_txt = util.file.mkstempfname('.txt')
util.file.cat(tmp_metag_combined_txt, [metag_file.name for metag_file in metagenomic_reports])
kraken_tool = tools.kraken.Kraken()
kraken_tool.report(tmp_metag_combined_txt, kraken_db.name, out_kraken_summary)
def parser_metagenomic_report_merge(parser=argparse.ArgumentParser()):
parser.add_argument("metagenomic_reports", help="Input metagenomic reports with the query ID and taxon ID in the 2nd and 3rd columns (Kraken format)", nargs='+', type=argparse.FileType('r'))
parser.add_argument("--outSummaryReport", dest="out_kraken_summary", help="Path of human-readable metagenomic summary report, created by kraken-report")
parser.add_argument("--krakenDB", dest="kraken_db", help="Kraken database (needed for outSummaryReport)", type=argparse.FileType('r'))
parser.add_argument("--outByQueryToTaxonID", dest="out_krona_input", help="Output metagenomic report suitable for Krona input. ")
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, metagenomic_report_merge, split_args=True)
return parser
__commands__.append(('kraken', parser_kraken))
__commands__.append(('diamond', parser_diamond))
__commands__.append(('krona', parser_krona))
__commands__.append(('report_merge', parser_metagenomic_report_merge))
def full_parser():
return util.cmd.make_parser(__commands__, __doc__)
if __name__ == '__main__':
util.cmd.main_argparse(__commands__, __doc__)