Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/refs 30 filter training data by account #40

Merged
merged 2 commits into from
Apr 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions smart_importer/machinelearning_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
filter_training_data_by_account: str = None,
known_account: str = None,
existing_entries: List[Tuple] = None) -> List[Transaction]:
'''
Loads training data
:param training_data: The training data that shall be loaded.
Can be provided as a string (the filename pointing to a beancount file),
a _FileMemo instance,
or a list of beancount entries
:param filter_training_data_by_account: Optional filter for the training data.
:param known_account: Optional filter for the training data.
If provided, the training data is filtered to only include transactions that involve the specified account.
:param existing_entries: Optional existing entries to use instead of explicit training_data
:return: Returns a list of beancount entries.
Expand All @@ -42,11 +42,11 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
assert not errors
training_data = filter_txns(training_data)
logger.debug(f"Finished reading training data.")
if filter_training_data_by_account:
if known_account:
training_data = [t for t in training_data
# ...filtered because the training data must involve the filter_training_data_by_account:
if transaction_involves_account(t, filter_training_data_by_account)]
logger.debug(f"After filtering for account {filter_training_data_by_account}, "
# ...filtered because the training data must involve the account:
if transaction_involves_account(t, known_account)]
logger.debug(f"After filtering for account {known_account}, "
f"the training data consists of {len(training_data)} entries.")
return training_data

Expand Down
6 changes: 3 additions & 3 deletions smart_importer/predict_payees.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def extract(file):

def __init__(self, *,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_payees: bool = True,
overwrite_existing_payees=False,
suggest_payees: bool = True):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_payees = predict_payees
self.overwrite_existing_payees = overwrite_existing_payees
self.suggest_payees = suggest_payees
Expand Down Expand Up @@ -88,7 +88,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# train the machine learning model
Expand Down
22 changes: 18 additions & 4 deletions smart_importer/predict_postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PredictPostings:

@PredictPostings(
training_data="trainingdata.beancount",
filter_training_data_by_account="The:Importers:Already:Known:Accountname"
account="The:Importers:Already:Known:Accountname"
)
class MyImporter(ImporterProtocol):
def extract(file):
Expand All @@ -48,12 +48,12 @@ def __init__(
self,
*,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_second_posting: bool = True,
suggest_accounts: bool = True
):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_second_posting = predict_second_posting
self.suggest_accounts = suggest_accounts

Expand All @@ -76,14 +76,28 @@ def patched_extract_function(self, original_extract_function):

@wraps(original_extract_function)
def wrapper(self, file, existing_entries=None):

# read the importer's existing entries, if provided as argument to its `extract` method:
decorator.existing_entries = existing_entries

# read the importer's `extract`ed entries
logger.debug(f"About to call the importer's extract function to receive entries to be imported...")
if 'existing_entries' in inspect.signature(original_extract_function).parameters:
decorator.imported_transactions = original_extract_function(self, file, existing_entries)
else:
decorator.imported_transactions = original_extract_function(self, file)

# read the importer's file_account, to be used as default value for the decorator's known `account`:
if inspect.ismethod(self.file_account) and not decorator.account:
logger.debug("Trying to read the importer's file_account, "
"to be used as default value for the decorator's `account` argument...")
file_account = self.file_account(file)
if file_account:
decorator.account = file_account
logger.debug(f"Read file_account {file_account} from the importer; "
f"using it as known account in the decorator.")
else:
logger.debug(f"Could not retrieve file_account from the importer.")

return decorator.enhance_transactions()

Expand All @@ -92,7 +106,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# convert training data to a list of TxnPostingAccounts
Expand Down
52 changes: 24 additions & 28 deletions smart_importer/tests/predict_payees_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from beancount.core.data import Transaction
from beancount.ingest.importer import ImporterProtocol
from beancount.parser import parser

from smart_importer import machinelearning_helpers as ml
from smart_importer.predict_payees import PredictPayees

Expand Down Expand Up @@ -85,7 +86,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Farmer Fresh',
Expand All @@ -97,10 +98,13 @@ class Testdata:
]


class BasicImporter(ImporterProtocol):
class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
return Testdata.test_data

def file_account(self, file):
return Testdata.known_account


class PredictPayeesTest(unittest.TestCase):
'''
Expand All @@ -115,14 +119,14 @@ def setUp(self):
# define and decorate an importer:
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking",
account="Assets:US:BofA:Checking",
overwrite_existing_payees=False
)
class DecoratedImporter(BasicImporter):
class DecoratedTestImporter(BasicTestImporter):
pass

