Skip to content

Commit

Permalink
Merge ec12960 into f363cb3
Browse files Browse the repository at this point in the history
  • Loading branch information
haydenm committed Sep 7, 2022
2 parents f363cb3 + ec12960 commit 04f2a9d
Show file tree
Hide file tree
Showing 14 changed files with 1,373 additions and 433 deletions.
167 changes: 126 additions & 41 deletions README.md

Large diffs are not rendered by default.

171 changes: 143 additions & 28 deletions bin/design.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions bin/design_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3
"""Design probes for genome capture, with options and parameters that
optimize resource usage for large, highly diverse input.
The downside of these options is usually a small increase in the number of
designed probes. As with design.py, this program still supports full
customization of the argument values.
This wraps design.py and offers a way to run design.py without requiring
deep familiarity with CATCH's options. That is, it takes into account
recommendations that often work well in practice.
"""

import design

__author__ = 'Hayden Metsky <hayden@broadinstitute.org>'


if __name__ == "__main__":
args = design.init_and_parse_args(args_type='large')
design.main(args)
144 changes: 135 additions & 9 deletions catch/filter/base_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,36 @@
"""

import inspect
import multiprocessing

__author__ = 'Hayden Metsky <hayden@mit.edu>'


def set_max_num_processes_for_filter_over_groupings(max_num_processes=8):
"""Set the maximum number of processes to use for parallelizing calls
to _filter() across groupings.
Note that parallelization defined in this module does not always occur.
See the `num_processes` arg in BaseFilter.filter() for when it does
apply.
Args:
max_num_processes: an int (>= 1) specifying the maximum number of
processes to use in a multiprocessing.Pool when parallelizing
over groupings, i.e., the maximum number of target groupings
to filter in parallel; it uses min(the number of CPUs
in the system, max_num_processes) processes
"""
global _fg_max_num_processes
_fg_max_num_processes = max_num_processes
set_max_num_processes_for_filter_over_groupings()

# Define filter function to use in multiprocessing Pool; this must be
# top-level in the module, and we will ensure only one can be set at a time
global _global_filter_fn
_global_filter_fn = None


class BaseFilter:
"""Abstract class representing a filter for processing candidate probes.
Expand All @@ -18,31 +44,131 @@ class BaseFilter:
versions of the input or there may even be more output probes than
input probes.
For information about parallelization over groupings, see the
`num_processes` argument below.
All subclasses must implement a _filter(..) method that returns a
list of probes after processing from the given input list.
"""

def filter(self, input, target_genomes=None):
def filter(self, input, target_genomes=None, input_is_grouped=False,
num_processes=None):
"""Perform the filtering.
Args:
input: list of candidate probes
target_genomes: list [g_1, g_2, g_m] of m groupings of genomes,
input: candidate probes from which to filter; see input_is_grouped
for details
target_genomes: list [g_1, g_2, ..., g_m] of m groupings of genomes,
where each g_i is a list of genome.Genomes belonging to group
i, that should be targeted by the probes; for example a
group may be a species and each g_i would be a list of the
target genomes of species i
input_is_grouped: if True, input is list [p_1, p_2, ..., p_m] of
m groupings of genomes, where each p_i is a list of candidate
probes for group i; if False, input is a single list of
candidate probes (ungrouped)
num_processes: number of processes to use when parallelizing over
groupings; if None, this determines a number based on the
maximum specified and the number of CPUs. Note that
parallelization only happens when input_is_grouped is True
*and* self.requires_probe_groupings is not set or is False; if
that parameter is True and input_is_grouped is True, then
all groupings are passed to the subclass's filter and it
is up to self._filter() to parallelize over groupings
Returns:
list of probes after applying a filter to the input
if input_is_grouped is True:
list [q_1, q_2, q_m] where each q_i is a list of probes after
applying a filter to the corresponding input
else:
list of probes after applying a filter to the input
"""
_filter_params = inspect.signature(self._filter).parameters
if len(_filter_params) == 2:
# _filter() should accept both probes and target genomes
return self._filter(input, target_genomes)

# Determine whether self._filter() requires probes being
# split into groupings, or whether each group must be passed
# separately
if (hasattr(self, 'requires_probe_groupings') and
self.requires_probe_groupings is True):
pass_groupings = True
else:
pass_groupings = False

if pass_groupings:
# Input must already be grouped
assert input_is_grouped is True

if len(_filter_params) == 2:
# self._filter() should accept both probes and target genomes
return self._filter(input, target_genomes)
else:
# self._filter() may not need target genomes, and does not
# accept it
return self._filter(input)
else:
# _filter() may not need target genomes, and does not accept it
return self._filter(input)
if input_is_grouped:
# Call _filter() separately for each group, and parallelize
# calls across groupings

global _fg_max_num_processes
if num_processes is None:
num_processes = min(multiprocessing.cpu_count(),
_fg_max_num_processes)
pool = multiprocessing.Pool(num_processes)

# Order groupings in descending order
# by the number of possible probes (input size) in the group.
# The number is an indication of how long the grouping may
# take to filter, and we want to start the slower groupings
# first in the pool
input_lens = list(enumerate([len(x) for x in input]))
input_idx_ordered = [x[0] for x in sorted(input_lens,
key=lambda y: y[1], reverse=True)]
input_idx_revert = {y: x for x, y in
enumerate(input_idx_ordered)}
# Note that the reordered input is:
# [input[i] for i in input_idx_ordered]

# The function called by a multiprocessing Pool must be
# top-level
global _global_filter_fn
if _global_filter_fn is not None:
raise Exception(("Only one filter() function can be "
"called in parallel at a time"))
_global_filter_fn = self._filter

# Construct args to _filter()
if len(_filter_params) == 2:
# self._filter() should accept both probes and target genomes
pool_args = [(input[i], target_genomes)
for i in input_idx_ordered]
else:
# self._filter() may not need target genomes, and does not
# accept it
pool_args = [tuple([input[i]])
for i in input_idx_ordered]

# Run the pool, giving 1 grouping (chunksize=1) at a time
pool_out = pool.starmap(_global_filter_fn, pool_args,
chunksize=1)
pool.close()
_global_filter_fn = None

# Revert the order of the output to go back to the original
# ordering of the input
pool_out_reordered = [pool_out[input_idx_revert[i]]
for i in range(len(pool_out))]
return pool_out_reordered
else:
# Input is not grouped and there is no need to pass it grouped

if len(_filter_params) == 2:
# self._filter() should accept both probes and target genomes
return self._filter(input, target_genomes)
else:
# self._filter() may not need target genomes, and does not
# accept it
return self._filter(input)

def _filter(self, input):
raise Exception(("A subclass of BaseFilter must implement "
Expand Down
33 changes: 21 additions & 12 deletions catch/filter/near_duplicate_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class NearDuplicateFilter(BaseFilter):
duplicate filter should *not* be run before this.
"""

