Skip to content

Commit

Permalink
better error message when mark_pairs sees unseen records
Browse files Browse the repository at this point in the history
mark_pairs has to see records that have passed, in some form,
to prepare_training.

this commit dcuments that, and gives a more useful error message
about it.
  • Loading branch information
fgregg committed Feb 3, 2022
1 parent 0d62b2e commit 9b21a90
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
44 changes: 32 additions & 12 deletions dedupe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,17 +1052,7 @@ def _read_training(self, training_file: TextIO) -> None:
training_pairs = json.load(training_file,
object_hook=serializer._from_json)

try:
self.mark_pairs(training_pairs)
except AttributeError as e:
if "Attempting to fingerprint with an index predicate without indexing records" in str(e):
raise UserWarning('Training data has records not known '
'to the active learner. Read training '
'in before initializing the active '
'learner with the sample method, or '
'use the prepare_training method.')
else:
raise
self.mark_pairs(training_pairs)

def train(self,
recall: float = 1.00,
Expand Down Expand Up @@ -1174,6 +1164,22 @@ def mark_pairs(self, labeled_pairs: TrainingData) -> None:
}
matcher.mark_pairs(labeled_examples)
.. note::
`mark_pairs` is primarily designed to be used with `uncertain_pairs`
to incrementally build a training set.
If you have existing training data, you should likely
format the data into the right form and supply the training
data with the `training_file` argument of the
`prepare_training` method.
If that is not possible or desirable, you can use
`mark_pairs` to train a linker using the `mark_pairs`
method. However, you must ensure that every record that
appears in the `labeled_pairs` argument appears in either
the data or `training_file` supplied to the
`prepare_method` argument.
'''
self._checkTrainingPairs(labeled_pairs)

Expand All @@ -1182,7 +1188,16 @@ def mark_pairs(self, labeled_pairs: TrainingData) -> None:

if self.active_learner:
examples, y = flatten_training(labeled_pairs)
self.active_learner.mark(examples, y)

try:
self.active_learner.mark(examples, y)
except dedupe.predicates.NoIndexError:
raise UserWarning('Training data has records not known '
'to the active learner. Make sure data '
'are in the data arguments of '
'prepare_training method or use '
'the training_file argument of '
'prepare_training')

def _checkTrainingPairs(self, labeled_pairs: TrainingData) -> None:
try:
Expand Down Expand Up @@ -1269,6 +1284,9 @@ def prepare_training(self,
'''

# Reset active learner
self.active_learner = None

if training_file:
self._read_training(training_file)
self._sample(data, sample_size, blocked_proportion)
Expand Down Expand Up @@ -1354,6 +1372,8 @@ def prepare_training(self,
matcher.prepare_training(data_1, data_2, training_file=f)
'''
# Reset active learner
self.active_learner = None

if training_file:
self._read_training(training_file)
Expand Down
12 changes: 8 additions & 4 deletions dedupe/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
PUNCTABLE = str.maketrans("", "", string.punctuation)


class NoIndexError(AttributeError):
pass


def strip_punc(s):
return s.translate(PUNCTABLE)

Expand Down Expand Up @@ -162,8 +166,8 @@ def __call__(self, record, **kwargs):
try:
doc_id = self.index._doc_to_id[doc]
except AttributeError:
raise AttributeError("Attempting to block with an index "
"predicate without indexing records")
raise NoIndexError("Attempting to block with an index "
"predicate without indexing records")

if doc_id in self.canopy:
block_key = self.canopy[doc_id]
Expand Down Expand Up @@ -216,8 +220,8 @@ def __call__(self, record, target=False, **kwargs):
else:
centers = self.index.search(doc, self.threshold)
except AttributeError:
raise AttributeError("Attempting to block with an index "
"predicate without indexing records")
raise NoIndexError("Attempting to block with an index "
"predicate without indexing records")
result = [str(center) for center in centers]
self._cache[(column, target)] = result
return result
Expand Down

0 comments on commit 9b21a90

Please sign in to comment.