Skip to content

Commit

Permalink
Refactor subset_reads, switch to skbio
Browse files Browse the repository at this point in the history
  • Loading branch information
polyatail committed Jan 24, 2019
1 parent 76dad8f commit d78623a
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 112 deletions.
261 changes: 151 additions & 110 deletions onecodex/scripts/subset_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,10 @@
import os

from onecodex.auth import login_required
from onecodex.exceptions import OneCodexException
from onecodex.exceptions import OneCodexException, ValidationError
from onecodex.utils import download_file_helper, get_download_dest, pretty_errors


def with_progress_bar(length, ix, *args, **kwargs):
label = kwargs.pop('label', None)
bar_kwargs = {
'length': length
}
if label:
bar_kwargs['label'] = label
with click.progressbar(**bar_kwargs) as bar:
return ix(cb=bar.update, *args, **kwargs)


def filter_rows_by_taxid(results, tax_ids, cb=None):
filtered_rows = []
for i, row in enumerate(results):
if row['Tax ID'] in tax_ids:
filtered_rows.append(i)
if i % 1000 == 0 and cb:
cb(1000)
return filtered_rows


def get_record_count(fp):
counter = 0
for _ in fp:
counter += 1
return counter


def get_filtered_filename(filepath):
filename = os.path.split(filepath)[-1]
prefix, ext = os.path.splitext(filename.rstrip('.gz')
Expand Down Expand Up @@ -79,6 +51,7 @@ def recurse_taxonomy_map(tax_id_map, tax_id, parent=False):
"""

if parent:
# TODO: allow filtering on tax_id and its parents, too
pass
else:
def _child_recurse(tax_id, visited):
Expand All @@ -97,45 +70,59 @@ def _child_recurse(tax_id, visited):
return list(set(_child_recurse(tax_id, [])))


@click.command('filter_reads', help='Filter a FASTX file based on the taxonomic results from a CLASSIFICATION_ID')
@click.command('subset_reads',
help='Subset a FASTX file based on the taxonomic results from a CLASSIFICATION_ID. '
'By default, reads (or pairs or reads) matching the given taxonomic ID are '
'written to the path provided, ignoring low confidence taxonomic assignments. ')
@click.argument('classification_id')
@click.argument('fastx', type=click.Path())
@click.option('-t', '--tax-id', required=True, multiple=True,
help='Filter reads mapping to tax IDs. May be passed multiple times.')
@click.option('-t', '--tax-id', 'tax_ids', required=True, multiple=True,
help='Subset reads mapping to tax IDs. May be passed multiple times.')
@click.option('-r', '--reverse', type=click.Path(),
help='The reverse (R2) read file, optionally.')
@click.option('--with-children', default=False, is_flag=True,
help='Keep reads of child taxa, too. For example, all strains of E. coli')
@click.option('-r', '--reverse', type=click.Path(), help='The reverse (R2) '
'read file, optionally')
@click.option('--split-pairs/--keep-pairs', default=False, help='Keep only '
'the read pair member that matches the list of tax ID\'s')
help='Match child taxa of those given with -t (e.g., all strains of E. coli)')
@click.option('--split-pairs', default=False, is_flag=True,
help='By default, if either read in a pair matches, both will match. Choose this '
'option to consider each paired-end read separately. Resulting files may *not* '
'have the same number of reads!')
@click.option('--exclude-reads', default=False, is_flag=True,
help='Rather than keep reads matching reads, choosing this option will exclude them.')
@click.option('-o', '--out', default='.', type=click.Path(), help='Where '
'to put the filtered outputs')
help='By default, matching reads are kept. Choose this option to instead output '
'reads that do *not* match.')
@click.option('--include-lowconf', default=False, is_flag=True,
help='By default, reads with low confidence taxonomic assignments are ignored. '
'Choose this option to include them.')
@click.option('-o', '--out', default='.', type=click.Path(),
help='Where to save the filtered outputs')
@click.pass_context
@pretty_errors
@login_required
def cli(ctx, classification_id, fastx, reverse, tax_id, with_children,
split_pairs, exclude_reads, out):
def cli(ctx, classification_id, fastx, reverse, tax_ids, with_children,
split_pairs, exclude_reads, include_lowconf, out):
import skbio

if not len(tax_id):
if not len(tax_ids):
raise OneCodexException('You must supply at least one tax ID')

# fetch classification result object from API
classification = ctx.obj['API'].Classifications.get(classification_id)
if classification is None:
raise OneCodexException('Classification {} not found.'.format(classification_id))
raise ValidationError('Classification {} not found.'.format(classification_id))

# if with children, expand tax_ids by referring to the taxonomic tree
if with_children:
tax_id_map = make_taxonomy_dict(classification)

new_tax_ids = []

for t_id in tax_id:
for t_id in tax_ids:
new_tax_ids.extend(recurse_taxonomy_map(tax_id_map, t_id))

tax_id = list(set(new_tax_ids))
tax_ids = new_tax_ids

tax_ids = set(tax_ids)

# pull the classification result TSV
tsv_url = classification.readlevel()['url']
readlevel_path = get_download_dest('./', tsv_url)
if not os.path.exists(readlevel_path):
Expand All @@ -144,28 +131,20 @@ def cli(ctx, classification_id, fastx, reverse, tax_id, with_children,
click.echo('Using cached read-level results: {}'
.format(readlevel_path), err=True)

filtered_rows = []
tsv_row_count = 0
# count the number of rows in the TSV file
with gzip.open(readlevel_path, 'rt') as tsv:
try:
tsv_row_count = get_record_count(tsv) - 1 # discount header line
tsv_row_count = 0
for _ in tsv:
tsv_row_count += 1
tsv_row_count -= 1 # discount header line
except EOFError:
click.echo('\nWe encountered an error while processing the read '
'level results. Please delete {} and try again.'
.format(readlevel_path), err=True)
raise
else:
tsv.seek(0)
reader = csv.DictReader(tsv, delimiter='\t')
click.echo('Selecting results matching tax ID(s): {}'
.format(', '.join(tax_id)), err=True)
filtered_rows = with_progress_bar(
tsv_row_count,
filter_rows_by_taxid,
reader,
tax_id
)

# determine the name of the output file(s)
filtered_filename, ext = get_filtered_filename(fastx)
filtered_filename = os.path.join(out, filtered_filename)
if reverse:
Expand All @@ -181,61 +160,123 @@ def cli(ctx, classification_id, fastx, reverse, tax_id, with_children,
'{}: extension must be one of .fa, .fna, .fasta, .fq, .fastq'.format(fastx)
)

fastx_record_count = get_record_count(skbio.io.read(fastx, **io_kwargs))

if reverse:
fastx_record_count = fastx_record_count * 2

if tsv_row_count != fastx_record_count:
os.remove(readlevel_path)
raise OneCodexException('The supplied file has a different number of '
'records than the requested Classification')

save_msg = 'Saving filtered reads: {}'.format(filtered_filename)
# do the actual filtering
save_msg = 'Saving subsetted reads: {}'.format(filtered_filename)
if reverse:
save_msg += ' and {}'.format(rev_filtered_filename)
click.echo(save_msg, err=True)

# skbio doesn't support built-in open() method in python2. must use io.open()
counter = 0
if reverse:
with io.open(filtered_filename, 'w') as out_file, \
io.open(rev_filtered_filename, 'w') as rev_out_file: # noqa
if split_pairs:
for fwd, rev in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs)):
if exclude_reads:
if counter not in filtered_rows:
fwd.write(out_file, **io_kwargs)
if (counter + 1) not in filtered_rows:
rev.write(rev_out_file, **io_kwargs)
with click.progressbar(length=tsv_row_count) as bar, gzip.open(readlevel_path, 'rt') as tsv:
reader = csv.DictReader(tsv, delimiter='\t')

if reverse:
with io.open(filtered_filename, 'w') as out_file, \
io.open(rev_filtered_filename, 'w') as rev_out_file: # noqa
if split_pairs:
if include_lowconf:
if exclude_reads:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
if row['Tax ID'] not in tax_ids:
fwd.write(out_file, **io_kwargs)
row2 = next(reader)
if row2['Tax ID'] not in tax_ids:
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
if row['Tax ID'] in tax_ids:
fwd.write(out_file, **io_kwargs)
row2 = next(reader)
if row2['Tax ID'] in tax_ids:
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
if exclude_reads:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
if row['Passed Filter'] == 'T' and row['Tax ID'] not in tax_ids:
fwd.write(out_file, **io_kwargs)
row2 = next(reader)
if row2['Passed Filter'] == 'T' and row2['Tax ID'] not in tax_ids:
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
if row['Passed Filter'] == 'T' and row['Tax ID'] in tax_ids:
fwd.write(out_file, **io_kwargs)
row2 = next(reader)
if row2['Passed Filter'] == 'T' and row2['Tax ID'] in tax_ids:
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
if include_lowconf:
if exclude_reads:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
row2 = next(reader)
if row['Tax ID'] not in tax_ids or row2['Tax ID'] not in tax_ids:
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
row2 = next(reader)
if row['Tax ID'] in tax_ids or row2['Tax ID'] in tax_ids:
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
if counter in filtered_rows:
fwd.write(out_file, **io_kwargs)
if (counter + 1) in filtered_rows:
rev.write(rev_out_file, **io_kwargs)
counter += 2
else:
for fwd, rev in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs)):
if exclude_reads:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
row2 = next(reader)
if (row['Passed Filter'] == 'T' and row['Tax ID'] not in tax_ids) or \
(row2['Passed Filter'] == 'T' and row2['Tax ID'] not in tax_ids):
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
for fwd, rev, row in zip(skbio.io.read(fastx, **io_kwargs),
skbio.io.read(reverse, **io_kwargs),
reader):
row2 = next(reader)
if (row['Passed Filter'] == 'T' and row['Tax ID'] in tax_ids) or \
(row2['Passed Filter'] == 'T' and row2['Tax ID'] in tax_ids):
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
bar.update(2)
else:
with io.open(filtered_filename, 'w') as out_file:
if include_lowconf:
if exclude_reads:
if counter not in filtered_rows and \
(counter + 1) not in filtered_rows:
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
for fwd, row in zip(skbio.io.read(fastx, **io_kwargs), reader):
if row['Tax ID'] not in tax_ids:
fwd.write(out_file, **io_kwargs)
bar.update(1)
else:
if counter in filtered_rows or \
(counter + 1) in filtered_rows:
fwd.write(out_file, **io_kwargs)
rev.write(rev_out_file, **io_kwargs)
counter += 2
else:
with io.open(filtered_filename, 'w') as out_file:
for seq in skbio.io.read(fastx, **io_kwargs):
if exclude_reads:
if counter not in filtered_rows:
seq.write(out_file, **io_kwargs)
for fwd, row in zip(skbio.io.read(fastx, **io_kwargs), reader):
if row['Tax ID'] in tax_ids:
fwd.write(out_file, **io_kwargs)
bar.update(1)
else:
if counter in filtered_rows:
seq.write(out_file, **io_kwargs)
counter += 1
if exclude_reads:
for fwd, row in zip(skbio.io.read(fastx, **io_kwargs), reader):
if row['Passed Filter'] == 'T' and row['Tax ID'] not in tax_ids:
fwd.write(out_file, **io_kwargs)
bar.update(1)
else:
for fwd, row in zip(skbio.io.read(fastx, **io_kwargs), reader):
if row['Passed Filter'] == 'T' and row['Tax ID'] in tax_ids:
fwd.write(out_file, **io_kwargs)
bar.update(1)
5 changes: 3 additions & 2 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
import hashlib
import os
import pytest; pytest.importorskip('skbio') # noqa
import pytest; pytest.importorskip('skbio')
import shutil

from onecodex import Cli
Expand Down Expand Up @@ -97,7 +98,7 @@ def md5sum(filepath):
if include_lowconf:
args += ['--include-lowconf']

result = runner.invoke(Cli, args)
result = runner.invoke(Cli, args, catch_exceptions=False)
assert 'Using cached read-level results' in result.output

results_digests = []
Expand Down

0 comments on commit d78623a

Please sign in to comment.