Skip to content

Commit

Permalink
Merge pull request #40 from beancount/feature/refs-30-filter_training…
Browse files Browse the repository at this point in the history
…_data_by_account

Feature/refs 30 filter training data by account
  • Loading branch information
tarioch committed Apr 30, 2018
2 parents d28d0fd + 9dddd5f commit 0b75bea
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 69 deletions.
12 changes: 6 additions & 6 deletions smart_importer/machinelearning_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
filter_training_data_by_account: str = None,
known_account: str = None,
existing_entries: List[Tuple] = None) -> List[Transaction]:
'''
Loads training data
:param training_data: The training data that shall be loaded.
Can be provided as a string (the filename pointing to a beancount file),
a _FileMemo instance,
or a list of beancount entries
:param filter_training_data_by_account: Optional filter for the training data.
:param known_account: Optional filter for the training data.
If provided, the training data is filtered to only include transactions that involve the specified account.
:param existing_entries: Optional existing entries to use instead of explicit training_data
:return: Returns a list of beancount entries.
Expand All @@ -42,11 +42,11 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
assert not errors
training_data = filter_txns(training_data)
logger.debug(f"Finished reading training data.")
if filter_training_data_by_account:
if known_account:
training_data = [t for t in training_data
# ...filtered because the training data must involve the filter_training_data_by_account:
if transaction_involves_account(t, filter_training_data_by_account)]
logger.debug(f"After filtering for account {filter_training_data_by_account}, "
# ...filtered because the training data must involve the account:
if transaction_involves_account(t, known_account)]
logger.debug(f"After filtering for account {known_account}, "
f"the training data consists of {len(training_data)} entries.")
return training_data

Expand Down
6 changes: 3 additions & 3 deletions smart_importer/predict_payees.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def extract(file):

def __init__(self, *,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_payees: bool = True,
overwrite_existing_payees=False,
suggest_payees: bool = True):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_payees = predict_payees
self.overwrite_existing_payees = overwrite_existing_payees
self.suggest_payees = suggest_payees
Expand Down Expand Up @@ -88,7 +88,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# train the machine learning model
Expand Down
22 changes: 18 additions & 4 deletions smart_importer/predict_postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PredictPostings:
@PredictPostings(
training_data="trainingdata.beancount",
filter_training_data_by_account="The:Importers:Already:Known:Accountname"
account="The:Importers:Already:Known:Accountname"
)
class MyImporter(ImporterProtocol):
def extract(file):
Expand All @@ -48,12 +48,12 @@ def __init__(
self,
*,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_second_posting: bool = True,
suggest_accounts: bool = True
):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_second_posting = predict_second_posting
self.suggest_accounts = suggest_accounts

Expand All @@ -76,14 +76,28 @@ def patched_extract_function(self, original_extract_function):

@wraps(original_extract_function)
def wrapper(self, file, existing_entries=None):

# read the importer's existing entries, if provided as argument to its `extract` method:
decorator.existing_entries = existing_entries

# read the importer's `extract`ed entries
logger.debug(f"About to call the importer's extract function to receive entries to be imported...")
if 'existing_entries' in inspect.signature(original_extract_function).parameters:
decorator.imported_transactions = original_extract_function(self, file, existing_entries)
else:
decorator.imported_transactions = original_extract_function(self, file)

# read the importer's file_account, to be used as default value for the decorator's known `account`:
if inspect.ismethod(self.file_account) and not decorator.account:
logger.debug("Trying to read the importer's file_account, "
"to be used as default value for the decorator's `account` argument...")
file_account = self.file_account(file)
if file_account:
decorator.account = file_account
logger.debug(f"Read file_account {file_account} from the importer; "
f"using it as known account in the decorator.")
else:
logger.debug(f"Could not retrieve file_account from the importer.")

return decorator.enhance_transactions()

Expand All @@ -92,7 +106,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# convert training data to a list of TxnPostingAccounts
Expand Down
52 changes: 24 additions & 28 deletions smart_importer/tests/predict_payees_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from beancount.core.data import Transaction
from beancount.ingest.importer import ImporterProtocol
from beancount.parser import parser

from smart_importer import machinelearning_helpers as ml
from smart_importer.predict_payees import PredictPayees

Expand Down Expand Up @@ -85,7 +86,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Farmer Fresh',
Expand All @@ -97,10 +98,13 @@ class Testdata:
]


class BasicImporter(ImporterProtocol):
class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
return Testdata.test_data

def file_account(self, file):
return Testdata.known_account


class PredictPayeesTest(unittest.TestCase):
'''
Expand All @@ -115,14 +119,14 @@ def setUp(self):
# define and decorate an importer:
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking",
account="Assets:US:BofA:Checking",
overwrite_existing_payees=False
)
class DecoratedImporter(BasicImporter):
class DecoratedTestImporter(BasicTestImporter):
pass

