From c3f755908b1bdb92282901c04eab2c24ca73ee8c Mon Sep 17 00:00:00 2001 From: Sufiyan Adhikari Date: Fri, 7 Aug 2020 23:59:48 +0530 Subject: [PATCH] Fix mypy errors --- .../source/generic_importer_source.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/beancount_import/source/generic_importer_source.py b/beancount_import/source/generic_importer_source.py index 2398e5bc..b4e4ec39 100644 --- a/beancount_import/source/generic_importer_source.py +++ b/beancount_import/source/generic_importer_source.py @@ -13,17 +13,14 @@ """ import os -import hashlib from glob import glob -from typing import List from collections import defaultdict import itertools -import datetime +from typing import Hashable, List, Dict, Optional from beancount.core.data import Transaction, Posting, Directive from beancount.core.amount import Amount from beancount.ingest.importer import ImporterProtocol -from beancount.core.compare import hash_entry from beancount.ingest.cache import get_file from ..matching import FIXME_ACCOUNT, SimpleInventory @@ -54,12 +51,13 @@ def __init__(self, self.files = [f for f in files if self.importer.identify(f)] @property - def name(self): + def name(self) -> str: return self.importer.name() def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None: results.add_account(self.account) - entries = defaultdict(list) + + entries:Dict[Hashable,List[Directive]] = defaultdict(list) for f in self.files: f_entries = self.importer.extract(f, existing_entries=journal.entries) # collect all entries in current statement, grouped by hash @@ -85,7 +83,7 @@ def prepare(self, journal: 'JournalEditor', results: SourceResults) -> None: def _add_description(self, entry: Transaction): if not isinstance(entry, Transaction): return None - postings = entry.postings #type: ignore + postings: List[Posting] = entry.postings to_mutate = [] for i, posting in enumerate(postings): if posting.account != self.account: continue @@ -102,14 +100,18 @@ def _add_description(self, entry: Transaction): {"source_desc":entry.narration, "date": entry.date}) postings.insert(i, p) - def _get_source_posting(self, entry:Transaction): + def _get_source_posting(self, entry:Transaction) -> Optional[Posting]: for posting in entry.postings: - if posting.account == self.account: return posting + if posting.account == self.account: + return posting - def _get_key_from_imported_entry(self, entry:Transaction): + def _get_key_from_imported_entry(self, entry:Transaction) -> Hashable: + source_posting = self._get_source_posting(entry) + if source_posting is None: + raise ValueError("entry has no postings for {self.account}") return (self.account, entry.date, - self._get_source_posting(entry).units, + source_posting.units, entry.narration) def _make_import_result(self, imported_entry:Directive): @@ -121,12 +123,6 @@ def _make_import_result(self, imported_entry:Directive): imported_entry.meta.pop('filename') return result -def _get_key_from_posting(entry: Transaction, posting: Posting, - source_postings: List[Posting], source_desc: str, - posting_date: datetime.date): - del entry - del source_postings - return (posting.account, posting_date, posting.units, source_desc) def get_info(raw_entry: Directive) -> dict: return dict(