diff --git a/beancount_import/reconcile.py b/beancount_import/reconcile.py index c4c4f8dd..a92811c8 100644 --- a/beancount_import/reconcile.py +++ b/beancount_import/reconcile.py @@ -25,6 +25,8 @@ from .matching import FIXME_ACCOUNT, is_unknown_account, CLEARED_KEY +UNCONFIRMED_ACCOUNT_KEY = 'unconfirmed_account' +NEW_ACCOUNT_KEY = 'new_account' display_prediction_explanation = False classifier_cache_version_number = 1 @@ -722,18 +724,50 @@ def _get_primary_transaction_amount_number(self, transaction: Transaction): return -source_posting.units.number return None + + def _get_unknown_account_names(self, transaction: Transaction): + return [NEW_ACCOUNT_KEY in posting.meta for posting in transaction] + + def _group_predicted_accounts_by_name(self, transaction: Transaction): + ''' + Takes a list of postings with candidate account names, + and groups them into groups that should share the same exact account. + Expects each predicted posting to have an UNCONFIRMED_ACCOUNT_KEY meta field. + ''' + num_groups = 0 + group_numbers = [] + predicted_account_names = [] + existing_groups = {} # type: Dict[str, int] + new_accounts = [] + for posting in transaction.postings: + if UNCONFIRMED_ACCOUNT_KEY not in posting.meta: + continue + group_number = existing_groups.setdefault(posting.account, + num_groups) + predicted_account_names.append(posting.account) + if group_number == num_groups: + num_groups += 1 + group_numbers.append(group_number) + return predicted_account_names, group_numbers + def _get_unknown_account_predictions(self, transaction: Transaction) -> List[str]: - group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features( - transaction) - group_predictions = [ - self.predict_account(prediction_input) - for prediction_input in group_prediction_inputs - ] - group_numbers = training.get_unknown_account_group_numbers(transaction) - return [ - group_predictions[group_number] for group_number in group_numbers - ] + if any(UNCONFIRMED_ACCOUNT_KEY in posting.meta for posting in transaction): + # if any of the postings have the UNCONFIRMED_ACCOUNT_KEY, then prediction was handled by smart_importer + predicted_account_names, _ = _group_predicted_accounts_by_name(transaction) + transaction.meta[UNCONFIRMED_ACCOUNT_KEY] = True + return predicted_account_names + else: + group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features( + transaction) + group_predictions = [ + self.predict_account(prediction_input) + for prediction_input in group_prediction_inputs + ] + group_numbers = training.get_unknown_account_group_numbers(transaction) + return [ + group_predictions[group_number] for group_number in group_numbers + ] def _make_candidate_with_substitutions(self, transaction: Transaction, @@ -754,8 +788,12 @@ def _make_candidate_with_substitutions(self, unique_id: account for unique_id, account in zip(unique_ids, new_accounts) } - group_numbers = training.get_unknown_account_group_numbers(transaction) - unknown_names = training.get_unknown_account_names(transaction) + if UNCONFIRMED_ACCOUNT_KEY in transaction.meta: + _, group_numbers = _group_predicted_accounts_by_name(transaction) + unknown_names = _get_unknown_account_names(transaction) + else: + group_numbers = training.get_unknown_account_group_numbers(transaction) + unknown_names = training.get_unknown_account_names(transaction) substitutions = [ AccountSubstitution( unique_name=unique_id,