From 5e1b2d20bb489fc9a3b9263d45b2f78a9a6676a4 Mon Sep 17 00:00:00 2001 From: Johannes Harms Date: Mon, 30 Apr 2018 19:23:34 +0200 Subject: [PATCH 1/2] refs #30 renames `filter_training_data_by_account` to simply `account` or sometimes to be be more specific to `known_account`. --- smart_importer/machinelearning_helpers.py | 12 ++++++------ smart_importer/predict_payees.py | 6 +++--- smart_importer/predict_postings.py | 8 ++++---- smart_importer/tests/predict_payees_test.py | 8 ++++---- smart_importer/tests/predict_postings_test.py | 8 ++++---- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/smart_importer/machinelearning_helpers.py b/smart_importer/machinelearning_helpers.py index 5e40f62..00d7f6f 100644 --- a/smart_importer/machinelearning_helpers.py +++ b/smart_importer/machinelearning_helpers.py @@ -15,7 +15,7 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str], - filter_training_data_by_account: str = None, + known_account: str = None, existing_entries: List[Tuple] = None) -> List[Transaction]: ''' Loads training data @@ -23,7 +23,7 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str], Can be provided as a string (the filename pointing to a beancount file), a _FileMemo instance, or a list of beancount entries - :param filter_training_data_by_account: Optional filter for the training data. + :param known_account: Optional filter for the training data. If provided, the training data is filtered to only include transactions that involve the specified account. :param existing_entries: Optional existing entries to use instead of explicit training_data :return: Returns a list of beancount entries. @@ -42,11 +42,11 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str], assert not errors training_data = filter_txns(training_data) logger.debug(f"Finished reading training data.") - if filter_training_data_by_account: + if known_account: training_data = [t for t in training_data - # ...filtered because the training data must involve the filter_training_data_by_account: - if transaction_involves_account(t, filter_training_data_by_account)] - logger.debug(f"After filtering for account {filter_training_data_by_account}, " + # ...filtered because the training data must involve the account: + if transaction_involves_account(t, known_account)] + logger.debug(f"After filtering for account {known_account}, " f"the training data consists of {len(training_data)} entries.") return training_data diff --git a/smart_importer/predict_payees.py b/smart_importer/predict_payees.py index 0dbd48f..37ae77c 100644 --- a/smart_importer/predict_payees.py +++ b/smart_importer/predict_payees.py @@ -44,12 +44,12 @@ def extract(file): def __init__(self, *, training_data: Union[_FileMemo, List[Transaction], str] = None, - filter_training_data_by_account: str = None, + account: str = None, predict_payees: bool = True, overwrite_existing_payees=False, suggest_payees: bool = True): self.training_data = training_data - self.filter_training_data_by_account = filter_training_data_by_account + self.account = account self.predict_payees = predict_payees self.overwrite_existing_payees = overwrite_existing_payees self.suggest_payees = suggest_payees @@ -88,7 +88,7 @@ def wrapper(self, file, existing_entries=None): def enhance_transactions(self):# load training data self.training_data = ml.load_training_data( self.training_data, - filter_training_data_by_account=self.filter_training_data_by_account, + known_account=self.account, existing_entries=self.existing_entries) # train the machine learning model diff --git a/smart_importer/predict_postings.py b/smart_importer/predict_postings.py index 46864dd..6fc2c68 100644 --- a/smart_importer/predict_postings.py +++ b/smart_importer/predict_postings.py @@ -32,7 +32,7 @@ class PredictPostings: @PredictPostings( training_data="trainingdata.beancount", - filter_training_data_by_account="The:Importers:Already:Known:Accountname" + account="The:Importers:Already:Known:Accountname" ) class MyImporter(ImporterProtocol): def extract(file): @@ -48,12 +48,12 @@ def __init__( self, *, training_data: Union[_FileMemo, List[Transaction], str] = None, - filter_training_data_by_account: str = None, + account: str = None, predict_second_posting: bool = True, suggest_accounts: bool = True ): self.training_data = training_data - self.filter_training_data_by_account = filter_training_data_by_account + self.account = account self.predict_second_posting = predict_second_posting self.suggest_accounts = suggest_accounts @@ -92,7 +92,7 @@ def wrapper(self, file, existing_entries=None): def enhance_transactions(self):# load training data self.training_data = ml.load_training_data( self.training_data, - filter_training_data_by_account=self.filter_training_data_by_account, + known_account=self.account, existing_entries=self.existing_entries) # convert training data to a list of TxnPostingAccounts diff --git a/smart_importer/tests/predict_payees_test.py b/smart_importer/tests/predict_payees_test.py index 0522a91..0195c3d 100644 --- a/smart_importer/tests/predict_payees_test.py +++ b/smart_importer/tests/predict_payees_test.py @@ -85,7 +85,7 @@ class Testdata: """) assert not errors - filter_training_data_by_account = "Assets:US:BofA:Checking" + known_account = "Assets:US:BofA:Checking" correct_predictions = [ 'Farmer Fresh', @@ -115,7 +115,7 @@ def setUp(self): # define and decorate an importer: @PredictPayees( training_data=Testdata.training_data, - filter_training_data_by_account="Assets:US:BofA:Checking", + account="Assets:US:BofA:Checking", overwrite_existing_payees=False ) class DecoratedImporter(BasicImporter): @@ -190,7 +190,7 @@ def test_class_decoration_with_arguments(self): @PredictPayees( training_data=Testdata.training_data, - filter_training_data_by_account=Testdata.filter_training_data_by_account + account=Testdata.known_account ) class SmartImporter(BasicImporter): pass @@ -213,7 +213,7 @@ def test_method_decoration_with_arguments(self): class SmartImporter(BasicImporter): @PredictPayees( training_data=Testdata.training_data, - filter_training_data_by_account=Testdata.filter_training_data_by_account + account=Testdata.known_account ) def extract(self, file, existing_entries=None): testcase.assertIsInstance(self, SmartImporter) diff --git a/smart_importer/tests/predict_postings_test.py b/smart_importer/tests/predict_postings_test.py index c680765..d3ba062 100644 --- a/smart_importer/tests/predict_postings_test.py +++ b/smart_importer/tests/predict_postings_test.py @@ -81,7 +81,7 @@ class Testdata: """) assert not errors - filter_training_data_by_account = "Assets:US:BofA:Checking" + known_account = "Assets:US:BofA:Checking" correct_predictions = [ 'Expenses:Food:Groceries', @@ -111,7 +111,7 @@ def setUp(self): # define and decorate an importer: @PredictPostings( training_data=Testdata.training_data, - filter_training_data_by_account="Assets:US:BofA:Checking" + account="Assets:US:BofA:Checking" ) class DecoratedImporter(BasicImporter): pass @@ -176,7 +176,7 @@ def test_class_decoration_with_arguments(self): @PredictPostings( training_data=Testdata.training_data, - filter_training_data_by_account=Testdata.filter_training_data_by_account + account=Testdata.known_account ) class SmartImporter(BasicImporter): pass @@ -199,7 +199,7 @@ def test_method_decoration_with_arguments(self): class SmartImporter(BasicImporter): @PredictPostings( training_data=Testdata.training_data, - filter_training_data_by_account=Testdata.filter_training_data_by_account + account=Testdata.known_account ) def extract(self, file, existing_entries=None): testcase.assertIsInstance(self, SmartImporter) From 9dddd5f514bc2b5f29311df02651229a2e56e5b0 Mon Sep 17 00:00:00 2001 From: Johannes Harms Date: Mon, 30 Apr 2018 20:01:38 +0200 Subject: [PATCH 2/2] refs #30 reads default value for known `account` from the importer instance. --- smart_importer/predict_postings.py | 14 ++++++ smart_importer/tests/predict_payees_test.py | 44 +++++++++---------- smart_importer/tests/predict_postings_test.py | 44 +++++++++---------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/smart_importer/predict_postings.py b/smart_importer/predict_postings.py index 6fc2c68..a454e8d 100644 --- a/smart_importer/predict_postings.py +++ b/smart_importer/predict_postings.py @@ -76,14 +76,28 @@ def patched_extract_function(self, original_extract_function): @wraps(original_extract_function) def wrapper(self, file, existing_entries=None): + + # read the importer's existing entries, if provided as argument to its `extract` method: decorator.existing_entries = existing_entries + # read the importer's `extract`ed entries logger.debug(f"About to call the importer's extract function to receive entries to be imported...") if 'existing_entries' in inspect.signature(original_extract_function).parameters: decorator.imported_transactions = original_extract_function(self, file, existing_entries) else: decorator.imported_transactions = original_extract_function(self, file) + # read the importer's file_account, to be used as default value for the decorator's known `account`: + if inspect.ismethod(self.file_account) and not decorator.account: + logger.debug("Trying to read the importer's file_account, " + "to be used as default value for the decorator's `account` argument...") + file_account = self.file_account(file) + if file_account: + decorator.account = file_account + logger.debug(f"Read file_account {file_account} from the importer; " + f"using it as known account in the decorator.") + else: + logger.debug(f"Could not retrieve file_account from the importer.") return decorator.enhance_transactions() diff --git a/smart_importer/tests/predict_payees_test.py b/smart_importer/tests/predict_payees_test.py index 0195c3d..5235f44 100644 --- a/smart_importer/tests/predict_payees_test.py +++ b/smart_importer/tests/predict_payees_test.py @@ -7,6 +7,7 @@ from beancount.core.data import Transaction from beancount.ingest.importer import ImporterProtocol from beancount.parser import parser + from smart_importer import machinelearning_helpers as ml from smart_importer.predict_payees import PredictPayees @@ -97,10 +98,13 @@ class Testdata: ] -class BasicImporter(ImporterProtocol): +class BasicTestImporter(ImporterProtocol): def extract(self, file, existing_entries=None): return Testdata.test_data + def file_account(self, file): + return Testdata.known_account + class PredictPayeesTest(unittest.TestCase): ''' @@ -118,11 +122,11 @@ def setUp(self): account="Assets:US:BofA:Checking", overwrite_existing_payees=False ) - class DecoratedImporter(BasicImporter): + class DecoratedTestImporter(BasicTestImporter): pass - self.importerClass = DecoratedImporter - self.importer = DecoratedImporter() + self.importerClass = DecoratedTestImporter + self.importer = DecoratedTestImporter() def test_dummy_importer(self): ''' @@ -192,11 +196,11 @@ def test_class_decoration_with_arguments(self): training_data=Testdata.training_data, account=Testdata.known_account ) - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): pass - i = SmartImporter() - self.assertIsInstance(i, SmartImporter, + i = SmartTestImporter() + self.assertIsInstance(i, SmartTestImporter, 'The decorated importer shall still be an instance of the undecorated class.') transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_payees = [transaction.payee for transaction in transactions] @@ -210,24 +214,20 @@ def test_method_decoration_with_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) testcase = self - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): @PredictPayees( training_data=Testdata.training_data, account=Testdata.known_account ) def extract(self, file, existing_entries=None): - testcase.assertIsInstance(self, SmartImporter) + testcase.assertIsInstance(self, SmartTestImporter) return super().extract(file, existing_entries=existing_entries) - i = SmartImporter() + i = SmartTestImporter() transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_payees = [transaction.payee for transaction in transactions] self.assertEqual(predicted_payees, Testdata.correct_predictions) - # TODO: implement reasonable defaults to fix this test case: - @unittest.skip( - "smart imports without arguments currently fail " - "because the already known account is not filtered from the training data") def test_class_decoration_without_arguments(self): ''' Verifies that the decorator can be applied to importer classes, @@ -236,19 +236,15 @@ def test_class_decoration_without_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) @PredictPayees() - class SmartImporter(BasicImporter): pass + class SmartTestImporter(BasicTestImporter): pass - i = SmartImporter() - self.assertIsInstance(i, SmartImporter, + i = SmartTestImporter() + self.assertIsInstance(i, SmartTestImporter, 'The decorated importer shall still be an instance of the undecorated class.') transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_payees = [transaction.payee for transaction in transactions] self.assertEqual(predicted_payees, Testdata.correct_predictions) - # TODO: implement reasonable defaults to fix this test case: - @unittest.skip( - "smart imports without arguments currently fail " - "because the already known account is not filtered from the training data") def test_method_decoration_without_arguments(self): ''' Verifies that the decorator can be applied to an importer's extract method, @@ -257,13 +253,13 @@ def test_method_decoration_without_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) testcase = self - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): @PredictPayees() def extract(self, file, existing_entries=None): - testcase.assertIsInstance(self, SmartImporter) + testcase.assertIsInstance(self, SmartTestImporter) return super().extract(file, existing_entries=existing_entries) - i = SmartImporter() + i = SmartTestImporter() transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_payees = [transaction.payee for transaction in transactions] self.assertEqual(predicted_payees, Testdata.correct_predictions) diff --git a/smart_importer/tests/predict_postings_test.py b/smart_importer/tests/predict_postings_test.py index d3ba062..25b5c95 100644 --- a/smart_importer/tests/predict_postings_test.py +++ b/smart_importer/tests/predict_postings_test.py @@ -7,6 +7,7 @@ from beancount.core.data import Transaction from beancount.ingest.importer import ImporterProtocol from beancount.parser import parser + from smart_importer.predict_postings import PredictPostings LOG_LEVEL = logging.DEBUG @@ -93,10 +94,13 @@ class Testdata: ] -class BasicImporter(ImporterProtocol): +class BasicTestImporter(ImporterProtocol): def extract(self, file, existing_entries=None): return Testdata.test_data + def file_account(self, file): + return Testdata.known_account + class PredictPostingsTest(unittest.TestCase): ''' @@ -113,11 +117,11 @@ def setUp(self): training_data=Testdata.training_data, account="Assets:US:BofA:Checking" ) - class DecoratedImporter(BasicImporter): + class DecoratedTestImporter(BasicTestImporter): pass - self.importerClass = DecoratedImporter - self.importer = DecoratedImporter() + self.importerClass = DecoratedTestImporter + self.importer = DecoratedTestImporter() def test_unchanged_narrations(self): ''' @@ -178,11 +182,11 @@ def test_class_decoration_with_arguments(self): training_data=Testdata.training_data, account=Testdata.known_account ) - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): pass - i = SmartImporter() - self.assertIsInstance(i, SmartImporter, + i = SmartTestImporter() + self.assertIsInstance(i, SmartTestImporter, 'The decorated importer shall still be an instance of the undecorated class.') transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_accounts = [entry.postings[-1].account for entry in transactions] @@ -196,24 +200,20 @@ def test_method_decoration_with_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) testcase = self - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): @PredictPostings( training_data=Testdata.training_data, account=Testdata.known_account ) def extract(self, file, existing_entries=None): - testcase.assertIsInstance(self, SmartImporter) + testcase.assertIsInstance(self, SmartTestImporter) return super().extract(file, existing_entries=existing_entries) - i = SmartImporter() + i = SmartTestImporter() transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_accounts = [entry.postings[-1].account for entry in transactions] self.assertEqual(predicted_accounts, Testdata.correct_predictions) - # TODO: implement reasonable defaults to fix this test case: - @unittest.skip( - "smart imports without arguments currently fail " - "because the already known account is not filtered from the training data") def test_class_decoration_with_empty_arguments(self): ''' Verifies that the decorator can be applied to importer classes, @@ -222,19 +222,15 @@ def test_class_decoration_with_empty_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) @PredictPostings() - class SmartImporter(BasicImporter): pass + class SmartTestImporter(BasicTestImporter): pass - i = SmartImporter() - self.assertIsInstance(i, SmartImporter, + i = SmartTestImporter() + self.assertIsInstance(i, SmartTestImporter, 'The decorated importer shall still be an instance of the undecorated class.') transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_accounts = [transaction.postings[-1].account for transaction in transactions] self.assertEqual(predicted_accounts, Testdata.correct_predictions) - # TODO: implement reasonable defaults to fix this test case: - @unittest.skip( - "smart imports without arguments currently fail " - "because the already known account is not filtered from the training data") def test_method_decoration_with_empty_arguments(self): ''' Verifies that the decorator can be applied to an importer's extract method, @@ -243,13 +239,13 @@ def test_method_decoration_with_empty_arguments(self): logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1])) testcase = self - class SmartImporter(BasicImporter): + class SmartTestImporter(BasicTestImporter): @PredictPostings() def extract(self, file, existing_entries=None): - testcase.assertIsInstance(self, SmartImporter) + testcase.assertIsInstance(self, SmartTestImporter) return super().extract(file, existing_entries=existing_entries) - i = SmartImporter() + i = SmartTestImporter() transactions = i.extract('file', existing_entries=Testdata.training_data) predicted_accounts = [entry.postings[-1].account for entry in transactions] self.assertEqual(predicted_accounts, Testdata.correct_predictions)