Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding arbitrary parameters to blastn_chunked_fasta function #46

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 52 additions & 18 deletions classify/blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import util.misc

TOOL_NAME = "blastn"

#Creating task.log
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("blast_py.log"),
golu099 marked this conversation as resolved.
Show resolved Hide resolved
logging.StreamHandler()
]
)
_log = logging.getLogger(__name__)

class BlastTools(tools.Tool):
Expand Down Expand Up @@ -38,44 +46,70 @@ class BlastnTool(BlastTools):
""" Tool wrapper for blastn """
subtool_name = 'blastn'

def get_hits_pipe(self, inPipe, db, threads=None):

def get_hits_pipe(self, inPipe, db, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type="read_id"):
_log.debug(f"Executing get_hits_pipe function. Called with outfmt: {outfmt}")
#toggle between extracting read IDs only or full blast output (all lines)
if output_type not in ['read_id', 'full_line']:
_log.warning(f"Invalid output_type '{output_type}' specified. Defaulting to 'read_id'.")
output_type = 'read_id'
# run blastn and emit list of read IDs
threads = util.misc.sanitize_thread_count(threads)
cmd = [self.install_and_get_path(),
'-db', db,
'-word_size', 16,
'-num_threads', threads,
'-evalue', '1e-6',
'-outfmt', 6,
'-max_target_seqs', 1,
'-outfmt', str(outfmt),
'-max_target_seqs', str(max_target_seqs),
'-task', str(task) if task else 'blastn',
]
cmd = [str(x) for x in cmd]
#Log BLAST command executed
_log.debug('Running blastn command: {}'.format(' '.join(cmd)))
_log.debug('| ' + ' '.join(cmd) + ' |')
blast_pipe = subprocess.Popen(cmd, stdin=inPipe, stdout=subprocess.PIPE)
output, error = blast_pipe.communicate()

# strip tab output to just query read ID names and emit
#Display error message if BLAST failed
if blast_pipe.returncode!= 0:
_log.error('Error running blastn command: {}'.format(error))
raise subprocess.CalledProcessError(blast_pipe.returncode, cmd)

# If read_id is defined, strip tab output to just query read ID names and emit (default)
last_read_id = None
for line in blast_pipe.stdout:
line = line.decode('UTF-8').rstrip('\n\r')
read_id = line.split('\t')[0]
# only emit if it is not a duplicate of the previous read ID
if read_id != last_read_id:
last_read_id = read_id
yield read_id

if blast_pipe.poll():
for line in output.decode('UTF-8').splitlines():
if output_type == 'read_id':
#Split line by tab, and take the first element
read_id = line.split('\t')[0]
# Only emit if it is not a duplicate of the previous read ID
if read_id != last_read_id:
last_read_id = read_id
yield read_id
#Yield the full line without stripping whitespace
elif output_type == 'full_line':
yield line

#Display on CMD if BLAST fails
if blast_pipe.returncode!= 0:
_log.error('Error running blastn command: {}'.format(error))
raise subprocess.CalledProcessError(blast_pipe.returncode, cmd)

#Logging configuration written to blast_py.log if BLAST passes/fails
if blast_pipe.returncode == 0:
_log.info("Blastn process completed succesfully.")
else:
_log.error("Blastn process failed with exit code: %s", blast_pipe.returncode)
raise subprocess.CalledProcessError(blast_pipe.returncode, cmd)

def get_hits_bam(self, inBam, db, threads=None):
return self.get_hits_pipe(
tools.samtools.SamtoolsTool().bam2fa_pipe(inBam),
db,
threads=threads)

def get_hits_fasta(self, inFasta, db, threads=None):
def get_hits_fasta(self, inFasta, db, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'):
_log.debug(f"Executing get_hits_fasta function. Called with outfmt: {outfmt}")
with open(inFasta, 'rt') as inf:
for hit in self.get_hits_pipe(inf, db, threads=threads):
for hit in self.get_hits_pipe(inf, db, threads=threads, task=None, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type):
yield hit


Expand Down
75 changes: 63 additions & 12 deletions taxon_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@
import classify.bmtagger
import read_utils

log = logging.getLogger(__name__)


#Adding logging configuration to identify issues/ time spent
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("task.log"), # Log file name
golu099 marked this conversation as resolved.
Show resolved Hide resolved
logging.StreamHandler() # Keeps the console output if desired
]
)
log = logging.getLogger(__name__)
# =======================
# *** deplete ***
# =======================
Expand Down Expand Up @@ -397,15 +405,22 @@ def multi_db_deplete_bam(inBam, refDbs, deplete_method, outBam, **kwargs):
# ========================