def __init__(self, k, reporting_prob=0.95):
def __init__(self, k, reporting_prob=0.80):
"""
Args:
k: number of hash functions to draw from a family of
Expand Down Expand Up @@ -103,6 +103,11 @@ def _filter(self, input):
return list(to_include)


# Keep top-level in module so it can be pickled
def hamming_dist(a, b):
# a and b are probe.Probe objects
return a.mismatches(b)

class NearDuplicateFilterWithHammingDistance(NearDuplicateFilter):
"""Filter that removes near-duplicates according to Hamming distance.
"""
Expand All @@ -123,9 +128,6 @@ def __init__(self, dist_thres, probe_length):
self.lsh_family = lsh.HammingDistanceFamily(probe_length)
self.dist_thres = dist_thres

def hamming_dist(a, b):
# a and b are probe.Probe objects
return a.mismatches(b)
self.dist_fn = hamming_dist

def _filter(self, input):
Expand All @@ -140,6 +142,20 @@ def _filter(self, input):
return NearDuplicateFilter._filter(self, input)


# Keep top-level in module so it can be pickled
# Since we cannot pickle nested functions, structure it as an object that
# can be called
class jaccard_dist_fn(object):
def __init__(self, kmer_size):
self.kmer_size = kmer_size
def __call__(self, a, b):
a_kmers = [a[i:(i + self.kmer_size)] for i in range(len(a) - self.kmer_size + 1)]
b_kmers = [b[i:(i + self.kmer_size)] for i in range(len(b) - self.kmer_size + 1)]
a_kmers = set(a_kmers)
b_kmers = set(b_kmers)
jaccard_sim = float(len(a_kmers & b_kmers)) / len(a_kmers | b_kmers)
return 1.0 - jaccard_sim

class NearDuplicateFilterWithMinHash(NearDuplicateFilter):
"""Filter that removes near-duplicates using MinHash.
"""
Expand All @@ -160,14 +176,7 @@ def __init__(self, dist_thres, kmer_size=10):
use_fast_str_hash=True) # safe as long as not parallelized
self.dist_thres = dist_thres

def jaccard_dist(a, b):
a_kmers = [a[i:(i + kmer_size)] for i in range(len(a) - kmer_size + 1)]
b_kmers = [b[i:(i + kmer_size)] for i in range(len(b) - kmer_size + 1)]
a_kmers = set(a_kmers)
b_kmers = set(b_kmers)
jaccard_sim = float(len(a_kmers & b_kmers)) / len(a_kmers | b_kmers)
return 1.0 - jaccard_sim
self.dist_fn = jaccard_dist
self.dist_fn = jaccard_dist_fn(kmer_size)

def _filter(self, input):
"""Filter with LSH using MinHash family.
Expand Down

0 comments on commit 04f2a9d

Please sign in to comment.