diff --git a/dedupe/api.py b/dedupe/api.py index 776cad01..e58c69cb 100644 --- a/dedupe/api.py +++ b/dedupe/api.py @@ -14,7 +14,7 @@ import sqlite3 import tempfile import warnings -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, Optional, cast, overload import numpy import sklearn.linear_model @@ -1182,7 +1182,7 @@ def _read_training(self, training_file: TextIO) -> None: self.mark_pairs(training_pairs) def train( - self, recall: float = 1.00, index_predicates: bool = True + self, recall: float = 1.00, index_predicates: bool = True, branch_bound_max_calls: Optional[int] = None ) -> None: # pragma: no cover """ Learn final pairwise classifier and fingerprinting rules. Requires that @@ -1212,7 +1212,7 @@ def train( examples, y = flatten_training(self.training_pairs) self.classifier.fit(self.data_model.distances(examples), y) - self.predicates = self.active_learner.learn_predicates(recall, index_predicates) + self.predicates = self.active_learner.learn_predicates(recall, index_predicates, branch_bound_max_calls) self._fingerprinter = blocking.Fingerprinter(self.predicates) self.fingerprinter.reset_indices() diff --git a/dedupe/labeler.py b/dedupe/labeler.py index 52e87cd7..776afae9 100644 --- a/dedupe/labeler.py +++ b/dedupe/labeler.py @@ -3,7 +3,7 @@ import logging import random from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Optional, overload from warnings import warn import numpy @@ -131,13 +131,14 @@ def candidate_scores(self) -> numpy.typing.NDArray[numpy.float_]: return self._cached_scores def learn_predicates( - self, dupes: TrainingExamples, recall: float, index_predicates: bool + self, dupes: TrainingExamples, recall: float, index_predicates: bool, branch_bound_max_calls: Optional[int] = None ) -> tuple[Predicate, ...]: return self.block_learner.learn( dupes, recall=recall, index_predicates=index_predicates, candidate_types="random forest", + branch_bound_max_calls=branch_bound_max_calls ) def _predict(self, pairs: TrainingExamples) -> Labels: @@ -391,11 +392,11 @@ def mark(self, pairs: TrainingExamples, y: LabelsLike) -> None: learner.fit(self.pairs, self.y) def learn_predicates( - self, recall: float, index_predicates: bool + self, recall: float, index_predicates: bool, branch_bound_max_calls: Optional[int] = None ) -> tuple[Predicate, ...]: dupes = [pair for label, pair in zip(self.y, self.pairs) if label] return self.blocker.learn_predicates( - dupes, recall=recall, index_predicates=index_predicates + dupes, recall=recall, index_predicates=index_predicates, branch_bound_max_calls=branch_bound_max_calls ) diff --git a/dedupe/training.py b/dedupe/training.py index b38d5c73..65cd1747 100644 --- a/dedupe/training.py +++ b/dedupe/training.py @@ -7,7 +7,7 @@ import math import random from abc import ABC -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Optional, overload from warnings import warn from . import blocking @@ -40,6 +40,7 @@ def learn( recall: float, index_predicates: bool, candidate_types: Literal["simple", "random forest"] = "simple", + branch_bound_max_calls: Optional[int] = None ) -> tuple[Predicate, ...]: """ Takes in a set of training pairs and predicates and tries to find @@ -75,7 +76,7 @@ def learn( else: raise ValueError("candidate_type is not valid") - searcher = BranchBound(target_cover, 2500) + searcher = BranchBound(target_cover, branch_bound_max_calls or 2500) final_predicates = searcher.search(candidate_cover) logger.info("Final predicate set:")