self.importerClass = DecoratedImporter
self.importer = DecoratedImporter()
self.importerClass = DecoratedTestImporter
self.importer = DecoratedTestImporter()

def test_dummy_importer(self):
'''
Expand Down Expand Up @@ -190,13 +194,13 @@ def test_class_decoration_with_arguments(self):

@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
Expand All @@ -210,24 +214,20 @@ def test_method_decoration_with_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_class_decoration_without_arguments(self):
'''
Verifies that the decorator can be applied to importer classes,
Expand All @@ -236,19 +236,15 @@ def test_class_decoration_without_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))

@PredictPayees()
class SmartImporter(BasicImporter): pass
class SmartTestImporter(BasicTestImporter): pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_method_decoration_without_arguments(self):
'''
Verifies that the decorator can be applied to an importer's extract method,
Expand All @@ -257,13 +253,13 @@ def test_method_decoration_without_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPayees()
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_payees = [transaction.payee for transaction in transactions]
self.assertEqual(predicted_payees, Testdata.correct_predictions)
Expand Down
52 changes: 24 additions & 28 deletions smart_importer/tests/predict_postings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from beancount.core.data import Transaction
from beancount.ingest.importer import ImporterProtocol
from beancount.parser import parser

from smart_importer.predict_postings import PredictPostings

LOG_LEVEL = logging.DEBUG
Expand Down Expand Up @@ -81,7 +82,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Expenses:Food:Groceries',
Expand All @@ -93,10 +94,13 @@ class Testdata:
]


class BasicImporter(ImporterProtocol):
class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
return Testdata.test_data

def file_account(self, file):
return Testdata.known_account


class PredictPostingsTest(unittest.TestCase):
'''
Expand All @@ -111,13 +115,13 @@ def setUp(self):
# define and decorate an importer:
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking"
account="Assets:US:BofA:Checking"
)
class DecoratedImporter(BasicImporter):
class DecoratedTestImporter(BasicTestImporter):
pass

self.importerClass = DecoratedImporter
self.importer = DecoratedImporter()
self.importerClass = DecoratedTestImporter
self.importer = DecoratedTestImporter()

def test_unchanged_narrations(self):
'''
Expand Down Expand Up @@ -176,13 +180,13 @@ def test_class_decoration_with_arguments(self):

@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
Expand All @@ -196,24 +200,20 @@ def test_method_decoration_with_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_class_decoration_with_empty_arguments(self):
'''
Verifies that the decorator can be applied to importer classes,
Expand All @@ -222,19 +222,15 @@ def test_class_decoration_with_empty_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))

@PredictPostings()
class SmartImporter(BasicImporter): pass
class SmartTestImporter(BasicTestImporter): pass

i = SmartImporter()
self.assertIsInstance(i, SmartImporter,
i = SmartTestImporter()
self.assertIsInstance(i, SmartTestImporter,
'The decorated importer shall still be an instance of the undecorated class.')
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [transaction.postings[-1].account for transaction in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)

# TODO: implement reasonable defaults to fix this test case:
@unittest.skip(
"smart imports without arguments currently fail "
"because the already known account is not filtered from the training data")
def test_method_decoration_with_empty_arguments(self):
'''
Verifies that the decorator can be applied to an importer's extract method,
Expand All @@ -243,13 +239,13 @@ def test_method_decoration_with_empty_arguments(self):
logger.info("Running Test Case: {id}".format(id=self.id().split('.')[-1]))
testcase = self

class SmartImporter(BasicImporter):
class SmartTestImporter(BasicTestImporter):
@PredictPostings()
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
testcase.assertIsInstance(self, SmartTestImporter)
return super().extract(file, existing_entries=existing_entries)

i = SmartImporter()
i = SmartTestImporter()
transactions = i.extract('file', existing_entries=Testdata.training_data)
predicted_accounts = [entry.postings[-1].account for entry in transactions]
self.assertEqual(predicted_accounts, Testdata.correct_predictions)
Expand Down