diff --git a/beancount_import/reconcile.py b/beancount_import/reconcile.py index df375ef5..78edd5d0 100644 --- a/beancount_import/reconcile.py +++ b/beancount_import/reconcile.py @@ -25,6 +25,8 @@ from .matching import FIXME_ACCOUNT, is_unknown_account, CLEARED_KEY +UNCONFIRMED_ACCOUNT_KEY = 'unconfirmed_account' +NEW_ACCOUNT_KEY = 'new_account' display_prediction_explanation = False classifier_cache_version_number = 1 @@ -722,18 +724,53 @@ def _get_primary_transaction_amount_number(self, transaction: Transaction): return -source_posting.units.number return None + def _get_unknown_account_names(self, transaction: Transaction): + return [posting.account for posting in transaction.postings if posting.meta is not None and NEW_ACCOUNT_KEY in posting.meta] + + def _has_unconfirmed_account(self, transaction: Transaction) -> bool: + return any((posting.meta is not None and UNCONFIRMED_ACCOUNT_KEY in posting.meta) + for posting in transaction.postings) + + def _group_predicted_accounts_by_name(self, transaction: Transaction): + ''' + Takes a list of postings with candidate account names, + and groups them into groups that should share the same exact account. + Expects each predicted posting to have an UNCONFIRMED_ACCOUNT_KEY meta field. + ''' + num_groups = 0 + group_numbers = [] + predicted_account_names = [] + existing_groups = {} # type: Dict[str, int] + new_accounts = [] + for posting in transaction.postings: + if posting.meta is None or UNCONFIRMED_ACCOUNT_KEY not in posting.meta: + continue + group_number = existing_groups.setdefault(posting.account, + num_groups) + predicted_account_names.append(posting.account) + if group_number == num_groups: + num_groups += 1 + group_numbers.append(group_number) + return predicted_account_names, group_numbers + def _get_unknown_account_predictions(self, transaction: Transaction) -> List[str]: - group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features( - transaction) - group_predictions = [ - self.predict_account(prediction_input) - for prediction_input in group_prediction_inputs - ] - group_numbers = training.get_unknown_account_group_numbers(transaction) - return [ - group_predictions[group_number] for group_number in group_numbers - ] + if self._has_unconfirmed_account(transaction): + # if any of the postings have an unconfirmed account, then prediction was handled by smart_importer + predicted_account_names, _ = _group_predicted_accounts_by_name(transaction) + transaction.meta[UNCONFIRMED_ACCOUNT_KEY] = True + return predicted_account_names + else: + group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features( + transaction) + group_predictions = [ + self.predict_account(prediction_input) + for prediction_input in group_prediction_inputs + ] + group_numbers = training.get_unknown_account_group_numbers(transaction) + return [ + group_predictions[group_number] for group_number in group_numbers + ] def _make_candidate_with_substitutions(self, transaction: Transaction, @@ -754,8 +791,12 @@ def _make_candidate_with_substitutions(self, unique_id: account for unique_id, account in zip(unique_ids, new_accounts) } - group_numbers = training.get_unknown_account_group_numbers(transaction) - unknown_names = training.get_unknown_account_names(transaction) + if transaction.meta is not None and UNCONFIRMED_ACCOUNT_KEY in transaction.meta: + _, group_numbers = self._group_predicted_accounts_by_name(transaction) + unknown_names = self._get_unknown_account_names(transaction) + else: + group_numbers = training.get_unknown_account_group_numbers(transaction) + unknown_names = training.get_unknown_account_names(transaction) substitutions = [ AccountSubstitution( unique_name=unique_id,