diff --git a/dedupe/api.py b/dedupe/api.py index 91e80e08c..f8a3b2bda 100644 --- a/dedupe/api.py +++ b/dedupe/api.py @@ -1052,17 +1052,7 @@ def _read_training(self, training_file: TextIO) -> None: training_pairs = json.load(training_file, object_hook=serializer._from_json) - try: - self.mark_pairs(training_pairs) - except AttributeError as e: - if "Attempting to fingerprint with an index predicate without indexing records" in str(e): - raise UserWarning('Training data has records not known ' - 'to the active learner. Read training ' - 'in before initializing the active ' - 'learner with the sample method, or ' - 'use the prepare_training method.') - else: - raise + self.mark_pairs(training_pairs) def train(self, recall: float = 1.00, @@ -1174,6 +1164,22 @@ def mark_pairs(self, labeled_pairs: TrainingData) -> None: } matcher.mark_pairs(labeled_examples) + .. note:: + `mark_pairs` is primarily designed to be used with `uncertain_pairs` + to incrementally build a training set. + + If you have existing training data, you should likely + format the data into the right form and supply the training + data with the `training_file` argument of the + `prepare_training` method. + + If that is not possible or desirable, you can use + `mark_pairs` to train a linker using the `mark_pairs` + method. However, you must ensure that every record that + appears in the `labeled_pairs` argument appears in either + the data or `training_file` supplied to the + `prepare_method` argument. + ''' self._checkTrainingPairs(labeled_pairs) @@ -1182,7 +1188,16 @@ def mark_pairs(self, labeled_pairs: TrainingData) -> None: if self.active_learner: examples, y = flatten_training(labeled_pairs) - self.active_learner.mark(examples, y) + + try: + self.active_learner.mark(examples, y) + except dedupe.predicates.NoIndexError: + raise UserWarning('Training data has records not known ' + 'to the active learner. Make sure data ' + 'are in the data arguments of ' + 'prepare_training method or use ' + 'the training_file argument of ' + 'prepare_training') def _checkTrainingPairs(self, labeled_pairs: TrainingData) -> None: try: @@ -1269,6 +1284,9 @@ def prepare_training(self, ''' + # Reset active learner + self.active_learner = None + if training_file: self._read_training(training_file) self._sample(data, sample_size, blocked_proportion) @@ -1354,6 +1372,8 @@ def prepare_training(self, matcher.prepare_training(data_1, data_2, training_file=f) ''' + # Reset active learner + self.active_learner = None if training_file: self._read_training(training_file) diff --git a/dedupe/predicates.py b/dedupe/predicates.py index 819e23f88..e4e90cb85 100644 --- a/dedupe/predicates.py +++ b/dedupe/predicates.py @@ -25,6 +25,10 @@ PUNCTABLE = str.maketrans("", "", string.punctuation) +class NoIndexError(AttributeError): + pass + + def strip_punc(s): return s.translate(PUNCTABLE) @@ -162,8 +166,8 @@ def __call__(self, record, **kwargs): try: doc_id = self.index._doc_to_id[doc] except AttributeError: - raise AttributeError("Attempting to block with an index " - "predicate without indexing records") + raise NoIndexError("Attempting to block with an index " + "predicate without indexing records") if doc_id in self.canopy: block_key = self.canopy[doc_id] @@ -216,8 +220,8 @@ def __call__(self, record, target=False, **kwargs): else: centers = self.index.search(doc, self.threshold) except AttributeError: - raise AttributeError("Attempting to block with an index " - "predicate without indexing records") + raise NoIndexError("Attempting to block with an index " + "predicate without indexing records") result = [str(center) for center in centers] self._cache[(column, target)] = result return result