diff --git a/README.md b/README.md index 414dc575a..6e9a8d78b 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ The main program to design probes is [`design.py`](./bin/design.py). To see details on all the arguments that the program accepts, run: ```bash -design.py -h +design.py --help ``` [`design.py`](./bin/design.py) requires one or more `dataset`s that specify input sequence data to target: @@ -116,6 +116,9 @@ Probes are designed such that each `dataset` should be captured by probes that a * `--add-adapters`: Add PCR adapters to the ends of each probe sequence. This selects adapters to add to probe sequences so as to minimize overlap among probes that share an adapter, allowing probes with the same adapter to be amplified together. (See `--adapter-a` and `--adapter-b` too.) +* `--filter-with-lsh-hamming FILTER_WITH_LSH_HAMMING`/`--filter-with-lsh-minhash FILTER_WITH_LSH_MINHASH`: Use locality-sensitive hashing to reduce the space of candidate probes. +This can significantly improve runtime and memory requirements when the input is especially large and diverse. +See `design.py --help` for details on using these options and downsides. * `-o OUTPUT`: Write probe sequences in FASTA format to OUTPUT. ### Pooling across many runs ([`pool.py`](./bin/pool.py)) @@ -125,7 +128,7 @@ It does this by searching over a space of probe sets to solve a constrained opti To see details on all the arguments that the program accepts, run: ```bash -pool.py -h +pool.py --help ``` You need to run [`design.py`](./bin/design.py) on each dataset over a grid of parameters values that spans a reasonable domain. diff --git a/bin/design.py b/bin/design.py index 0a0c1feab..b1dc52468 100755 --- a/bin/design.py +++ b/bin/design.py @@ -14,6 +14,7 @@ from catch.filter import duplicate_filter from catch.filter import fasta_filter from catch.filter import n_expansion_filter +from catch.filter import near_duplicate_filter from catch.filter import probe_designer from catch.filter import reverse_complement_filter from catch.filter import set_cover_filter @@ -132,15 +133,38 @@ def main(args): skip_reverse_complements=True) filters += [ff] - # Duplicate filter (df) -- condense all candidate probes that + # Duplicate filter (df) -- condense all candidate probes that # are identical down to one; this is not necessary for # correctness, as the set cover filter achieves the same task # implicitly, but it does significantly lower runtime by # decreasing the input size to the set cover filter - df = duplicate_filter.DuplicateFilter() - filters += [df] + # Near duplicate filter (ndf) -- condense candidate probes that + # are near-duplicates down to one using locality-sensitive + # hashing; like the duplicate filter, this is not necessary + # but can significantly lower runtime and reduce memory usage + # (even more than the duplicate filter) + if (args.filter_with_lsh_hamming is not None and + args.filter_with_lsh_minhash is not None): + raise Exception(("Cannot use both --filter-with-lsh-hamming " + "and --filter-with-lsh-minhash")) + if args.filter_with_lsh_hamming is not None: + if args.filter_with_lsh_hamming > args.mismatches: + logger.warning(("Setting FILTER_WITH_LSH_HAMMING (%d) to be greater " + "than MISMATCHES (%d) may cause the probes to achieve less " + "than the desired coverage"), args.filter_with_lsh_hamming, + args.mismatches) + ndf = near_duplicate_filter.NearDuplicateFilterWithHammingDistance( + args.filter_with_lsh_hamming, args.probe_length) + filters += [ndf] + elif args.filter_with_lsh_minhash is not None: + ndf = near_duplicate_filter.NearDuplicateFilterWithMinHash( + args.filter_with_lsh_minhash) + filters += [ndf] + else: + df = duplicate_filter.DuplicateFilter() + filters += [df] - # Set cover filter (scf) -- solve the problem by treating it as + # Set cover filter (scf) -- solve the problem by treating it as # an instance of the set cover problem scf = set_cover_filter.SetCoverFilter( mismatches=args.mismatches, @@ -445,6 +469,49 @@ def check_coverage(val): "replacement")) # Technical adjustments + parser.add_argument('--filter-with-lsh-hamming', + type=int, + help=("(Optional) If set, filter candidate probes for near-" + "duplicates using LSH with a family of hash functions that " + "works with Hamming distance. FILTER_WITH_LSH_HAMMING gives " + "the maximum Hamming distance at which to call near-" + "duplicates; it should be commensurate with (but not greater " + "than) MISMATCHES. Using this may significantly improve " + "runtime and reduce memory usage by reducing the number of " + "candidate probes to consider, but may lead to a slightly " + "sub-optimal solution. It may also, particularly with " + "relatively high values of FILTER_WITH_LSH_HAMMING, cause " + "coverage obtained for each genome to be slightly less than " + "the desired coverage (COVERAGE) when that desired coverage " + "is the complete genome; it is recommended to also use " + "--print-analysis or --write-analysis-to-tsv with this " + "to see the coverage that is obtained.")) + def check_filter_with_lsh_minhash(val): + fval = float(val) + if fval >= 0.0 and fval <= 1.0: + # a float in [0,1] + return fval + else: + raise argparse.ArgumentTypeError(("%s is an invalid Jaccard " + "distance") % val) + parser.add_argument('--filter-with-lsh-minhash', + type=check_filter_with_lsh_minhash, + help=("(Optional) If set, filter candidate probes for near-" + "duplicates using LSH with a MinHash family. " + "FILTER_WITH_LSH_MINHASH gives the maximum Jaccard distance " + "(1 minus Jaccard similarity) at which to call near-duplicates; " + "the Jaccard similarity is calculated by treating each probe " + "as a set of overlapping 10-mers. Its value should be " + "commensurate with parameter values determining whether a probe " + "hybridizes to a target sequence, but this can be difficult " + "to measure compared to the input for --filter-with-lsh-hamming. " + "However, this allows more sensitivity in near-duplicate " + "detection than --filter-with-lsh-hamming (e.g., if near-" + "duplicates should involve probes shifted relative to each " + "other). The same caveats mentioned in help for " + "--filter-with-lsh-hamming also apply here. Values of " + "FILTER_WITH_LSH_MINHASH above ~0.7 may start to require " + "significant memory and runtime for near-duplicate detection.")) parser.add_argument('--cover-groupings-separately', dest="cover_groupings_separately", action="store_true", diff --git a/catch/filter/near_duplicate_filter.py b/catch/filter/near_duplicate_filter.py new file mode 100644 index 000000000..b587d503a --- /dev/null +++ b/catch/filter/near_duplicate_filter.py @@ -0,0 +1,181 @@ +"""Removes near-duplicates from an input list of probes. + +This acts as a filter on the probes by removing ones that are +near-duplicates of another using LSH. There might be near-duplicates +in the output that are not detected, but every near-duplicate removed +should indeed be a near-duplicate as defined by the given criteria. +""" + +from collections import defaultdict +import math +import operator + +from catch.filter.base_filter import BaseFilter +from catch.utils import lsh + +__author__ = 'Hayden Metsky ' + + +class NearDuplicateFilter(BaseFilter): + """Filter that removes near-duplicates using LSH. + + This constructs a concatenation of k hash functions, and does + this multiple times so as to achieve a desired probability of + reporting any probe as a near-duplicate of a queried probe. k + can be a constant, and the number of (concatenated) hash functions + to use is calculated to achieve the desired reporting probability. + + This sorts input probes by their multiplicity; therefore, the + duplicate filter should *not* be run before this. + """ + + def __init__(self, k, reporting_prob=0.95): + """ + Args: + k: number of hash functions to draw from a family of + hash functions for amplification; each hash function is then + the concatenation (h_1, h_2, ..., h_k) + reporting_prob: ensure that any probe within dist_thres of + a queried probe is detected as such; this constructs + multiple hash functions (each of which is a concatenation + of k functions drawn from the family) to achieve this + probability + """ + self.k = k + self.reporting_prob = reporting_prob + + def _filter(self, input): + """Filter with an arbitrary LSH family. + + This performs near neighbor lookups using self.lsh_family. It only + calls probes near-duplicates if their distance, according to + self.dist_fn, is within self.dist_thres. + + Args: + input: collection of probes to filter + + Returns: + subset of input + """ + # Sort the probes by their mulitiplicity (descending) + occurrences = defaultdict(int) + for p in input: + occurrences[p] += 1 + input_sorted = [p for p, count in + sorted(occurrences.items(), key=operator.itemgetter(1), + reverse=True)] + + # Remove exact duplicates from the input + input = list(set(input)) + + # Construct a collection of hash tables for looking up + # near neighbors of each probe + nnl = lsh.NearNeighborLookup(self.lsh_family, self.k, self.dist_thres, + self.dist_fn, self.reporting_prob) + nnl.add(input) + + # Iterate through all probes in order; for each p, remove others + # that are near-duplicates (neighbors) of p. Since we iterate + # in sorted order by multiplicity, the ones that hit more targets + # should appear earlier and will be included in the filtered output + to_include = set() + to_exclude = set() + for p in input_sorted: + # p should not have already been included because input_sorted + # should not contain duplicates + assert p not in to_include + + if p in to_exclude: + # p is already being filtered out + continue + + # Include p in the output and exclude all near-duplicates of it + to_include.add(p) + for near_dup in nnl.query(p): + if near_dup not in to_include: + to_exclude.add(near_dup) + + # Check that every probe is either included or excluded and + # that none are both included and excluded + assert len(to_include | to_exclude) == len(input_sorted) + assert len(to_include & to_exclude) == 0 + + return list(to_include) + + +class NearDuplicateFilterWithHammingDistance(NearDuplicateFilter): + """Filter that removes near-duplicates according to Hamming distance. + """ + + def __init__(self, dist_thres, probe_length): + """ + Args: + dist_thres: only call two probes near-duplicates if their + Hamming distance is within this value; this should be + equal to or commensurate with (but not greater than) + the number of mismatches at/below which a probe is + considered to hybridize to a target sequence so that + candidate probes further apart than this value are not + collapsed as near-duplicates + probe_length: length of probes + """ + super().__init__(k=20) + self.lsh_family = lsh.HammingDistanceFamily(probe_length) + self.dist_thres = dist_thres + + def hamming_dist(a, b): + # a and b are probe.Probe objects + return a.mismatches(b) + self.dist_fn = hamming_dist + + def _filter(self, input): + """Filter with LSH using family that works with Hamming distance. + + Args: + input: collection of probes to filter + + Returns: + subset of input + """ + return NearDuplicateFilter._filter(self, input) + + +class NearDuplicateFilterWithMinHash(NearDuplicateFilter): + """Filter that removes near-duplicates using MinHash. + """ + + def __init__(self, dist_thres, kmer_size=10): + """ + Args: + dist_thres: only call two probes near-duplicates if their + Jaccard distance (1 minus Jaccard similarity) is within + this value; the Jaccard similarity is measured by treating + each probe sequence as a set of k-mers and measuring + the overlap of those k-mers + kmer_size: the length of each k-mer to use with MinHash; note + that this is *not* the same as self.k + """ + super().__init__(k=3) + self.lsh_family = lsh.MinHashFamily(kmer_size) + self.dist_thres = dist_thres + + def jaccard_dist(a, b): + a_kmers = [a[i:(i + kmer_size)] for i in range(len(a) - kmer_size + 1)] + b_kmers = [b[i:(i + kmer_size)] for i in range(len(b) - kmer_size + 1)] + a_kmers = set(a_kmers) + b_kmers = set(b_kmers) + jaccard_sim = float(len(a_kmers & b_kmers)) / len(a_kmers | b_kmers) + return 1.0 - jaccard_sim + self.dist_fn = jaccard_dist + + def _filter(self, input): + """Filter with LSH using MinHash family. + + Args: + input: collection of probes to filter + + Returns: + subset of input + """ + return NearDuplicateFilter._filter(self, input) + diff --git a/catch/filter/tests/test_near_duplicate_filter.py b/catch/filter/tests/test_near_duplicate_filter.py new file mode 100644 index 000000000..fdacea31a --- /dev/null +++ b/catch/filter/tests/test_near_duplicate_filter.py @@ -0,0 +1,147 @@ +"""Tests for near_duplicate_filter module. +""" + +import random +import unittest + +from catch.filter import near_duplicate_filter as ndf +from catch import probe + +__author__ = 'Hayden Metsky ' + + +class TestNearDuplicateFilterWithHammingDistance(unittest.TestCase): + """Tests output of near duplicate filter according to Hamming distance. + """ + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + def test_all_similar_no_exact_dup(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithHammingDistance(2, 10) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 1) + self.assertIn(f.output_probes[0], input_probes) + + def test_all_similar_with_exact_dup(self): + input = ['ATCGTCGCGG', 'ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', + 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithHammingDistance(2, 10) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 1) + # The first probe in input_probes is the most common, so this + # should be the one that is kept + self.assertEqual(f.output_probes[0], input_probes[0]) + + def test_all_similar_but_zero_dist_thres(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithHammingDistance(0, 10) + f.k = 3 + f.filter(input_probes) + self.assertCountEqual(f.output_probes, input_probes) + + def test_all_similar_but_one_too_far(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCCT'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithHammingDistance(2, 10) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 2) + # The last probe in input_probes is barely >2 mismatches + # from the others + self.assertIn(input_probes[-1], f.output_probes) + + def test_two_clusters(self): + cluster1 = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + cluster2 = ['GGCTTACTGA', 'GGCTTACTGA', 'GGCTTTCTGA', 'GGCTTACTAT'] + input = cluster1 + cluster2 + random.shuffle(input) + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithHammingDistance(2, 10) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 2) + self.assertTrue((f.output_probes[0].seq_str in cluster1 and + f.output_probes[1].seq_str in cluster2) or + (f.output_probes[0].seq_str in cluster2 and + f.output_probes[1].seq_str in cluster1)) + + +class TestNearDuplicateFilterWithMinHash(unittest.TestCase): + """Tests output of near duplicate filter using MinHash. + """ + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + def test_all_similar_no_exact_dup(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithMinHash(0.8, 3) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 1) + self.assertIn(f.output_probes[0], input_probes) + + def test_all_similar_with_exact_dup(self): + input = ['ATCGTCGCGG', 'ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', + 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithMinHash(0.8, 3) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 1) + # The first probe in input_probes is the most common, so this + # should be the one that is kept + self.assertEqual(f.output_probes[0], input_probes[0]) + + def test_all_similar_but_zero_dist_thres(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithMinHash(0, 3) + f.k = 3 + f.filter(input_probes) + self.assertCountEqual(f.output_probes, input_probes) + + def test_all_similar_but_one_too_far(self): + input = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATTGGGGCCA'] + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithMinHash(0.8, 3) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 2) + # The last probe in input_probes is far from the others + self.assertIn(input_probes[-1], f.output_probes) + + def test_two_clusters(self): + cluster1 = ['ATCGTCGCGG', 'ATCGTGGCGG', 'TTCGTCGCGG', 'ATCGGCGCGG'] + cluster2 = ['GGCTTACTGA', 'GGCTTACTGA', 'GGCTTTCTGA', 'GGCTTACTAT'] + input = cluster1 + cluster2 + random.shuffle(input) + input_probes = [probe.Probe.from_str(s) for s in input] + + f = ndf.NearDuplicateFilterWithMinHash(0.8, 3) + f.k = 3 + f.filter(input_probes) + self.assertEqual(len(f.output_probes), 2) + self.assertTrue((f.output_probes[0].seq_str in cluster1 and + f.output_probes[1].seq_str in cluster2) or + (f.output_probes[0].seq_str in cluster2 and + f.output_probes[1].seq_str in cluster1)) diff --git a/catch/probe.py b/catch/probe.py index eb1f7d617..a1a899956 100644 --- a/catch/probe.py +++ b/catch/probe.py @@ -325,6 +325,12 @@ def __cmp__(self, other): # not equal the corresponding char in other.seq return cmp(self.seq[c[0]], other.seq[c[0]]) + def __len__(self): + return len(self.seq) + + def __getitem__(self, i): + return self.seq_str[i] + def __str__(self): return self.seq_str diff --git a/catch/utils/lsh.py b/catch/utils/lsh.py new file mode 100644 index 000000000..ae8fed87f --- /dev/null +++ b/catch/utils/lsh.py @@ -0,0 +1,224 @@ +"""Classes and methods for applying locality-sensitive hashing. +""" + +from collections import defaultdict +import logging +import math +import random +import zlib + +__author__ = 'Hayden Metsky ' + +logger = logging.getLogger(__name__) + + +class HammingDistanceFamily: + """An LSH family that works with Hamming distance by sampling bases.""" + + def __init__(self, dim): + self.dim = dim + + def make_h(self): + """Construct a random hash function for this family. + + Returns: + hash function + """ + i = random.randint(0, self.dim - 1) + def h(x): + assert len(x) == self.dim + return x[i] + return h + + def P1(self, dist): + """Calculate lower bound on probability of collision for nearby sequences. + + Args: + dist: Hamming distance; suppose two sequences are within this + distance of each other + + Returns: + lower bound on probability that two sequences (e.g., probes) hash + to the same value if they are within dist of each other + """ + return 1.0 - float(dist) / float(self.dim) + + +class MinHashFamily: + """An LSH family that works by taking the minimum permutation of + k-mers in a string/sequence (MinHash). + + See (Broder et al. 1997) and (Andoni and Indyk 2008) for details. + """ + + def __init__(self, kmer_size): + self.kmer_size = kmer_size + + def make_h(self): + """Construct a random hash function for this family. + + Here, we treat a sequence as being a set of k-mers. We calculate + a hash value for each k-mer and the hash function on the sequence + returns the minimum of these. + + Returns: + hash function + """ + # First construct a random hash function for a k-mer that + # is a universal hash function (effectively a "permutation" + # of the k-mer) + # Constrain all values to be in [0, 2^31 - 1] to have a bound + # on the output of the universal hash function; this upper bound + # is nice because it is also a prime, so we can simply work + # modulo (2^31 - 1) + p = 2**31 - 1 + # Let the random hash function be: + # (a*x + b) mod p + # for random integers a, b (a in [1, p] and b in [0, p]) + a = random.randint(1, p) + b = random.randint(0, p) + def kmer_hash(x): + # Hash a k-mer x with the random hash function + # hash(..) uses a random seed in Python 3.3+, so its output + # varies across Python processes; use zlib.adler32(..) for a + # deterministic hash value of the k-mer + x_hash = zlib.adler32(x.encode('utf-8')) + return (a * x_hash + b) % p + + def h(s): + # For a string/sequence s, have the MinHash function be the minimum + # hash over all the k-mers in it + assert self.kmer_size <= len(s) + if self.kmer_size >= len(s) / 2: + logger.warning(("The k-mer size %d is large (> (1/2)x) " + "compared to the size of a sequence to hash (%d), which " + "might make it difficult for MinHash to find similar " + "sequence")) + kmer_hashes = [] + for i in range(len(s) - self.kmer_size + 1): + kmer = s[i:(i + self.kmer_size)] + kmer_hashes += [kmer_hash(kmer)] + return min(kmer_hashes) + return h + + def P1(self, dist): + """Calculate lower bound on probability of collision for nearby sequences. + + Args: + dist: Jaccard distance (1 minus Jaccard similarity); suppose + two sequences are within this distance of each other. The + Jaccard similarity can be thought of as the overlap in k-mers + between the two sequences + + Returns: + lower bound on probability that two sequences (e.g., probes) hash + to the same value if they are within dist of each other + """ + # With MinHash, the collision probability is the Jaccard similarity + return 1.0 - dist + + +class HashConcatenation: + """Concatenated hash functions (AND constructions).""" + + def __init__(self, family, k): + self.family = family + self.k = k + self.hs = [family.make_h() for _ in range(k)] + + def g(self, x): + """Evaluate random hash functions and concatenate the result. + + Args: + x: point (e.g., probe) on which to evaluate hash functions + + Returns: + concatenation of the result of the self.k random hash functions + evaluated at x + """ + return tuple([h(x) for h in self.hs]) + + +class NearNeighborLookup: + """Support for approximate near neighbor lookups. + + This implements the R-near neighbor reporting problem described in + Andoni and Indyk 2008. + """ + + def __init__(self, family, k, dist_thres, dist_fn, reporting_prob): + """ + This selects a number of hash tables (defined as L in the above + reference) according to the strategy it outlines: we want any + neighbor (within dist_thres) of a query to be reported with + probability at least reporting_prob; thus, the number of + tables should be [log_{1 - (P1)^k} (1 - reporting_prob)]. In + the above reference, delta is 1.0 - reporting_prob. + + Args: + family: object giving family of hash functions + k: number of hash functions from family to concatenate + dist_thres: consider any two objects within this threshold + of each other to be neighbors + dist_fn: function f(a, b) that calculates the distance between + a and b, to compare against dist_thres + reporting_prob: report any neighbor of a query with + probability at least equal to this + """ + self.family = family + self.k = k + self.dist_thres = dist_thres + self.dist_fn = dist_fn + + P1 = self.family.P1(dist_thres) + if P1 == 1.0: + # dist_thres might be 0, and any number of hash tables can + # satisfy the reporting probability + self.num_tables = 1 + else: + self.num_tables = math.log(1.0 - reporting_prob, 1.0 - math.pow(P1, k)) + self.num_tables = int(math.ceil(self.num_tables)) + + # Setup self.num_tables hash tables, each with a corresponding + # function for hashing into it (the functions are concatenations + # of k hash functions from the given family) + self.hashtables = [] + self.hashtables_g = [] + for j in range(self.num_tables): + g = HashConcatenation(self.family, self.k) + self.hashtables += [defaultdict(list)] + self.hashtables_g += [g] + + def add(self, pts): + """Insert given points into each of the hash tables. + + Args: + pts: collection of points (e.g., probes) to add to the hash + tables + """ + for j in range(self.num_tables): + ht = self.hashtables[j] + g = self.hashtables_g[j].g + for p in pts: + ht[g(p)].append(p) + + def query(self, q): + """Find neighbors of a query point. + + Args: + q: query point (e.g., probe) + + Returns: + collection of stored points that are within self.dist_thres of + q; all returned points are within this distance, but the + returned points might not include all that are + """ + neighbors = set() + for j in range(self.num_tables): + ht = self.hashtables[j] + g = self.hashtables_g[j].g + for p in ht[g(q)]: + if self.dist_fn(q, p) <= self.dist_thres: + neighbors.add(p) + return neighbors + diff --git a/catch/utils/tests/test_lsh.py b/catch/utils/tests/test_lsh.py new file mode 100644 index 000000000..ad286ded1 --- /dev/null +++ b/catch/utils/tests/test_lsh.py @@ -0,0 +1,231 @@ +"""Tests for lsh module. +""" + +import random +import unittest + +from catch.utils import lsh + +__author__ = 'Hayden Metsky ' + + +class TestHammingDistanceFamily(unittest.TestCase): + """Tests family of hash functions for Hamming distance. + """ + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + self.family = lsh.HammingDistanceFamily(20) + + def test_identical(self): + a = 'ATCGATATGGGCACTGCTAT' + b = str(a) + + # Identical strings should hash to the same value + h1 = self.family.make_h() + self.assertEqual(h1(a), h1(b)) + h2 = self.family.make_h() + self.assertEqual(h2(a), h2(b)) + + def test_similar(self): + a = 'ATCGATATGGGCACTGCTAT' + b = 'ATCGACATGGGCACTGGTAT' + + # a and b should probably collide + collision_count = 0 + for i in range(10): + h = self.family.make_h() + if h(a) == h(b): + collision_count += 1 + self.assertGreater(collision_count, 8) + + def test_not_similar(self): + a = 'ATCGATATGGGCACTGCTAT' + b = 'AGTTGTCACCCTTGACGATA' + + # a and b should rarely collide + collision_count = 0 + for i in range(10): + h = self.family.make_h() + if h(a) == h(b): + collision_count += 1 + self.assertLess(collision_count, 2) + + def test_collision_prob(self): + # Collision probability for 2 mismatches should be + # 1 - 2/20 + self.assertEqual(self.family.P1(2), 0.9) + + +class TestMinHashFamily(unittest.TestCase): + """Tests family of hash functions for MinHash. + """ + + def setUp(self): + # Set a random sseed so hash functions are always the same + random.seed(0) + + self.family = lsh.MinHashFamily(3) + + def test_identical(self): + a = 'ATCGATATGGGCACTGCTAT' + b = str(a) + + # Identical strings should hash to the same value + h1 = self.family.make_h() + self.assertEqual(h1(a), h1(b)) + h2 = self.family.make_h() + self.assertEqual(h2(a), h2(b)) + + def test_similar(self): + a = 'ATCGATATGGGCACTGCTATGTAGCGC' + b = 'ATCGACATGGGCACTGGTATGTAGCGC' + + # a and b should probably collide; the Jaccard similarity + # of a and b is ~67% (with 3-mers being the elements that + # make up each sequence) so they should collide with that + # probability (check that it is >60%) + collision_count = 0 + for i in range(100): + h = self.family.make_h() + if h(a) == h(b): + collision_count += 1 + self.assertGreater(collision_count, 60) + + def test_not_similar(self): + a = 'ATCGATATGGGCACTGCTAT' + b = 'AGTTGTCACCCTTGACGATA' + + # a and b should rarely collide + collision_count = 0 + for i in range(100): + h = self.family.make_h() + if h(a) == h(b): + collision_count += 1 + self.assertLess(collision_count, 30) + + def test_collision_prob(self): + # Collision probability for two sequences with a Jaccard + # distance of 0.2 should be 0.8 + self.assertEqual(self.family.P1(0.2), 0.8) + + +class TestHammingHashConcatenation(unittest.TestCase): + """Tests concatenations of hash functions with Hamming distance. + """ + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + self.family = lsh.HammingDistanceFamily(20) + self.G = lsh.HashConcatenation(self.family, 100) + + def test_identical(self): + # Identical a and b should collide even with large k + a = 'ATCGATATGGGCACTGCTAT' + b = str(a) + self.assertEqual(self.G.g(a), self.G.g(b)) + + def test_similar(self): + # Similar (but not identical) a and b should rarely + # collide when k is large + a = 'ATCGATATGGGCACTGCTAT' + b = 'ATCGACATGGGCACTGGTAT' + + collision_count = 0 + for i in range(10): + if self.G.g(a) == self.G.g(b): + collision_count += 1 + self.assertLess(collision_count, 2) + + def test_not_similar(self): + a = 'ATCGATATGGGCACTGCTAT' + b = 'AGTTGTCACCCTTGACGATA' + + # a and b should rarely collide + collision_count = 0 + for i in range(10): + if self.G.g(a) == self.G.g(b): + collision_count += 1 + self.assertLess(collision_count, 2) + + +class TestHammingNearNeighborLookup(unittest.TestCase): + """Tests approximate near neighbor lookups with Hamming distance.""" + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + self.family = lsh.HammingDistanceFamily(20) + self.dist_thres = 5 + def f(a, b): + assert len(a) == len(b) + return sum(1 for i in range(len(a)) if a[i] != b[i]) + self.dist_fn = f + + def test_varied_k(self): + a = 'ATCGATATGGGCACTGCTAT' + b = str(a) # identical to a + c = 'ATCGACATGGGCACTGGTAT' # similar to a + d = 'AGTTGTCACCCTTGACGATA' # not similar to a + e = 'AGTTGTCACCCTTGACGATA' # similar to d + + for k in [2, 5, 10]: + nnl = lsh.NearNeighborLookup(self.family, k, self.dist_thres, + self.dist_fn, 0.95) + nnl.add([a, b, c, d]) + + # b and c are within self.dist_thres of a, so only these + # should be returned (along with a); note that since + # a==b, {a,b,c}=={a,c}=={b,c} and nnl.query(a) returns + # a set, which will be {a,c} or {b,c} + self.assertCountEqual(nnl.query(a), {a, b, c}) + + # Although e was not added, a query for it should return d + self.assertCountEqual(nnl.query(e), {d}) + + +class TestMinHashNearNeighborLookup(unittest.TestCase): + """Tests approximate near neighbor lookups with MinHash.""" + + def setUp(self): + # Set a random seed so hash functions are always the same + random.seed(0) + + kmer_size = 3 + self.family = lsh.MinHashFamily(kmer_size) + self.dist_thres = 0.5 + def f(a, b): + a_kmers = [a[i:(i + kmer_size)] for i in range(len(a) - kmer_size + 1)] + b_kmers = [b[i:(i + kmer_size)] for i in range(len(b) - kmer_size + 1)] + a_kmers = set(a_kmers) + b_kmers = set(b_kmers) + jaccard_sim = float(len(a_kmers & b_kmers)) / len(a_kmers | b_kmers) + return 1.0 - jaccard_sim + self.dist_fn = f + + def test_varied_k(self): + a = 'ATCGATATGGGCACTGCTAT' + b = str(a) # identical to a + c = 'ATCGACATGGGCACTGGTAT' # similar to a + d = 'AGTTGTCACCCTTGACGATA' # not similar to a + e = 'AGTTGTCACCCTTGACGATA' # similar to d + + for k in [2, 5, 10]: + nnl = lsh.NearNeighborLookup(self.family, k, self.dist_thres, + self.dist_fn, 0.95) + nnl.add([a, b, c, d]) + + # b and c are within self.dist_thres of a, so only these + # should be returned (along with a); note that since + # a==b, {a,b,c}=={a,c}=={b,c} and nnl.query(a) returns + # a set, which will be {a,c} or {b,c} + self.assertCountEqual(nnl.query(a), {a, b, c}) + + # Although e was not added, a query for it should return d + self.assertCountEqual(nnl.query(e), {d}) +