diff --git a/Makefile b/Makefile index d25017f..65b602d 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ test: coverage erase - python `which nosetests` -dsv --with-coverage --cover-package ofxtools tests/*.py mypy ofxtools mypy tests + python `which nosetests` -dsv --with-coverage --cover-package ofxtools tests/*.py clean: find -regex '.*\.pyc' -exec rm {} \; diff --git a/ofxtools/scripts/ofxget.py b/ofxtools/scripts/ofxget.py index 8226abd..b9e5406 100644 --- a/ofxtools/scripts/ofxget.py +++ b/ofxtools/scripts/ofxget.py @@ -6,9 +6,7 @@ # stdlib imports import os import argparse -from argparse import ArgumentParser, ArgumentError, Action import configparser -from configparser import ConfigParser import datetime from collections import defaultdict, OrderedDict, ChainMap import getpass @@ -21,7 +19,6 @@ from io import BytesIO import itertools from operator import attrgetter -import sys import warnings import pydoc import typing @@ -35,6 +32,7 @@ MutableMapping, Dict, Sequence, + Iterable, ) # 3rd party imports @@ -58,55 +56,62 @@ from ofxtools.Parser import OFXTree, ParseError -CONFIGPATH = os.path.join(config.CONFIGDIR, "fi.cfg") -USERCONFIGPATH = os.path.join(config.USERCONFIGDIR, "ofxget.cfg") -UserConfig = ConfigParser() -UserConfig.read([CONFIGPATH, USERCONFIGPATH]) +class OfxgetWarning(UserWarning): + """ Base class for warnings in this module """ -DEFAULTS: Dict[str, Union[str, int, bool, list]] = { - "server": "", "url": "", "ofxhome": "", "version": 203, "org": "", "fid": "", - "appid": "", "appver": "", "language": "", "bankid": "", "brokerid": "", - "unclosedelements": False, "pretty": False, "user": "", "clientuid": "", - "checking": [], "savings": [], "moneymrkt": [], "creditline": [], - "creditcard": [], "investment": [], "dtstart": "", "dtend": "", - "dtasof": "", "inctran": True, "incbal": True, "incpos": True, - "incoo": False, "all": False, "years": [], "acctnum": "", "recid": "", - "dryrun": False, "unsafe": False, "write": False, "savepass": False, -} +############################################################################### +# TYPE ALIASES +############################################################################### +# Loaded Argparser args +ArgType = typing.ChainMap[str, Any] +# OFX connection params (OFX version, prettyprint, close_elements) tagged onto +# the OFXClient.request_profile() job submitted to the ThreadPoolExecutor +# during a profile scan +OFXVersion = int -NULL_ARGS = [None, "", []] +ScanMetadata = Tuple[OFXVersion, bool, bool] -class OfxgetWarning(UserWarning): - """ Base class for warnings in this module """ +# ScanMetadata without OFX version, i.e. (prettyprint, close_elements) +FormatArgs = Tuple[bool, bool] +# All working FormatArgs for a given OFX version +FormatMap = Mapping[OFXVersion, List[FormatArgs]] -# Type alias for loaded Argparser args -ArgType = typing.ChainMap[str, Any] +# FormatArgs made ArgType-compatible, i.e. (pretty, unclosedelements) +FormatConfig = OrderedDict -# Type alias for scan result of a single OFX protocol version +# Scan result of a single OFX protocol version ScanVersionResult = Mapping[str, Union[list, dict]] -# Type alias for a full set of profile scan results -ScanResults = Tuple[ScanVersionResult, ScanVersionResult, Mapping[str, bool]] +# Auth information parsed out of SIGNONINFO during a profile scan - +# CLIENTUIDREQ et al. +SignoninfoReport = Mapping[str, bool] + +# Full set of profile scan results +ScanResults = Tuple[ScanVersionResult, ScanVersionResult, SignoninfoReport] -# Type alias -FormatType = Tuple[bool, bool] -FormatConfig = MutableMapping[str, bool] +AcctInfo = Union[models.BANKACCTINFO, models.CCACCTINFO, models.INVACCTINFO] +ParsedAcctinfo = Mapping[str, Union[str, list]] -class UuidAction(Action): +############################################################################### +# DEFINE CLI +############################################################################### +class UuidAction(argparse.Action): + """ + Generates a random UUID4 each time called + """ def __call__(self, parser, namespace, values, option_string=None): uuid = OFXClient.uuid setattr(namespace, self.dest, uuid) -def make_argparser() -> ArgumentParser: - argparser = ArgumentParser( +def make_argparser() -> argparse.ArgumentParser: + argparser = argparse.ArgumentParser( description="Download OFX financial data", - # epilog="FIs configured: {}".format(fi_index()), prog="ofxget", ) argparser.add_argument( @@ -272,33 +277,6 @@ def make_argparser() -> ArgumentParser: ############################################################################### # CLI METHODS ############################################################################### -def scan_profiles(start: int, - stop: int, - timeout: Optional[float] = None - ) -> Dict[str, ScanResults]: - """ - Scan OFX Home's list of FIs for working connection configs. - """ - results = {} - - institutions = ofxhome.list_institutions() - for ofxhome_id in list(institutions.keys())[start:stop]: - lookup = ofxhome.lookup(ofxhome_id) - if lookup is None\ - or lookup.url is None\ - or ofxhome.ofx_invalid(lookup)\ - or ofxhome.ssl_invalid(lookup): - continue - scan_results = _scan_profile(url=lookup.url, - org=lookup.org, - fid=lookup.fid, - timeout=timeout) - if scan_results: - results[ofxhome_id] = scan_results - - return results - - def scan_profile(args: ArgType) -> None: """ Report working connection parameters @@ -309,15 +287,14 @@ def scan_profile(args: ArgType) -> None: if (not v2["versions"]) and (not v1["versions"]): msg = "Scan found no working formats for {}" print(msg.format(args["url"])) - sys.exit() - - print(json.dumps(scan_results)) - - if args["write"] and not args["dryrun"]: - extra_args = _best_scan_format(scan_results) - args.maps.insert(0, extra_args) + else: + print(json.dumps(scan_results)) - write_config(args) + if args["write"] and not args["dryrun"]: + extra_args = _best_scan_format(scan_results) + # args.maps.insert(0, extra_args) + # write_config(args) + write_config(ChainMap(extra_args, dict(args))) def _best_scan_format(scan_results: ScanResults) -> MutableMapping: @@ -399,6 +376,27 @@ def _request_acctinfo(args: ArgType, password: str) -> BytesIO: return BytesIO(response) +def _merge_acctinfo(args: ArgType, markup: BytesIO) -> None: + # *ACCTINFO classes don't have rich comparison methods; + # can't sort by class + sortKey = attrgetter("__class__.__name__") + acctinfos: List[AcctInfo] = sorted(extract_acctinfos(markup), key=sortKey) + + def parse_acctinfos(clsName, acctinfos): + dispatcher = {"BANKACCTINFO": parse_bankacctinfos, + "CCACCTINFO": parse_ccacctinfos, + "INVACCTINFO": parse_invacctinfos} + parser = dispatcher.get(clsName, lambda x: {}) + return parser(acctinfos) + + parsed_args: List[ParsedAcctinfo] = [parse_acctinfos(clsnm, infos) + for clsnm, infos in itertools.groupby( + acctinfos, key=sortKey)] + + # Insert extracted ACCTINFO after CLI commands, but before config files + args.maps.insert(1, ChainMap(*parsed_args)) + + def request_stmt(args: ArgType) -> None: """ Send *STMTRQ @@ -511,79 +509,30 @@ def request_tax1099(args: ArgType) -> None: ############################################################################### # ARGUMENT/CONFIG HANDLERS ############################################################################### -def init_client(args: ArgType) -> OFXClient: - """ - Initialize OFXClient with connection info from args - """ - client = OFXClient( - args["url"], - userid=args["user"] or None, - clientuid=args["clientuid"] or None, - org=args["org"] or None, - fid=args["fid"] or None, - version=args["version"], - appid=args["appid"] or None, - appver=args["appver"] or None, - language=args["language"] or None, - prettyprint=args["pretty"], - close_elements=not args["unclosedelements"], - bankid=args["bankid"] or None, - brokerid=args["brokerid"] or None, - ) - return client - - -def read_config(cfg, section): - return {k: config2arg(k, v) - for k, v in cfg[section].items()} if section in cfg else {} - - -def merge_config(args: argparse.Namespace, - config: configparser.ConfigParser) -> ArgType: - """ - Merge CLI args > user config > library config > OFX Home > defaults - """ - # All ArgumentParser args that have a value set - _args = {k: v for k, v in vars(args).items() if v is not None} - - if "server" in _args: - user_cfg = read_config(config, _args["server"]) - else: - user_cfg = {} - merged = ChainMap(_args, user_cfg, DEFAULTS) +CONFIGPATH = os.path.join(config.CONFIGDIR, "fi.cfg") +USERCONFIGPATH = os.path.join(config.USERCONFIGDIR, "ofxget.cfg") +UserConfig = configparser.ConfigParser() +UserConfig.read([CONFIGPATH, USERCONFIGPATH]) - ofxhome_id = merged["ofxhome"] - if ofxhome_id: - lookup = ofxhome.lookup(ofxhome_id) - if lookup is not None: - # Insert OFX Home lookup ahead of DEFAULTS but after - # user configs and library configs - merged.maps.insert(-1, {"url": lookup.url, "org": lookup.org, - "fid": lookup.fid, - "brokerid": lookup.brokerid}) +DEFAULTS: Dict[str, Union[str, int, bool, list]] = { + "server": "", "url": "", "ofxhome": "", "version": 203, "org": "", "fid": "", + "appid": "", "appver": "", "language": "", "bankid": "", "brokerid": "", + "unclosedelements": False, "pretty": False, "user": "", "clientuid": "", + "checking": [], "savings": [], "moneymrkt": [], "creditline": [], + "creditcard": [], "investment": [], "dtstart": "", "dtend": "", + "dtasof": "", "inctran": True, "incbal": True, "incpos": True, + "incoo": False, "all": False, "years": [], "acctnum": "", "recid": "", + "dryrun": False, "unsafe": False, "write": False, "savepass": False, +} - if not (merged.get("url", None) - or merged.get("dryrun", False) - or merged.get("request", None) == "list"): - err = "Missing URL" - if "server" not in _args: - msg = (f"{err} - please provide a server nickname, " - "or configure 'url' / 'ofxhome'") - raise ValueError(msg) +NULL_ARGS = [None, "", []] - server = _args["server"] - # Allow sloppy CLI args - passing URL as "server" positional arg - if urllib_parse.urlparse(server).scheme: - merged["url"] = server - merged["server"] = None - else: - msg = (f"{err} - please configure 'url' or 'ofxhome' " - f"for server '{server}'") - raise ValueError(msg) - return merged +def read_config(cfg, section): + return {k: config2arg(k, v) + for k, v in cfg[section].items()} if section in cfg else {} def config2arg(key: str, value: str) -> Union[List[str], bool, int, str]: @@ -597,7 +546,7 @@ def read_int(string: str) -> int: return int(value) def read_bool(string: str) -> bool: - BOOLY = ConfigParser.BOOLEAN_STATES # type: ignore + BOOLY = configparser.ConfigParser.BOOLEAN_STATES # type: ignore keys = list(BOOLY.keys()) if string not in BOOLY: msg = f"Can't interpret '{list}' as bool; must be one of {keys}" @@ -625,35 +574,16 @@ def read_list(string: str) -> List[str]: return handlers[cfg_type](value) # type: ignore -def arg2config(key: str, value: Union[list, bool, int, str]) -> str: - """ - Transform a config value from ArgParser format to ConfigParser format - """ - def write_string(value: str) -> str: - return value - - def write_int(value: int) -> str: - return str(value) - - def write_bool(value: bool) -> str: - return {True: "true", False: "false"}[value] - - def write_list(value: list) -> str: - # INI-friendly string representation of Python list type - return str(value).strip("[]").replace("'", "") - - handlers = {str: write_string, - bool: write_bool, - list: write_list, - int: write_int} +def write_config(args: ArgType) -> None: + mk_server_cfg(args) - if key not in DEFAULTS: - msg = f"Don't know type of {key}; define in ofxget.DEFAULTS" - raise ValueError(msg) + # msg = "\nWriting '{}' configs {} to {}..." + # print(msg.format(args["server"], dict(cfg.items()), USERCONFIGPATH)) - cfg_type = type(DEFAULTS[key]) + with open(USERCONFIGPATH, "w") as f: + UserConfig.write(f) - return handlers[cfg_type](value) # type: ignore + # print("...write OK") def mk_server_cfg(args: ArgType) -> configparser.SectionProxy: @@ -678,7 +608,7 @@ def mk_server_cfg(args: ArgType) -> configparser.SectionProxy: UserConfig[server] = {} cfg = UserConfig[server] - LibraryConfig = ConfigParser() + LibraryConfig = configparser.ConfigParser() LibraryConfig.read(CONFIGPATH) lib_cfg = read_config(LibraryConfig, server) @@ -702,20 +632,109 @@ def mk_server_cfg(args: ArgType) -> configparser.SectionProxy: return cfg -def write_config(args: ArgType) -> None: - mk_server_cfg(args) +def arg2config(key: str, value: Union[list, bool, int, str]) -> str: + """ + Transform a config value from ArgParser format to ConfigParser format + """ + def write_string(value: str) -> str: + return value - # msg = "\nWriting '{}' configs {} to {}..." - # print(msg.format(args["server"], dict(cfg.items()), USERCONFIGPATH)) + def write_int(value: int) -> str: + return str(value) - with open(USERCONFIGPATH, "w") as f: - UserConfig.write(f) + def write_bool(value: bool) -> str: + return {True: "true", False: "false"}[value] - # print("...write OK") + def write_list(value: list) -> str: + # INI-friendly string representation of Python list type + return str(value).strip("[]").replace("'", "") + + handlers = {str: write_string, + bool: write_bool, + list: write_list, + int: write_int} + + if key not in DEFAULTS: + msg = f"Don't know type of {key}; define in ofxget.DEFAULTS" + raise ValueError(msg) + + cfg_type = type(DEFAULTS[key]) + + return handlers[cfg_type](value) # type: ignore + + +def merge_config(args: argparse.Namespace, + config: configparser.ConfigParser) -> ArgType: + """ + Merge CLI args > user config > library config > OFX Home > defaults + """ + # All ArgumentParser args that have a value set + _args = {k: v for k, v in vars(args).items() if v is not None} + + if "server" in _args: + user_cfg = read_config(config, _args["server"]) + else: + user_cfg = {} + merged = ChainMap(_args, user_cfg, DEFAULTS) + + ofxhome_id = merged["ofxhome"] + if ofxhome_id: + lookup = ofxhome.lookup(ofxhome_id) + + if lookup is not None: + # Insert OFX Home lookup ahead of DEFAULTS but after + # user configs and library configs + merged.maps.insert(-1, {"url": lookup.url, "org": lookup.org, + "fid": lookup.fid, + "brokerid": lookup.brokerid}) + + if not (merged.get("url", None) + or merged.get("dryrun", False) + or merged.get("request", None) == "list"): + err = "Missing URL" + + if "server" not in _args: + msg = (f"{err} - please provide a server nickname, " + "or configure 'url' / 'ofxhome'") + raise ValueError(msg) + + server = _args["server"] + # Allow sloppy CLI args - passing URL as "server" positional arg + if urllib_parse.urlparse(server).scheme: + merged["url"] = server + merged["server"] = None + else: + msg = (f"{err} - please configure 'url' or 'ofxhome' " + f"for server '{server}'") + raise ValueError(msg) + + return merged + + +def init_client(args: ArgType) -> OFXClient: + """ + Initialize OFXClient with connection info from args + """ + client = OFXClient( + args["url"], + userid=args["user"] or None, + clientuid=args["clientuid"] or None, + org=args["org"] or None, + fid=args["fid"] or None, + version=args["version"], + appid=args["appid"] or None, + appver=args["appver"] or None, + language=args["language"] or None, + prettyprint=args["pretty"], + close_elements=not args["unclosedelements"], + bankid=args["bankid"] or None, + brokerid=args["brokerid"] or None, + ) + return client ############################################################################### -# HEAVY LIFTING +# PROFILE SCAN ############################################################################### def _scan_profile(url: str, org: Optional[str], @@ -728,39 +747,36 @@ def _scan_profile(url: str, Returns a 3-tuple of (OFXv1 results, OFXv2 results, signoninfo), each type(dict). OFX results provide ``ofxget`` configs that will work to - make a basic OFX connection. SIGNONINFO provides further auth information - that may be needed to succssfully log in. + make a basic OFX connection. SIGNONINFO reports further information + that may be helpful to authenticate successfully. """ client = OFXClient(url, org=org, fid=fid) - futures = schedule_scan(client, max_workers, timeout) - - # The only thing we're measuring here is success (indicated by receiving - # a valid HTTP response) or failure (indicated by the request's - # throwing any of various errors). We don't examine the actual response - # beyond simply parsing it to verify that it's valid OFX. The data we keep - # is actually the metadata (i.e. connection parameters like OFX version - # tried for a request) stored as values in the ``futures`` dict. - working: Mapping[int, List[FormatType]] = defaultdict(list) - signoninfos: MutableMapping[int, Any] = defaultdict(OrderedDict) - + futures = _queue_scans(client, max_workers, timeout) + + # The primary data we keep is actually the metadata (i.e. connection + # parameters - OFX version; prettyprint; unclosedelements) tagged on + # the Future by _queue_scans() that gave us a successful OFX connection. + success_params: FormatMap = defaultdict(list) + # If possible, we also parse out some data from SIGNONINFO included in + # the PROFRS. + signoninfo: SignoninfoReport = {} + + # Assume that SIGNONINFO is the same for each successful OFX PROFRS. + # Tell _read_scan_response() to stop parsing out SIGNONINFO once + # it's successfully extracted one. for future in concurrent.futures.as_completed(futures): (version, prettyprint, close_elements) = futures[future] + valid, signoninfo_ = _read_scan_response(future, not signoninfo) - if not test_scan_response(future, version, signoninfos): + if not valid: continue + if not signoninfo and signoninfo_: + signoninfo = signoninfo_ + success_params[version].append((prettyprint, close_elements)) - working[version].append((prettyprint, close_elements)) - - signoninfos = {k: v for k, v in signoninfos.items() if v} - if signoninfos: - highest_version = max(signoninfos.keys()) - signoninfo = signoninfos[highest_version] - else: - signoninfo = OrderedDict() - - v2, v1 = utils.partition(lambda result: result[0] < 200, working.items()) - v1_versions, v1_formats = collate_results(v1) - v2_versions, v2_formats = collate_results(v2) + v2, v1 = utils.partition(lambda it: it[0] < 200, success_params.items()) + v1_versions, v1_formats = collate_scan_results(v1) + v2_versions, v2_formats = collate_scan_results(v2) # V2 always has closing tags for elements; just report prettyprint for format in v2_formats: @@ -769,102 +785,110 @@ def _scan_profile(url: str, return (OrderedDict([("versions", v1_versions), ("formats", v1_formats)]), OrderedDict([("versions", v2_versions), ("formats", v2_formats)]), - signoninfo, - ) + signoninfo) + + +def _queue_scans(client: OFXClient, + max_workers: Optional[int], + timeout: Optional[float], + ) -> Dict[concurrent.futures.Future, ScanMetadata]: + ofxv1 = [102, 103, 151, 160] + ofxv2 = [200, 201, 202, 203, 210, 211, 220] + BOOLS = (False, True) + + futures = {} + with concurrent.futures.ThreadPoolExecutor(max_workers) as executor: + for version, pretty, close in itertools.product(ofxv1, BOOLS, BOOLS): + future = executor.submit(client.request_profile, + version=version, + prettyprint=pretty, + close_elements=close, + timeout=timeout) + futures[future] = (version, pretty, close) + + for version, pretty in itertools.product(ofxv2, BOOLS): + future = executor.submit(client.request_profile, + version=version, + prettyprint=pretty, + close_elements=True, + timeout=timeout) + futures[future] = (version, pretty, True) + + return futures + + +def _read_scan_response(future: concurrent.futures.Future, + read_signoninfo: bool = False, + ) -> Tuple[bool, SignoninfoReport]: + valid: bool = False + signoninfo: SignoninfoReport = {} -def test_scan_response(future, version, signoninfos): try: # ``future.result()`` returns an http.client.HTTPResponse response = future.result() - except (URLError, - HTTPError, - ConnectionError, - OSError, - socket.timeout, - ) as exc: + except (URLError, HTTPError, ConnectionError, OSError, socket.timeout): future.cancel() - return False + return valid, signoninfo # ``response`` is an HTTPResponse; doesn't have seek() method used # by ``header.parse_header()``. Repackage as BytesIO for parsing. - if not signoninfos[version]: + if read_signoninfo: with response as f: response_ = f.read() try: if not response_: - return False + return valid, signoninfo - signoninfos_ = extract_signoninfos(BytesIO(response_)) - assert len(signoninfos_) > 0 - info = signoninfos_[0] + signoninfos: List[models.SIGNONINFO] \ + = extract_signoninfos(BytesIO(response_)) + + assert len(signoninfos) > 0 + valid = True + info = signoninfos[0] bool_attrs = ("chgpinfirst", "clientuidreq", "authtokenfirst", - "mfachallengefirst", - ) - signoninfo_ = OrderedDict([ + "mfachallengefirst") + signoninfo = OrderedDict([ (attr, getattr(info, attr, None) or False) for attr in bool_attrs]) - signoninfos[version] = signoninfo_ except (socket.timeout, ): # We didn't receive a response at all - return False + valid = False except (ParseError, ET.ParseError, OFXHeaderError): # We didn't receive valid OFX in the response - return False + valid = False except (ValueError, ): # We received OFX, but not a valid PROFRS - pass - - return True - - -def schedule_scan(client, max_workers, timeout): - ofxv1 = [102, 103, 151, 160] - ofxv2 = [200, 201, 202, 203, 210, 211, 220] - - futures = {} - with concurrent.futures.ThreadPoolExecutor(max_workers) as executor: - for prettyprint in (False, True): - for close_elements in (False, True): - futures.update({executor.submit( - client.request_profile, - version=version, - prettyprint=prettyprint, - close_elements=close_elements, - timeout=timeout): - (version, prettyprint, close_elements) - for version in ofxv1}) - - futures.update({executor.submit( - client.request_profile, - version=version, - prettyprint=prettyprint, - close_elements=True, - timeout=timeout): - (version, prettyprint, True) for version in ofxv2}) + valid = True + else: + # IF we're not parsing the PROFRS, then we interpret receiving a good + # HTTP response as valid. + valid = True - return futures + return valid, signoninfo -def collate_results( - results: Tuple[int, FormatType] -) -> Tuple[List[int], List[FormatConfig]]: +def collate_scan_results( + scan_results: Iterable[Tuple[OFXVersion, FormatArgs]] +) -> Tuple[List[OFXVersion], List[FormatConfig]]: """ Transform the metadata (version, prettyprint, close_elements) tagged onto a concurrent.futures.Future instance for a successful run of - OFXClient.request_profile(). Returns a 2-tuple of ([OFX version], [format]) - where each format is a dict of {"pretty": bool, "unclosedelements": bool} - representing configs that successully connect for those versions. + OFXClient.request_profile(). + + Returns a 2-tuple of ([OFX version], [format]) where each format is a dict + of {"pretty": bool, "unclosedelements": bool} representing configs that + successully connect for those versions. - Input ``results`` needs to be a complete set for either OFXv1 or v2, + Input ``scan_results`` needs to be a complete set for either OFXv1 or v2, with no results for the other version admixed. """ - results_ = list(results) + results_ = list(scan_results) if not results_: return [], [] - versions, formats = zip(*results_) # type: ignore + versions, formats = zip(*results_) # Assumption: the same formatting requirements apply to all # sub-versions (e.g. 1.0.2 and 1.0.3, or 2.0.3 and 2.2.0). @@ -875,13 +899,16 @@ def collate_results( # Translation: just pick the longest sequence of successful # formats and assume it applies to the whole version. formats = max(formats, key=len) - formats.sort() - formats = [OrderedDict([("pretty", fmt[0]), - ("unclosedelements", not fmt[1])]) - for fmt in formats] - return sorted(list(versions)), formats + + formats_ = [OrderedDict([("pretty", fmt[0]), + ("unclosedelements", not fmt[1])]) + for fmt in sorted(formats, key=lambda x: (x[0], not x[1]))] + return sorted(list(versions)), formats_ +############################################################################### +# OFX PARSING +############################################################################### def verify_status(trnrs: models.Aggregate) -> None: """ Input a models.Aggregate instance representing a transaction wrapper. @@ -894,6 +921,10 @@ def verify_status(trnrs: models.Aggregate) -> None: raise ValueError(msg) +def _acctIsActive(acctinfo: AcctInfo) -> bool: + return acctinfo.svcstatus == "ACTIVE" + + def extract_signoninfos(markup: BytesIO) -> List[models.SIGNONINFO]: """ Input seralized OFX containing PROFRS @@ -910,8 +941,7 @@ def extract_signoninfos(markup: BytesIO) -> List[models.SIGNONINFO]: msgs = ofx.profmsgsrsv1 assert msgs is not None - def extract_signoninfo(trnrs): - assert isinstance(trnrs, models.PROFTRNRS) + def extract_signoninfo(trnrs: models.PROFTRNRS) -> List[models.SIGNONINFO]: verify_status(trnrs) rs = trnrs.profrs assert rs is not None @@ -920,11 +950,15 @@ def extract_signoninfo(trnrs): assert list_ is not None return list_[:] + # return list(itertools.chain.from_iterable( + # [extract_signoninfo(trnrs) for trnrs in msgs])) return list(itertools.chain.from_iterable( - [extract_signoninfo(trnrs) for trnrs in msgs])) + extract_signoninfo(trnrs) for trnrs in msgs)) -def extract_acctinfos(markup: BytesIO) -> Mapping: +def extract_acctinfos( + markup: BytesIO +) -> Iterable[AcctInfo]: """ Input seralized OFX containing ACCTINFORS Output dict-like object containing parsed *ACCTINFOs @@ -946,69 +980,50 @@ def extract_acctinfos(markup: BytesIO) -> Mapping: acctinfors = trnrs.acctinfors assert isinstance(acctinfors, models.ACCTINFORS) - # *ACCTINFO classes don't have rich comparison methods; - # can't sort by class - sortKey = attrgetter("__class__.__name__") - # ACCTINFOs are ListItems of ACCTINFORS # *ACCTINFOs are ListItems of ACCTINFO # The data we want is in a nested list - acctinfos = sorted(itertools.chain.from_iterable(acctinfors), key=sortKey) - - def _unique(ids, label): - ids = set(ids) - if len(ids) > 1: - msg = f"Multiple {label} {list(ids)}; can't configure automatically" - raise ValueError(msg) - try: - id = ids.pop() - except KeyError: - raise ValueError("{label} is empty") - return id - - def _ready(acctinfo): - return acctinfo.svcstatus == "ACTIVE" + return itertools.chain.from_iterable(acctinfors) - def parse_bank(acctinfos): - bankids = [] - args_ = defaultdict(list) - for inf in acctinfos: - if _ready(inf): - bankids.append(inf.bankid) - args_[inf.accttype.lower()].append(inf.acctid) - args_["bankid"] = _unique(bankids, "BANKIDs") - return dict(args_) +def parse_bankacctinfos( + acctinfos: Sequence[models.BANKACCTINFO] +) -> ParsedAcctinfo: + bankids = [] + args_: MutableMapping = defaultdict(list) + for inf in acctinfos: + if _acctIsActive(inf): + bankids.append(inf.bankid) + args_[inf.accttype.lower()].append(inf.acctid) - def parse_inv(acctinfos): - brokerids = [] - args_ = defaultdict(list) - for inf in acctinfos: - if _ready(inf): - acctfrom = inf.invacctfrom - brokerids.append(acctfrom.brokerid) - args_["investment"].append(acctfrom.acctid) + args_["bankid"] = utils.collapseToSingle(bankids, "BANKIDs") + return dict(args_) - args_["brokerid"] = _unique(brokerids, "BROKERIDs") - return dict(args_) - def parse_cc(acctinfos): - return {"creditcard": [inf.acctid for inf in acctinfos if _ready(inf)]} +def parse_invacctinfos( + acctinfos: Sequence[models.INVACCTINFO] +) -> ParsedAcctinfo: + brokerids = [] + args_: MutableMapping = defaultdict(list) + for inf in acctinfos: + if _acctIsActive(inf): + acctfrom = inf.invacctfrom + brokerids.append(acctfrom.brokerid) + args_["investment"].append(acctfrom.acctid) - dispatcher = {"BANKACCTINFO": parse_bank, - "CCACCTINFO": parse_cc, - "INVACCTINFO": parse_inv} + args_["brokerid"] = utils.collapseToSingle(brokerids, "BROKERIDs") + return dict(args_) - return ChainMap(*[dispatcher.get(clsName, lambda x: {})(_acctinfos) - for clsName, _acctinfos in itertools.groupby( - acctinfos, key=sortKey)]) - -def _merge_acctinfo(args: ArgType, markup: BytesIO) -> None: - # Insert extracted ACCTINFO after CLI commands, but before config files - args.maps.insert(1, extract_acctinfos(markup)) +def parse_ccacctinfos( + acctinfos: Sequence[models.CCACCTINFO] +) -> ParsedAcctinfo: + return {"creditcard": [i.acctid for i in acctinfos if _acctIsActive(i)]} +############################################################################### +# CLI UTILITIES +############################################################################### def list_fis(args: ArgType) -> None: server = args["server"] if server in (None, ""): diff --git a/ofxtools/utils.py b/ofxtools/utils.py index 8fe04e7..db1f1ef 100644 --- a/ofxtools/utils.py +++ b/ofxtools/utils.py @@ -6,7 +6,14 @@ import os import itertools import xml.etree.ElementTree as ET -from typing import Optional +from typing import ( + Any, + Optional, + Tuple, + Callable, + Iterable, + Sequence, +) import math @@ -30,6 +37,22 @@ def fixpath(path: str) -> str: return path +def collapseToSingle(items: Sequence, label: str): + """ + Given a sequence of repeated items, return the item that's repeated. + Throw an error if sequence is empty or contains >1 distinct item. + + ``label`` is the name used in error reporting. + """ + items_ = set(items) + if len(items_) == 0: + raise ValueError("{label} is empty") + if len(items_) > 1: + raise ValueError((f"Multiple {label} {list(items)}; " + "can't configure automatically")) + return items_.pop() + + ############################################################################### # date/time utilities ############################################################################### @@ -54,7 +77,7 @@ def gmt_offset(hours: int, minutes: int) -> datetime.timedelta: # itertools recipes # https://docs.python.org/2/library/itertools.html#recipes ############################################################################### -def pairwise(iterable): +def pairwise(iterable: Iterable) -> Iterable[Tuple[Any, Any]]: """ s -> (s0,s1), (s1,s2), (s2, s3), ... """ a, b = itertools.tee(iterable) next(b, None) @@ -67,7 +90,7 @@ def all_equal(iterable): return next(g, True) and not next(g, False) -def partition(pred, iterable): +def partition(pred: Callable, iterable: Iterable) -> Tuple[Iterable, Iterable]: """ Use a predicate to partition entries into false entries and true entries """ diff --git a/tests/test_ofxget.py b/tests/test_ofxget.py index 7d71a53..e933d73 100644 --- a/tests/test_ofxget.py +++ b/tests/test_ofxget.py @@ -3,7 +3,7 @@ # stdlib imports import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import patch, DEFAULT from datetime import datetime from io import BytesIO import argparse @@ -11,9 +11,15 @@ import collections import urllib from configparser import ConfigParser +from collections import ChainMap +import xml.etree.ElementTree as ET +import concurrent.futures +from urllib.error import HTTPError, URLError +import socket # local imports +from ofxtools import models, header, Parser from ofxtools.Client import ( OFXClient, StmtRq, @@ -25,9 +31,9 @@ from ofxtools.utils import UTC from ofxtools.scripts import ofxget from ofxtools.ofxhome import OFXServer -from ofxtools import models # test imports +import base import test_models_msgsets import test_models_signup @@ -39,6 +45,9 @@ def testMakeArgparser(self): self.assertGreater(len(argparser._actions), 0) +############################################################################### +# CLI METHODS +############################################################################### class CliTestCase(unittest.TestCase): @property def args(self): @@ -111,11 +120,7 @@ def testScanProfile(self): args, kwargs = mock_print.call_args self.assertEqual(len(args), 1) - # FIXME - the string output of dicts in json.dumps() appears - # not to be stable; below is for Py 3.5 - 3.7, whereas - # Py 3.4 looks like this: - #[{"versions": [102, 103], "formats": [{"unclosedelements": true, "pretty": false}, {"unclosedelements": false, "pretty": true}]}, {"versions": [203], "formats": [{"pretty": false}, {"pretty": true}]}] - + # FIXME - json.dumps output for dicts isn't stable # self.maxDiff = None # self.assertEqual( # args[0], @@ -125,6 +130,79 @@ def testScanProfile(self): # '{"versions": [203], "formats": [{"pretty": false}, ' # '{"pretty": true}]}]')) + def testScanProfileNoResult(self): + with patch("ofxtools.scripts.ofxget._scan_profile") as mock_scan_prof: + with patch("builtins.print") as mock_print: + mock_scan_prof.return_value = ({"versions": []}, + {"versions": []}, + {}) + result = ofxget.scan_profile(self.args) + self.assertIsNone(result, None) + + args, kwargs = mock_scan_prof.call_args + self.assertEqual(len(args), 3) + url, org, fid = args + self.assertEqual(url, self.args["url"]) + self.assertEqual(org, self.args["org"]) + self.assertEqual(fid, self.args["fid"]) + self.assertEqual(len(kwargs), 0) + + args, kwargs = mock_print.call_args + self.assertEqual(len(args), 1) + + self.maxDiff = None + self.assertEqual( + args[0], f"Scan found no working formats for {url}") + + def testScanProfileWrite(self): + with patch.multiple("ofxtools.scripts.ofxget", + _scan_profile=DEFAULT, + write_config=DEFAULT) as MOCKS: + mock_scan_prof = MOCKS["_scan_profile"] + mock_write_config = MOCKS["write_config"] + + ofxv1 = collections.OrderedDict([ + ("versions", [102, 103]), + ("formats", [{"pretty": False, "unclosedelements": True}, + {"pretty": True, "unclosedelements": False}])]) + + ofxv2 = collections.OrderedDict([ + ("versions", [203]), + ("formats", [{"pretty": False}, + {"pretty": True}])]) + + signoninfo = collections.OrderedDict([ + ('chgpinfirst', False), + ('clientuidreq', False), + ('authtokenfirst', False), + ('mfachallengefirst', False)]) + + mock_scan_prof.return_value = (ofxv1, ofxv2, signoninfo) + + ARGS = ChainMap({"write": True, "dryrun": False}, self.args) + + with patch("builtins.print") as mock_print: + result = ofxget.scan_profile(ARGS) + + self.assertEqual(result, None) + + args, kwargs = mock_scan_prof.call_args + self.assertEqual(len(args), 3) + url, org, fid = args + self.assertEqual(url, self.args["url"]) + self.assertEqual(org, self.args["org"]) + self.assertEqual(fid, self.args["fid"]) + self.assertEqual(len(kwargs), 0) + + args, kwargs = mock_write_config.call_args + self.assertEqual(len(args), 1) + args = args[0] + + ARGS["version"] = 203 # best version + # self.assertEqual(dict(args), dict(ARGS)) + self.assertEqual(dict(args), dict(ARGS)) + self.assertEqual(len(kwargs), 0) + def testRequestProfile(self): with patch("ofxtools.Client.OFXClient.request_profile") as fake_rqprof: with patch("builtins.print") as mock_print: @@ -193,7 +271,27 @@ def test_RequestAcctinfo(self): self.assertEqual(kwargs, {"dryrun": self.args["dryrun"], - "verify_ssl": not self.args["unsafe"]}) + "verify_ssl": not self.args["unsafe"]}) + + def testMergeAcctinfo(self): + """ Unit test for ofxtools.scripts.ofxget._merge_acctinfo() """ + cli = {"dryrun": True} + config = {"pretty": False} + args = ChainMap(cli, config) + + markup = OFXClient("").serialize(ExtractAcctInfosTestCase.ofx) + + ofxget._merge_acctinfo(args, BytesIO(markup)) + + # Have extracted bankid, brokerid, checking, creditcard, investment + self.assertEqual(len(args), 7) + self.assertEqual(args["dryrun"], True) + self.assertEqual(args["pretty"], False) + self.assertEqual(args["bankid"], '111000614') + self.assertEqual(args["brokerid"], '111000614') + self.assertEqual(args["checking"], ["123456789123456789"]) + self.assertEqual(args["creditcard"], ["123456789123456789"]) + self.assertEqual(args["investment"], ["123456789123456789"]) def testRequestStmt(self): args = self.args @@ -526,88 +624,152 @@ def testInitClient(self): self.assertEqual(getattr(client, arg), args[arg]) -class _ScanProfileTestCase(unittest.TestCase): - """ Unit tests for ofxtools.scripts.ofxget._scan_profile() """ - ofx = models.OFX( - signonmsgsrsv1=test_models_msgsets.Signonmsgsrsv1TestCase.aggregate, - profmsgsrsv1=test_models_msgsets.Profmsgsrsv1TestCase.aggregate) +############################################################################### +# ARGUMENT/CONFIG HANDLERS +############################################################################### +class MkServerCfgTestCase(unittest.TestCase): + """ Unit tests for ofxtools.script.ofxget.mk_server_cfg() """ + def testMkservercfg(self): + with patch("ofxtools.scripts.ofxget.UserConfig", new=ConfigParser()): + # FIXME - patching the classproperty isn't working + # with patch("ofxtools.Client.OFXClient.uuid", new="DEADBEEF"): - @property - def client(self): - return OFXClient("https://ofx.test.com") + # Must have "server" arg + with self.assertRaises(ValueError): + ofxget.mk_server_cfg({"foo": "bar"}) - errcount = 0 + # "server" arg can't have been sourced from "url" arg + with self.assertRaises(ValueError): + ofxget.mk_server_cfg({"server": "foo", "url": "foo"}) - def prof_result(self, version, prettyprint, close_elements, **kwargs): - # Sequence of errors caught for futures.result() in _scan_profile() - errors = (urllib.error.URLError(None, None), - urllib.error.HTTPError(None, None, None, None, None), - ConnectionError(), - OSError(), - ) - accept = [ - (102, False, False), - (102, True, True), - (103, False, False), - (103, True, True), - (203, False, True), - (203, True, True) - ] - if (version, prettyprint, close_elements) in accept: - ofx = self.client.serialize(self.ofx, - version, - prettyprint, - close_elements) - return BytesIO(ofx) - else: - error = errors[self.errcount % len(errors)] - self.errcount += 1 - raise error + results = dict(ofxget.mk_server_cfg( + {"server": "myserver", "url": "https://ofxget.test.com", + "version": 203, "ofxhome": "123", "org": "TEST", "fid": "321", + "brokerid": "test.com", "bankid": "11235813", + "user": "porkypig", "pretty": True, + "unclosedelements": False})) - def test_scanProfile(self): - with patch("ofxtools.Client.OFXClient.request_profile") as mock_profrq: - mock_profrq.side_effect = self.prof_result - results = ofxget._scan_profile(None, None, None) + self.assertIn("clientuid", results) - ofxv1 = collections.OrderedDict([ - ("versions", [102, 103]), - ("formats", [{"pretty": False, "unclosedelements": True}, - {"pretty": True, "unclosedelements": False}])]) + # FIXME - patching the classproperty isn't working + del results["clientuid"] - ofxv2 = collections.OrderedDict([ - ("versions", [203]), - ("formats", [{"pretty": False}, - {"pretty": True}])]) + # args equal to defaults are omitted from the results + predicted = { + "url": "https://ofxget.test.com", "ofxhome": "123", + "org": "TEST", "fid": "321", "brokerid": "test.com", + "bankid": "11235813", "user": "porkypig", "pretty": "true"} - signoninfo = collections.OrderedDict([ - ('chgpinfirst', False), - ('clientuidreq', False), - ('authtokenfirst', False), - ('mfachallengefirst', False)]) + self.assertEqual(dict(results), predicted) - self.assertEqual(results, (ofxv1, ofxv2, signoninfo)) + for opt, val in predicted.items(): + self.assertEqual(ofxget.UserConfig["myserver"][opt], val) -class ExtractAcctInfosTestCase(unittest.TestCase): - """ Unit tests for ofxtools.scripts.ofxget.extract_acctinfos() """ - ofx = models.OFX( - signonmsgsrsv1=test_models_msgsets.Signonmsgsrsv1TestCase.aggregate, - signupmsgsrsv1=models.SIGNUPMSGSRSV1( - test_models_signup.AcctinfotrnrsTestCase.aggregate)) +class ArgConfigTestCase(unittest.TestCase): + """ + Unit tests for ofxtools.scripts.ofxget.config2arg() and + ofxtools.scripts.ofxget.arg2config() + """ + def testList2arg(self): + for cfg in ("checking", "savings", "moneymrkt", "creditline", + "creditcard", "investment", "years"): + self.assertEqual(ofxget.config2arg(cfg, "123"), ["123"]) + self.assertEqual(ofxget.config2arg(cfg, "123,456"), ["123", "456"]) - @property - def client(self): - return OFXClient("https://ofx.test.com") + # Surrounding whitespace is stripped + self.assertEqual(ofxget.config2arg(cfg, " 123 "), ["123"]) + self.assertEqual(ofxget.config2arg(cfg, "123, 456"), ["123", "456"]) - def test_extract_acctinfos(self): - ofx = self.client.serialize(self.ofx) - results = ofxget.extract_acctinfos(BytesIO(ofx)) - self.assertEqual(len(results), 5) - self.assertEqual(results["bankid"], "111000614") - self.assertEqual(results["brokerid"], "111000614") - self.assertEqual(results["checking"], ["123456789123456789"]) - self.assertEqual(results["creditcard"], ["123456789123456789"]) - self.assertEqual(results["investment"], ["123456789123456789"]) + def testList2config(self): + for cfg in ("checking", "savings", "moneymrkt", "creditline", + "creditcard", "investment", "years"): + self.assertEqual(ofxget.arg2config(cfg, ["123"]), "123") + self.assertEqual(ofxget.arg2config(cfg, ["123", "456"]), "123, 456") + + def testListRoundtrip(self): + for cfg in ("checking", "savings", "moneymrkt", "creditline", + "creditcard", "investment", "years"): + self.assertEqual( + ofxget.config2arg(cfg, ofxget.arg2config(cfg, ["123", "456"])), + ["123", "456"]) + self.assertEqual( + ofxget.arg2config(cfg, ofxget.config2arg(cfg, "123, 456")), + "123, 456") + + def testBool2arg(self): + for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", + "inctran", "incbal", "incpos", "incoo", "all", "write"): + self.assertEqual(ofxget.config2arg(cfg, "true"), True) + self.assertEqual(ofxget.config2arg(cfg, "false"), False) + self.assertEqual(ofxget.config2arg(cfg, "yes"), True) + self.assertEqual(ofxget.config2arg(cfg, "no"), False) + self.assertEqual(ofxget.config2arg(cfg, "on"), True) + self.assertEqual(ofxget.config2arg(cfg, "off"), False) + self.assertEqual(ofxget.config2arg(cfg, "1"), True) + self.assertEqual(ofxget.config2arg(cfg, "0"), False) + + def testBool2config(self): + for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", + "inctran", "incbal", "incpos", "incoo", "all", "write"): + self.assertEqual(ofxget.arg2config(cfg, True), "true") + self.assertEqual(ofxget.arg2config(cfg, False), "false") + + def testBoolRoundtrip(self): + for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", + "inctran", "incbal", "incpos", "incoo", "all", "write"): + self.assertEqual( + ofxget.config2arg(cfg, ofxget.arg2config(cfg, True)), + True) + self.assertEqual( + ofxget.config2arg(cfg, ofxget.arg2config(cfg, False)), + False) + self.assertEqual( + ofxget.arg2config(cfg, ofxget.config2arg(cfg, "true")), + "true") + self.assertEqual( + ofxget.arg2config(cfg, ofxget.config2arg(cfg, "false")), + "false") + + def testInt2arg(self): + for cfg in ("version", ): + self.assertEqual(ofxget.config2arg(cfg, "1"), 1) + + def testInt2config(self): + for cfg in ("version", ): + self.assertEqual(ofxget.arg2config(cfg, 1), "1") + + def testIntRoundtrip(self): + for cfg in ("version", ): + self.assertEqual( + ofxget.config2arg(cfg, ofxget.arg2config(cfg, 1)), + 1) + self.assertEqual( + ofxget.arg2config(cfg, ofxget.config2arg(cfg, "1")), + "1") + + def testString2arg(self): + for cfg in ("url", "org", "fid", "appid", "appver", "bankid", + "brokerid", "user", "clientuid", "language", "acctnum", + "recid"): + self.assertEqual(ofxget.config2arg(cfg, "Something"), "Something") + + def testString2config(self): + for cfg in ("url", "org", "fid", "appid", "appver", "bankid", + "brokerid", "user", "clientuid", "language", "acctnum", + "recid"): + self.assertEqual(ofxget.arg2config(cfg, "Something"), "Something") + + def testStringRoundtrip(self): + for cfg in ("url", "org", "fid", "appid", "appver", "bankid", + "brokerid", "user", "clientuid", "language", "acctnum", + "recid"): + self.assertEqual( + ofxget.config2arg(cfg, ofxget.arg2config(cfg, "Something")), + "Something") + self.assertEqual( + ofxget.arg2config(cfg, ofxget.config2arg(cfg, "Something")), + "Something") class MergeConfigTestCase(unittest.TestCase): @@ -753,149 +915,227 @@ def testMergeConfigUnknownFiArg(self): ofxget.merge_config(args, ofxget.UserConfig) -class ArgConfigTestCase(unittest.TestCase): - """ - Unit tests for ofxtools.scripts.ofxget.config2arg() and - ofxtools.scripts.ofxget.arg2config() - """ - def testList2arg(self): - for cfg in ("checking", "savings", "moneymrkt", "creditline", - "creditcard", "investment", "years"): - self.assertEqual(ofxget.config2arg(cfg, "123"), ["123"]) - self.assertEqual(ofxget.config2arg(cfg, "123,456"), ["123", "456"]) +############################################################################### +# PROFILE SCAN +############################################################################### +class ScanProfileTestCase(unittest.TestCase): + """ Unit tests for ofxtools.scripts.ofxget._scan_profile() and helpers """ + ofx = models.OFX( + signonmsgsrsv1=test_models_msgsets.Signonmsgsrsv1TestCase.aggregate, + profmsgsrsv1=test_models_msgsets.Profmsgsrsv1TestCase.aggregate) - # Surrounding whitespace is stripped - self.assertEqual(ofxget.config2arg(cfg, " 123 "), ["123"]) - self.assertEqual(ofxget.config2arg(cfg, "123, 456"), ["123", "456"]) + @property + def client(self): + return OFXClient("https://ofx.test.com") - def testList2config(self): - for cfg in ("checking", "savings", "moneymrkt", "creditline", - "creditcard", "investment", "years"): - self.assertEqual(ofxget.arg2config(cfg, ["123"]), "123") - self.assertEqual(ofxget.arg2config(cfg, ["123", "456"]), "123, 456") + errcount = 0 - def testListRoundtrip(self): - for cfg in ("checking", "savings", "moneymrkt", "creditline", - "creditcard", "investment", "years"): - self.assertEqual( - ofxget.config2arg(cfg, ofxget.arg2config(cfg, ["123", "456"])), - ["123", "456"]) - self.assertEqual( - ofxget.arg2config(cfg, ofxget.config2arg(cfg, "123, 456")), - "123, 456") + def prof_result(self, version, prettyprint, close_elements, **kwargs): + # Sequence of errors caught for futures.result() in _scan_profile() + errors = (urllib.error.URLError(None, None), + urllib.error.HTTPError(None, None, None, None, None), + ConnectionError(), + OSError(), + ) + accept = [ + (102, False, False), + (102, True, True), + (103, False, False), + (103, True, True), + (203, False, True), + (203, True, True) + ] + if (version, prettyprint, close_elements) in accept: + ofx = self.client.serialize(self.ofx, + version, + prettyprint, + close_elements) + return BytesIO(ofx) + else: + error = errors[self.errcount % len(errors)] + self.errcount += 1 + raise error - def testBool2arg(self): - for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", - "inctran", "incbal", "incpos", "incoo", "all", "write"): - self.assertEqual(ofxget.config2arg(cfg, "true"), True) - self.assertEqual(ofxget.config2arg(cfg, "false"), False) - self.assertEqual(ofxget.config2arg(cfg, "yes"), True) - self.assertEqual(ofxget.config2arg(cfg, "no"), False) - self.assertEqual(ofxget.config2arg(cfg, "on"), True) - self.assertEqual(ofxget.config2arg(cfg, "off"), False) - self.assertEqual(ofxget.config2arg(cfg, "1"), True) - self.assertEqual(ofxget.config2arg(cfg, "0"), False) + def test_scanProfile(self): + with patch("ofxtools.Client.OFXClient.request_profile") as mock_profrq: + mock_profrq.side_effect = self.prof_result + results = ofxget._scan_profile(None, None, None) - def testBool2config(self): - for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", - "inctran", "incbal", "incpos", "incoo", "all", "write"): - self.assertEqual(ofxget.arg2config(cfg, True), "true") - self.assertEqual(ofxget.arg2config(cfg, False), "false") + ofxv1 = collections.OrderedDict([ + ("versions", [102, 103]), + ("formats", [collections.OrderedDict([("pretty", False), + ("unclosedelements", True)]), + collections.OrderedDict([("pretty", True), + ("unclosedelements", False)]), + ])]) - def testBoolRoundtrip(self): - for cfg in ("dryrun", "unsafe", "unclosedelements", "pretty", - "inctran", "incbal", "incpos", "incoo", "all", "write"): - self.assertEqual( - ofxget.config2arg(cfg, ofxget.arg2config(cfg, True)), - True) - self.assertEqual( - ofxget.config2arg(cfg, ofxget.arg2config(cfg, False)), - False) - self.assertEqual( - ofxget.arg2config(cfg, ofxget.config2arg(cfg, "true")), - "true") - self.assertEqual( - ofxget.arg2config(cfg, ofxget.config2arg(cfg, "false")), - "false") + ofxv2 = collections.OrderedDict([ + ("versions", [203]), + ("formats", [collections.OrderedDict([("pretty", False)]), + collections.OrderedDict([("pretty", True)])])]) - def testInt2arg(self): - for cfg in ("version", ): - self.assertEqual(ofxget.config2arg(cfg, "1"), 1) + signoninfo = collections.OrderedDict([ + ('chgpinfirst', False), + ('clientuidreq', False), + ('authtokenfirst', False), + ('mfachallengefirst', False)]) - def testInt2config(self): - for cfg in ("version", ): - self.assertEqual(ofxget.arg2config(cfg, 1), "1") + self.assertEqual(len(results), 3) + self.assertEqual(results[0], ofxv1) + self.assertEqual(results[1], ofxv2) + self.assertEqual(results[2], signoninfo) - def testIntRoundtrip(self): - for cfg in ("version", ): - self.assertEqual( - ofxget.config2arg(cfg, ofxget.arg2config(cfg, 1)), - 1) - self.assertEqual( - ofxget.arg2config(cfg, ofxget.config2arg(cfg, "1")), - "1") + def testQueueScanResponse(self): + """ Test ofxget._queue_scans() """ + with patch("ofxtools.Client.OFXClient.request_profile") as mock_profrq: + mock_profrq.side_effect = self.prof_result - def testString2arg(self): - for cfg in ("url", "org", "fid", "appid", "appver", "bankid", - "brokerid", "user", "clientuid", "language", "acctnum", - "recid"): - self.assertEqual(ofxget.config2arg(cfg, "Something"), "Something") + futures = ofxget._queue_scans(self.client, max_workers=1, timeout=1.0) - def testString2config(self): - for cfg in ("url", "org", "fid", "appid", "appver", "bankid", - "brokerid", "user", "clientuid", "language", "acctnum", - "recid"): - self.assertEqual(ofxget.arg2config(cfg, "Something"), "Something") + # OFXv1: pretty, unclosed True/False for 6 versions; 4 * 4 = 16 + # OFXv2: pretty True/False for 7 versions ; 7 * 2 = 12 + self.assertEqual(len(futures), 30) - def testStringRoundtrip(self): - for cfg in ("url", "org", "fid", "appid", "appver", "bankid", - "brokerid", "user", "clientuid", "language", "acctnum", - "recid"): - self.assertEqual( - ofxget.config2arg(cfg, ofxget.arg2config(cfg, "Something")), - "Something") - self.assertEqual( - ofxget.arg2config(cfg, ofxget.config2arg(cfg, "Something")), - "Something") + for future, format in futures.items(): + self.assertIsInstance(future, concurrent.futures.Future) + self.assertEqual(len(format), 3) + self.assertIn( + format[0], + [102, 103, 151, 160, 200, 201, 202, 203, 210, 211, 220]) + self.assertIsInstance(format[1], bool) + self.assertIsInstance(format[2], bool) -class MkServerCfgTestCase(unittest.TestCase): - """ Unit tests for ofxtools.script.ofxget.mk_server_cfg() """ - def testMkservercfg(self): - with patch("ofxtools.scripts.ofxget.UserConfig", new=ConfigParser()): - # FIXME - patching the classproperty isn't working - # with patch("ofxtools.Client.OFXClient.uuid", new="DEADBEEF"): +class ReadScanResponseTestCase(unittest.TestCase): + ofx = models.OFX( + signonmsgsrsv1=test_models_msgsets.Signonmsgsrsv1TestCase.aggregate, + profmsgsrsv1=test_models_msgsets.Profmsgsrsv1TestCase.aggregate) - # Must have "server" arg - with self.assertRaises(ValueError): - ofxget.mk_server_cfg({"foo": "bar"}) + @property + def client(self): + return OFXClient("https://ofx.test.com") - # "server" arg can't have been sourced from "url" arg - with self.assertRaises(ValueError): - ofxget.mk_server_cfg({"server": "foo", "url": "foo"}) + def testReadScanResponse(self): + markup = self.client.serialize(self.ofx) + + # Connection error: return False, empty SIGNONINFO parameters + rq_errors = [URLError(""), + HTTPError(None, None, None, None, None), + ConnectionError(""), + OSError(""), + socket.timeout, + ] + + for error in rq_errors: + with patch("concurrent.futures.Future.result") as mock_result: + mock_result.side_effect = error + + future = concurrent.futures.Future() + result = ofxget._read_scan_response(future) + + self.assertEqual(len(result), 2) + self.assertFalse(result[0]) + self.assertEqual(result[1], {}) + + # No valid OFX: return False, empty SIGNONINFO parameters + ofx_errors = [ + socket.timeout, + ET.ParseError(), + Parser.ParseError(), + header.OFXHeaderError(), + ] + for error in ofx_errors: + with patch("concurrent.futures.Future.result") as mock_result: + mock_result.return_value = BytesIO(markup) + with patch("ofxtools.scripts.ofxget.extract_signoninfos") as mock_extract_signoninfos: + mock_extract_signoninfos.side_effect = error + + future = concurrent.futures.Future() + result = ofxget._read_scan_response(future, read_signoninfo=True) + + self.assertEqual(len(result), 2) + self.assertFalse(result[0]) + self.assertEqual(result[1], {}) + + # Valid OFX with no good SIGNONINFO: return True, empty SIGNONINFO + with patch("concurrent.futures.Future.result") as mock_result: + mock_result.return_value = BytesIO(markup) + with patch("ofxtools.scripts.ofxget.extract_signoninfos") as mock_extract_signoninfos: + mock_extract_signoninfos.side_effect = ValueError() + + future = concurrent.futures.Future() + result = ofxget._read_scan_response(future, read_signoninfo=True) + + self.assertEqual(len(result), 2) + self.assertTrue(result[0]) + self.assertEqual(result[1], {}) + + # Valid OFX with good SIGNONINFO: return True, SIGNONINFO parameters + with patch("concurrent.futures.Future.result") as mock_result: + mock_result.return_value = BytesIO(markup) + + future = concurrent.futures.Future() + result = ofxget._read_scan_response(future, read_signoninfo=True) + + self.assertEqual(len(result), 2) + self.assertTrue(result[0]) + signoninfo = result[1] + self.assertIsInstance(signoninfo, collections.OrderedDict) + self.assertEqual(len(signoninfo), 4) + self.assertEqual(set(signoninfo.keys()), + set(["chgpinfirst", "clientuidreq", + "authtokenfirst", "mfachallengefirst"])) + + +class CollateScanResultsTestCase(unittest.TestCase): + def testCollateScanResults(self): + v1 = [(160, [(True, False), (False, True), (False, False)]), + (102, [(True, False), (False, True), (False, False)]), + (103, [(True, False), (False, True), (False, False)]), + ] + versions, formats = ofxget.collate_scan_results(v1) + self.assertEqual(versions, [102, 103, 160]) + + self.assertEqual(formats, [ + collections.OrderedDict([("pretty", False), ("unclosedelements", False)]), + collections.OrderedDict([("pretty", False), ("unclosedelements", True)]), + collections.OrderedDict([("pretty", True), ("unclosedelements", True)]), + ]) + + +############################################################################### +# OFX PARSING +############################################################################### +class ExtractAcctInfosTestCase(unittest.TestCase): + """ Unit tests for ofxtools.scripts.ofxget.extract_acctinfos() """ + ofx = models.OFX( + signonmsgsrsv1=test_models_msgsets.Signonmsgsrsv1TestCase.aggregate, + signupmsgsrsv1=models.SIGNUPMSGSRSV1( + test_models_signup.AcctinfotrnrsTestCase.aggregate)) - results = dict(ofxget.mk_server_cfg( - {"server": "myserver", "url": "https://ofxget.test.com", - "version": 203, "ofxhome": "123", "org": "TEST", "fid": "321", - "brokerid": "test.com", "bankid": "11235813", - "user": "porkypig", "pretty": True, - "unclosedelements": False})) + @property + def client(self): + return OFXClient("https://ofx.test.com") - self.assertIn("clientuid", results) + def test_extract_acctinfos(self): + ofx = self.client.serialize(self.ofx) + results = ofxget.extract_acctinfos(BytesIO(ofx)) + # results is an iterator - sorted by *ACCTINFO classname + results = list(results) - # FIXME - patching the classproperty isn't working - del results["clientuid"] + acctinfo = test_models_signup.AcctinfoTestCase.aggregate - # args equal to defaults are omitted from the results - predicted = { - "url": "https://ofxget.test.com", "ofxhome": "123", - "org": "TEST", "fid": "321", "brokerid": "test.com", - "bankid": "11235813", "user": "porkypig", "pretty": "true"} + # HACK - Reuse base.OfxTestCase._eqAggregate to determine that + # our results are the same as the children of + # test_models_signup.AcctinfoTestCase, which was used to construct + # self.ofx + class Foo(unittest.TestCase, base.OfxTestCase): + ... - self.assertEqual(dict(results), predicted) + tc = Foo() - for opt, val in predicted.items(): - self.assertEqual(ofxget.UserConfig["myserver"][opt], val) + for n in range(3): + tc._eqAggregate(results[n], acctinfo[n]) if __name__ == "__main__":