diff --git a/pipes/WDL/workflows/tasks/taxon_filter.wdl b/pipes/WDL/workflows/tasks/taxon_filter.wdl index f72753b7f..629befdac 100644 --- a/pipes/WDL/workflows/tasks/taxon_filter.wdl +++ b/pipes/WDL/workflows/tasks/taxon_filter.wdl @@ -7,6 +7,8 @@ task deplete_taxa { Array[File]? bmtaggerDbs # .tar.gz, .tgz, .tar.bz2, .tar.lz4, .fasta, or .fasta.gz Array[File]? blastDbs # .tar.gz, .tgz, .tar.bz2, .tar.lz4, .fasta, or .fasta.gz Int? query_chunk_size + Boolean? clear_tags = false + String? tags_to_clear_space_separated = "XT X0 X1 XA AM SM BQ CT XN OC OP" String bam_basename = basename(raw_reads_unmapped_bam, ".bam") @@ -25,6 +27,13 @@ task deplete_taxa { DBS_BLAST="${sep=' ' blastDbs}" if [ -n "$DBS_BMTAGGER" ]; then DBS_BMTAGGER="--bmtaggerDbs $DBS_BMTAGGER"; fi if [ -n "$DBS_BLAST" ]; then DBS_BLAST="--blastDbs $DBS_BLAST"; fi + + if [[ "${clear_tags}" == "true" ]]; then + TAGS_TO_CLEAR="--clearTags" + if [[ -n "$tags_to_clear_space_separated" ]]; then + TAGS_TO_CLEAR="$TAGS_TO_CLEAR ${'--tagsToClear=' + tags_to_clear_space_separated}" + fi + fi # run depletion taxon_filter.py deplete_human \ @@ -35,6 +44,7 @@ task deplete_taxa { ${bam_basename}.cleaned.bam \ $DBS_BMTAGGER $DBS_BLAST \ ${'--chunkSize=' + query_chunk_size} \ + $TAGS_TO_CLEAR \ --JVMmemory="$mem_in_mb"m \ --srprismMemory=$mem_in_mb \ --loglevel=DEBUG diff --git a/read_utils.py b/read_utils.py index 6460aab9a..4b92d1699 100755 --- a/read_utils.py +++ b/read_utils.py @@ -15,6 +15,7 @@ import tempfile import shutil import sys +import concurrent.futures from Bio import SeqIO @@ -715,6 +716,35 @@ def parser_rmdup_cdhit_bam(parser=argparse.ArgumentParser()): __commands__.append(('rmdup_cdhit_bam', parser_rmdup_cdhit_bam)) +def _merge_fastqs_and_mvicuna(lb, files): + readList = mkstempfname('.keep_reads.txt') + log.info("executing M-Vicuna DupRm on library " + lb) + + # create merged FASTQs per library + infastqs = (mkstempfname('.1.fastq'), mkstempfname('.2.fastq')) + for d in range(2): + with open(infastqs[d], 'wt') as outf: + for fprefix in files: + fn = '%s_%d.fastq' % (fprefix, d + 1) + + if os.path.isfile(fn): + with open(fn, 'rt') as inf: + for line in inf: + outf.write(line) + os.unlink(fn) + else: + log.warn( + """no reads found in %s, + assuming that's because there's no reads in that read group""", fn + ) + + # M-Vicuna DupRm to see what we should keep (append IDs to running file) + if os.path.getsize(infastqs[0]) > 0 or os.path.getsize(infastqs[1]) > 0: + mvicuna_fastqs_to_readlist(infastqs[0], infastqs[1], readList) + for fn in infastqs: + os.unlink(fn) + + return readList def rmdup_mvicuna_bam(inBam, outBam, JVMmemory=None): ''' Remove duplicate reads from BAM file using M-Vicuna. The @@ -738,36 +768,26 @@ def rmdup_mvicuna_bam(inBam, outBam, JVMmemory=None): log.info("found %d distinct libraries and %d read groups", len(lb_to_files), len(read_groups)) # For each library, merge FASTQs and run rmdup for entire library - readList = mkstempfname('.keep_reads.txt') - for lb, files in lb_to_files.items(): - log.info("executing M-Vicuna DupRm on library " + lb) - - # create merged FASTQs per library - infastqs = (mkstempfname('.1.fastq'), mkstempfname('.2.fastq')) - for d in range(2): - with open(infastqs[d], 'wt') as outf: - for fprefix in files: - fn = '%s_%d.fastq' % (fprefix, d + 1) - - if os.path.isfile(fn): - with open(fn, 'rt') as inf: - for line in inf: - outf.write(line) - os.unlink(fn) - else: - log.warn( - """no reads found in %s, - assuming that's because there's no reads in that read group""", fn - ) - - # M-Vicuna DupRm to see what we should keep (append IDs to running file) - if os.path.getsize(infastqs[0]) > 0 or os.path.getsize(infastqs[1]) > 0: - mvicuna_fastqs_to_readlist(infastqs[0], infastqs[1], readList) - for fn in infastqs: - os.unlink(fn) + readListAll = mkstempfname('.keep_reads_all.txt') + per_lb_read_lists = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=util.misc.available_cpu_count()) as executor: + futures = [executor.submit(_merge_fastqs_and_mvicuna, lb, files) for lb, files in lb_to_files.items()] + for future in concurrent.futures.as_completed(futures): + log.info("mvicuna finished processing library") + try: + readList = future.result() + per_lb_read_lists.append(readList) + except Exception as exc: + log.error('mvicuna process call generated an exception: %s' % (exc)) + + # merge per-library read lists together + util.file.concat(per_lb_read_lists, readListAll) + # remove per-library read lists + for fl in per_lb_read_lists: + os.unlink(fl) # Filter original input BAM against keep-list - tools.picard.FilterSamReadsTool().execute(inBam, False, readList, outBam, JVMmemory=JVMmemory) + tools.picard.FilterSamReadsTool().execute(inBam, False, readListAll, outBam, JVMmemory=JVMmemory) return 0 diff --git a/taxon_filter.py b/taxon_filter.py index 72fb8ebbc..9ea942dd7 100755 --- a/taxon_filter.py +++ b/taxon_filter.py @@ -65,6 +65,11 @@ def parser_deplete_human(parser=argparse.ArgumentParser()): ) parser.add_argument('--srprismMemory', dest="srprism_memory", type=int, default=7168, help='Memory for srprism.') parser.add_argument("--chunkSize", type=int, default=1000000, help='blastn chunk size (default: %(default)s)') + parser.add_argument('--clearTags', dest='clear_tags', default=False, action='store_true', + help='When supplying an aligned input file, clear the per-read attribute tags') + parser.add_argument("--tagsToClear", type=str, nargs='+', dest="tags_to_clear", default=["XT", "X0", "X1", "XA", + "AM", "SM", "BQ", "CT", "XN", "OC", "OP"], + help='A space-separated list of tags to remove from all reads in the input bam file (default: %(default)s)') parser.add_argument( '--JVMmemory', default=tools.picard.FilterSamReadsTool.jvmMemDefault, @@ -96,8 +101,14 @@ def main_deplete_human(args): with pysam.AlignmentFile(args.inBam, 'rb', check_sq=False) as bam: # if it looks like the bam is aligned, revert it if 'SQ' in bam.header and len(bam.header['SQ'])>0: + picardTagOptions = [] + if args.clear_tags: + for tag in args.tags_to_clear: + picardTagOptions.append("ATTRIBUTE_TO_CLEAR={}".format(tag)) + tools.picard.RevertSamTool().execute( - args.inBam, revertBamOut, picardOptions=['SORT_ORDER=queryname', 'SANITIZE=true'] + args.inBam, revertBamOut, picardOptions=['SORT_ORDER=queryname', + 'SANITIZE=true'] + picardTagOptions ) bamToDeplete = revertBamOut else: @@ -123,8 +134,8 @@ def bmtagger_wrapper(inBam, db, outBam, JVMmemory=None): # if the user has not specified saving a revertBam, we used a temp file and can remove it if not args.revertBam: os.unlink(revertBamOut) - read_utils.rmdup_mvicuna_bam(args.bmtaggerBam, args.rmdupBam, JVMmemory=args.JVMmemory) + read_utils.rmdup_mvicuna_bam(args.bmtaggerBam, args.rmdupBam, JVMmemory=args.JVMmemory) multi_db_deplete_bam( args.rmdupBam, args.blastDbs, @@ -476,6 +487,7 @@ def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMme blast_hits = mkstempfname('.blast_hits.txt') with util.file.tmp_dir('-blastn_db_unpack') as tempDbDir: + db_dir = "" if os.path.exists(db): if os.path.isfile(db): # this is a single file @@ -491,9 +503,12 @@ def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMme db_dir = db # this directory should have a .bitmask and a .srprism file in it somewhere hits = list(glob.glob(os.path.join(db_dir, '*.nin'))) - if len(hits) != 1: - raise Exception() - db_prefix = hits[0][:-4] # remove the '.nin' + if len(hits) == 0: + raise Exception("The blast database does not appear to a *.nin file.") + elif len(hits) == 1: + db_prefix = hits[0][:-4] # remove the '.nin' + elif len(hits) >1: + db_prefix = os.path.commonprefix(hits).rsplit('.', 1)[0] # remove '.nin' and split-db prefix else: # this is simply a prefix to a bunch of files, not an actual file db_prefix = db