def _run_blastn_chunk(db, input_fasta, out_hits, blast_threads):
def _run_blastn_chunk(db, input_fasta, out_hits, blast_threads, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'):
""" run blastn on the input fasta file. this is intended to be run in parallel
by blastn_chunked_fasta
"""
with util.file.open_or_gzopen(out_hits, 'wt') as outf:
for read_id in classify.blast.BlastnTool().get_hits_fasta(input_fasta, db, threads=blast_threads):
outf.write(read_id + '\n')

def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
#Might need to remove this path, not absolute
#os.environ['BLASTDB']= 'viral-classify/blast'
try:
with util.file.open_or_gzopen(out_hits, 'wt') as outf:
for line in classify.blast.BlastnTool().get_hits_fasta(input_fasta, db, threads=blast_threads, task=task, outfmt=outfmt, output_type=output_type):
outf.write(line + '\n')
log.info("_run_blastn_chunk completed succesfully.")
except Exception as e:
log.error("An error occurred in _run_blastn_chunk.:%s", e)
golu099 marked this conversation as resolved.
Show resolved Hide resolved
raise e

def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'):
"""
Helper function: blastn a fasta file, overcoming apparent memory leaks on
an input with many query sequences, by splitting it into multiple chunks
Expand All @@ -415,6 +430,9 @@ def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
# the lower bound of how small a fasta chunk can be.
# too small and the overhead of spawning a new blast process
# will be detrimental relative to actual computation time

#checks if the blastn_chunked_fasta function is being called
log.info("Calling blastn_chunked_fasta function...")
MIN_CHUNK_SIZE = 20000

# just in case blast is not installed, install it once, not many times in parallel!
Expand Down Expand Up @@ -455,7 +473,7 @@ def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
log.debug("blastn chunk size %s" % chunkSize)
log.debug("number of chunks to create %s" % (number_of_reads / chunkSize))
log.debug("blastn parallel instances %s" % threads)

log.debug(f"outfmt value: {outfmt}")
# chunk the input file. This is a sequential operation
input_fastas = []
with open(fasta, "rt") as fastaFile:
Expand All @@ -464,8 +482,10 @@ def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
chunk_fasta = mkstempfname('.fasta')

with open(chunk_fasta, "wt") as handle:
SeqIO.write(batch, handle, "fasta")
count= SeqIO.write(batch, handle, "fasta")
batch = None
#detail chunk sizes being processed
log.info(f"Created chunk {chunk_fasta} with {count} records")
input_fastas.append(chunk_fasta)

num_chunks = len(input_fastas)
Expand All @@ -481,8 +501,7 @@ def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
cpus_leftover = threads - num_chunks
blast_threads = 2*max(1, int(cpus_leftover / num_chunks))
for i in range(num_chunks):
executor.submit(
_run_blastn_chunk, db, input_fastas[i], hits_files[i], blast_threads)
executor.submit(_run_blastn_chunk, db, input_fastas[i], hits_files[i], blast_threads, task=task, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type)

# merge results and clean up
util.file.cat(out_hits, hits_files)
Expand Down Expand Up @@ -515,6 +534,38 @@ def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMme
tools.picard.FilterSamReadsTool().execute(inBam, True, blast_hits, outBam, JVMmemory=JVMmemory)
os.unlink(blast_hits)

def chunk_blast_hits(inFasta, db, blast_hits_output, threads=None, chunkSize=1000000, task=None, outfmt='6', max_target_seqs=1, output_type= 'read_id'):
'Process BLAST hits from a FASTA file by dividing the file into smaller chunks for parallel processing (blastn_chunked_fasta).'
if chunkSize:
log.info("Running BLASTN on %s against database %s", inFasta, db)
#Execute blastn_chunked_fasta
blastn_chunked_fasta(inFasta, db, blast_hits_output, chunkSize, threads, task, outfmt, max_target_seqs, output_type=output_type)
else:
#Pipe tools together and run blastn multithreaded
with open(blast_hits_output, 'wt') as outf:
for output in classify.blast.BlastnTool().get_hits_fasta(inFasta, db, threads, task=task, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type):
#Account for read_ids extract only or full blast output run. Default = read_lines.
if output_type == 'read_id':
# Extract the first clmn in the output (assuming its the read ID)
read_id = output.split('\t')[0]
outf.write(read_id + '\n')
else:
#Extract and write full line if the output_type is not set to just read IDs
outf.write(output + '\n')

def parser_chunk_blast_hits(parser=argparse.ArgumentParser()):
parser = argparse.ArgumentParser(description="Run BLASTN on chunks of a FASTA file.")
parser.add_argument('inBam', help='Input BAM file.')
parser.add_argument('db', help='BLASTN database.')
parser.add_argument('blast_hits_output', help='Stores hits found by BLASTN.')
parser.add_argument("--chunkSize", type=int, default=1000000, help='FASTA chunk size (default: %(default)s)')
parser.add_argument("-task", help="details the type of search (i.e. megablast,blatn,etc)")
parser.add_argument("-outfmt", type=str, default=6, help="Custom output formats(default: %(default)s)")
parser.add_argument("-max_target_seqs", type=int, default=1, help="BLAST will return the first (if set to default) database hits for a sequence query. (default: %(default)s)")
parser.add_argument("--output_type", default= "read_id", choices=["read_id", "full_line"], help="Specify the type of output: 'read_id' for read IDs only, or 'full_line' for full BLAST output lines. Default is 'read_id'. Useful when adding taxonomy IDs to outfmt type 6.")
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, chunk_blast_hits)
return parser

def parser_deplete_blastn_bam(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input BAM file.')
Expand Down
Loading