Skip to content

Commit

Permalink
refs #30 renames filter_training_data_by_account to simply `account…
Browse files Browse the repository at this point in the history
…` or sometimes to be be more specific to `known_account`.
  • Loading branch information
johannesjh committed Apr 30, 2018
1 parent d28d0fd commit 5e1b2d2
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
12 changes: 6 additions & 6 deletions smart_importer/machinelearning_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
filter_training_data_by_account: str = None,
known_account: str = None,
existing_entries: List[Tuple] = None) -> List[Transaction]:
'''
Loads training data
:param training_data: The training data that shall be loaded.
Can be provided as a string (the filename pointing to a beancount file),
a _FileMemo instance,
or a list of beancount entries
:param filter_training_data_by_account: Optional filter for the training data.
:param known_account: Optional filter for the training data.
If provided, the training data is filtered to only include transactions that involve the specified account.
:param existing_entries: Optional existing entries to use instead of explicit training_data
:return: Returns a list of beancount entries.
Expand All @@ -42,11 +42,11 @@ def load_training_data(training_data: Union[_FileMemo, List[Transaction], str],
assert not errors
training_data = filter_txns(training_data)
logger.debug(f"Finished reading training data.")
if filter_training_data_by_account:
if known_account:
training_data = [t for t in training_data
# ...filtered because the training data must involve the filter_training_data_by_account:
if transaction_involves_account(t, filter_training_data_by_account)]
logger.debug(f"After filtering for account {filter_training_data_by_account}, "
# ...filtered because the training data must involve the account:
if transaction_involves_account(t, known_account)]
logger.debug(f"After filtering for account {known_account}, "
f"the training data consists of {len(training_data)} entries.")
return training_data

Expand Down
6 changes: 3 additions & 3 deletions smart_importer/predict_payees.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def extract(file):

def __init__(self, *,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_payees: bool = True,
overwrite_existing_payees=False,
suggest_payees: bool = True):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_payees = predict_payees
self.overwrite_existing_payees = overwrite_existing_payees
self.suggest_payees = suggest_payees
Expand Down Expand Up @@ -88,7 +88,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# train the machine learning model
Expand Down
8 changes: 4 additions & 4 deletions smart_importer/predict_postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PredictPostings:
@PredictPostings(
training_data="trainingdata.beancount",
filter_training_data_by_account="The:Importers:Already:Known:Accountname"
account="The:Importers:Already:Known:Accountname"
)
class MyImporter(ImporterProtocol):
def extract(file):
Expand All @@ -48,12 +48,12 @@ def __init__(
self,
*,
training_data: Union[_FileMemo, List[Transaction], str] = None,
filter_training_data_by_account: str = None,
account: str = None,
predict_second_posting: bool = True,
suggest_accounts: bool = True
):
self.training_data = training_data
self.filter_training_data_by_account = filter_training_data_by_account
self.account = account
self.predict_second_posting = predict_second_posting
self.suggest_accounts = suggest_accounts

Expand Down Expand Up @@ -92,7 +92,7 @@ def wrapper(self, file, existing_entries=None):
def enhance_transactions(self):# load training data
self.training_data = ml.load_training_data(
self.training_data,
filter_training_data_by_account=self.filter_training_data_by_account,
known_account=self.account,
existing_entries=self.existing_entries)

# convert training data to a list of TxnPostingAccounts
Expand Down
8 changes: 4 additions & 4 deletions smart_importer/tests/predict_payees_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Farmer Fresh',
Expand Down Expand Up @@ -115,7 +115,7 @@ def setUp(self):
# define and decorate an importer:
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking",
account="Assets:US:BofA:Checking",
overwrite_existing_payees=False
)
class DecoratedImporter(BasicImporter):
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_class_decoration_with_arguments(self):

@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
pass
Expand All @@ -213,7 +213,7 @@ def test_method_decoration_with_arguments(self):
class SmartImporter(BasicImporter):
@PredictPayees(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
Expand Down
8 changes: 4 additions & 4 deletions smart_importer/tests/predict_postings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Testdata:
""")
assert not errors

filter_training_data_by_account = "Assets:US:BofA:Checking"
known_account = "Assets:US:BofA:Checking"

correct_predictions = [
'Expenses:Food:Groceries',
Expand Down Expand Up @@ -111,7 +111,7 @@ def setUp(self):
# define and decorate an importer:
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account="Assets:US:BofA:Checking"
account="Assets:US:BofA:Checking"
)
class DecoratedImporter(BasicImporter):
pass
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_class_decoration_with_arguments(self):

@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
class SmartImporter(BasicImporter):
pass
Expand All @@ -199,7 +199,7 @@ def test_method_decoration_with_arguments(self):
class SmartImporter(BasicImporter):
@PredictPostings(
training_data=Testdata.training_data,
filter_training_data_by_account=Testdata.filter_training_data_by_account
account=Testdata.known_account
)
def extract(self, file, existing_entries=None):
testcase.assertIsInstance(self, SmartImporter)
Expand Down

0 comments on commit 5e1b2d2

Please sign in to comment.