Skip to content

Commit

Permalink
consistently return iters in random pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
fgregg committed Jan 27, 2022
1 parent 7317798 commit 0d62b2e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
10 changes: 6 additions & 4 deletions dedupe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import multiprocessing
import multiprocessing.dummy
import queue

from typing import (Iterator,
Tuple,
Mapping,
Expand All @@ -25,7 +24,6 @@
overload)

import numpy
from numpy.random import default_rng

from dedupe._typing import (RecordPairs,
RecordID,
Expand All @@ -52,7 +50,9 @@ def randomPairs(n_records: int, sample_size: int) -> IndicesIterator:
"""
n: int = n_records * (n_records - 1) // 2

if sample_size >= n:
if not sample_size:
return iter([])
elif sample_size >= n:
random_pairs = numpy.arange(n)
else:
try:
Expand All @@ -76,7 +76,9 @@ def randomPairsMatch(n_records_A: int, n_records_B: int, sample_size: int) -> In
"""
n: int = n_records_A * n_records_B

if sample_size >= n:
if not sample_size:
return iter([])
elif sample_size >= n:
random_pairs = numpy.arange(n)
else:
random_pairs = numpy.array(random.sample(range(n), sample_size))
Expand Down
19 changes: 4 additions & 15 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,13 @@ class RandomPairsTest(unittest.TestCase):
def test_random_pair(self):
random.seed(123)

if sys.version_info < (3, 0):
target = [(0, 3), (0, 4), (2, 4), (0, 5), (6, 8)]
else:
target = [(0, 4), (2, 3), (0, 6), (3, 6), (0, 7)]
target = [(0, 4), (2, 3), (0, 6), (3, 6), (0, 7)]

random_pairs = list(dedupe.core.randomPairs(10, 5))
assert random_pairs == target

random.seed(123)
if sys.version_info < (3, 0):
target = [(265, 3429)]
else:
target = [(357, 8322)]
target = [(357, 8322)]

random_pairs = list(dedupe.core.randomPairs(10**4, 1))
assert random_pairs == target
Expand All @@ -35,13 +29,8 @@ def test_random_pair_match(self):
assert len(list(dedupe.core.randomPairsMatch(10, 10, 99))) == 99

random.seed(123)
random.seed(123)
if sys.version_info < (3, 0):
target = [(0, 5), (0, 8), (4, 0), (1, 0), (9, 0),
(0, 3), (5, 3), (3, 3), (8, 5), (1, 5)]
else:
target = [(0, 6), (3, 4), (1, 1), (9, 8), (5, 2),
(1, 3), (0, 4), (4, 8), (6, 8), (7, 1)]
target = [(0, 6), (3, 4), (1, 1), (9, 8), (5, 2),
(1, 3), (0, 4), (4, 8), (6, 8), (7, 1)]

pairs = list(dedupe.core.randomPairsMatch(10, 10, 10))
assert pairs == target
Expand Down

0 comments on commit 0d62b2e

Please sign in to comment.