self.importerClass = DecoratedImporter
self.importer = DecoratedImporter()
self.importerClass = DecoratedTestImporter
self.importer = DecoratedTestImporter()

def test_dummy_importer(self):
'''
Expand Down Expand Up @@ -190,13 +194,13 @@ def test_class_decoration_with_arguments(self):

@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
Expand All @@ -210,24 +214,20 @@ def test_method_decoration_with_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_class_decoration_without_arguments(self):
'''
Verifies that the decorator can be applied to importer classes,
Expand All @@ -236,19 +236,15 @@ def test_class_decoration_without_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))

@PredictPayees()
class SmartImporter(BasicImporter): pass
class SmartTestImporter(BasicTestImporter): pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_method_decoration_without_arguments(self):
'''
Verifies that the decorator can be applied to an importer's extract method,
Expand All @@ -257,13 +253,13 @@ def test_method_decoration_without_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPayees()
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)
Expand Down
52 changes: 24 additions & 28 deletions smart_importer/tests/predict_postings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from beancount.core.data import Transaction
from beancount.ingest.importer import ImporterProtocol
from beancount.parser import parser

from smart_importer.predict_postings import PredictPostings

LOG_LEVEL = logging.DEBUG
Expand Down Expand Up @@ -81,7 +82,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Expenses:Food:Groceries',
Expand All @@ -93,10 +94,13 @@ class Testdata:
]


class BasicImporter(ImporterProtocol):
class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
return Testdata.test_data

def file_account(self, file):
return Testdata.known_account


class PredictPostingsTest(unittest.TestCase):
'''
Expand All @@ -111,13 +115,13 @@ def setUp(self):
# define and decorate an importer:
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking"
account="Assets:US:BofA:Checking"
)
class DecoratedImporter(BasicImporter):
class DecoratedTestImporter(BasicTestImporter):
pass

self.importerClass = DecoratedImporter
self.importer = DecoratedImporter()
self.importerClass = DecoratedTestImporter
self.importer = DecoratedTestImporter()

def test_unchanged_narrations(self):
'''
Expand Down Expand Up @@ -176,13 +180,13 @@ def test_class_decoration_with_arguments(self):

@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
Expand All @@ -196,24 +200,20 @@ def test_method_decoration_with_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_class_decoration_with_empty_arguments(self):
'''
Verifies that the decorator can be applied to importer classes,
Expand All @@ -222,19 +222,15 @@ def test_class_decoration_with_empty_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))

@PredictPostings()
class SmartImporter(BasicImporter): pass
class SmartTestImporter(BasicTestImporter): pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [transaction.postings[-1].account for transaction in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_method_decoration_with_empty_arguments(self):
'''
Verifies that the decorator can be applied to an importer's extract method,
Expand All @@ -243,13 +239,13 @@ def test_method_decoration_with_empty_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPostings()
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)
Expand Down

0 comments on commit 0b75bea

Please sign in to comment.