Skip to content

Commit

Permalink
random pairs, closes #830
Browse files Browse the repository at this point in the history
  • Loading branch information
fgregg committed Jan 19, 2022
1 parent 6062838 commit 6c59172
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 43 deletions.
32 changes: 19 additions & 13 deletions dedupe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import itertools
import tempfile
import os
import random
import collections
import warnings
import functools

import multiprocessing
import multiprocessing.dummy
from typing import (Iterator,
Tuple,
Mapping,
Expand All @@ -20,15 +20,22 @@
Type,
Iterable,
overload)

import numpy

from dedupe._typing import (RecordPairs,
RecordID,
Blocks,
Data,
Literal)

import numpy
import multiprocessing
import multiprocessing.dummy
from numpy.random import default_rng
rng = default_rng()

try:
rng_integers = rng.integers
except AttributeError:
rng_integers = rng.randint


class ChildProcessError(Exception):
Expand Down Expand Up @@ -57,8 +64,8 @@ def randomPairs(n_records: int, sample_size: int) -> IndicesIterator:
random_pairs = numpy.arange(n)
else:
try:
random_pairs = numpy.array(random.sample(range(n), sample_size))
except OverflowError:
random_pairs = rng_integers(n, size=sample_size)
except (OverflowError, ValueError):
return randomPairsWithReplacement(n_records, sample_size)

b: int = 1 - 2 * n_records
Expand All @@ -80,8 +87,7 @@ def randomPairsMatch(n_records_A: int, n_records_B: int, sample_size: int) -> In
if sample_size >= n:
random_pairs = numpy.arange(n)
else:
random_pairs = numpy.array(random.sample(range(n), sample_size),
dtype=int)
random_pairs = rng_integers(n, size=sample_size)

i, j = numpy.unravel_index(random_pairs, (n_records_A, n_records_B))

Expand All @@ -94,14 +100,14 @@ def randomPairsWithReplacement(n_records: int, sample_size: int) -> IndicesItera
warnings.warn("The same record pair may appear more than once in the sample")

try:
random_indices = numpy.random.randint(n_records,
size=sample_size * 2)
random_indices = rng_integers(n_records,
size=sample_size * 2)
except (OverflowError, ValueError):
max_int: int = numpy.iinfo('int').max
warnings.warn("Asked to sample pairs from %d records, will only sample pairs from first %d records" % (n_records, max_int))

random_indices = numpy.random.randint(max_int,
size=sample_size * 2)
random_indices = rng_integers(max_int,
size=sample_size * 2)

random_indices = random_indices.reshape((-1, 2))
random_indices.sort(axis=1)
Expand Down
30 changes: 0 additions & 30 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,6 @@

class RandomPairsTest(unittest.TestCase):
def test_random_pair(self):
random.seed(123)

if sys.version_info < (3, 0):
target = [(0, 3), (0, 4), (2, 4), (0, 5), (6, 8)]
else:
target = [(0, 4), (2, 3), (0, 6), (3, 6), (0, 7)]

random_pairs = list(dedupe.core.randomPairs(10, 5))
assert random_pairs == target

random.seed(123)
if sys.version_info < (3, 0):
target = [(265, 3429)]
else:
target = [(357, 8322)]

random_pairs = list(dedupe.core.randomPairs(10**4, 1))
assert random_pairs == target

random_pairs = list(dedupe.core.randomPairs(10**10, 1))

Expand All @@ -35,18 +17,6 @@ def test_random_pair_match(self):
assert len(list(dedupe.core.randomPairsMatch(100, 100, 100))) == 100
assert len(list(dedupe.core.randomPairsMatch(10, 10, 99))) == 99

random.seed(123)
random.seed(123)
if sys.version_info < (3, 0):
target = [(0, 5), (0, 8), (4, 0), (1, 0), (9, 0),
(0, 3), (5, 3), (3, 3), (8, 5), (1, 5)]
else:
target = [(0, 6), (3, 4), (1, 1), (9, 8), (5, 2),
(1, 3), (0, 4), (4, 8), (6, 8), (7, 1)]

pairs = list(dedupe.core.randomPairsMatch(10, 10, 10))
assert pairs == target

pairs = list(dedupe.core.randomPairsMatch(10, 10, 0))
assert pairs == []

Expand Down

0 comments on commit 6c59172

Please sign in to comment.