diff --git a/README.md b/README.md index 996f4c4..175790a 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ It does not however validate or use a whitelist of tickers. * The results are in the order they are first found. * By default, the results are deduplicated, although this can be disabled. * A configurable blacklist of common false-positives is used. -* A configurable remapping of specific tickers is supported. +* A configurable remapping of tickers is supported. * For lower level use, a configurably created compiled regular expression can be accessed. ## Links @@ -76,4 +76,7 @@ Python ≥3.8 is required. To install, run: >>> reticker.config.MAPPING["BTC"] = "BTC-USD" >>> reticker.TickerExtractor().extract("What is the Yahoo Finance symbol for BTC?") ['BTC-USD'] +>>> reticker.config.MAPPING["COMP"] = ["COMP", "COMP-USD"] +>>> reticker.TickerExtractor().extract('Is COMP for the equity named "Compass" or for the crypto named "Compound"? I want both!') +['COMP', 'COMP-USD'] ``` diff --git a/reticker/config/__init__.py b/reticker/config/__init__.py index ace7c0d..8f5e1c1 100644 --- a/reticker/config/__init__.py +++ b/reticker/config/__init__.py @@ -1,8 +1,8 @@ """Package config.""" from pathlib import Path -from typing import Dict, Final, List, Set +from typing import Dict, Final, List, Set, Union _CONFIG_PATH: Final[Path] = Path(__file__).parent BLACKLIST_PATHS: Final[List[Path]] = list((_CONFIG_PATH / "blacklist").glob("*.txt")) BLACKLIST: Final[Set[str]] = set(term for path in BLACKLIST_PATHS for term in path.read_text().strip().split("\n")) -MAPPING: Final[Dict[str, str]] = {} +MAPPING: Final[Dict[str, Union[str, List[str]]]] = {} diff --git a/reticker/reticker.py b/reticker/reticker.py index 60768d9..f381a0b 100644 --- a/reticker/reticker.py +++ b/reticker/reticker.py @@ -87,8 +87,9 @@ def append_patterns(part1: str, part2: str) -> None: def extract(self, text: str) -> List[str]: """Return possible tickers extracted from the given text.""" matches = [match.upper() for match in self.pattern.findall(text)] - matches = [match for match in matches if match not in config.BLACKLIST] # Is done _before_ 'mapping'. + matches = [match for match in matches if match not in config.BLACKLIST] # Is done _before_ mapping. matches = [config.MAPPING.get(match, match) for match in matches] + matches = [inner_m for m_list in [[outer_m] if isinstance(outer_m, str) else outer_m for outer_m in matches] for inner_m in m_list] # Conditional flattening of mapping. if self.deduplicate: - matches = list(dict.fromkeys(matches)) # Is done _after_ 'mapping'. + matches = list(dict.fromkeys(matches)) # Is done _after_ mapping. return matches diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 47d9710..a255c6d 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -52,7 +52,8 @@ def test_blacklist(self): reticker.config.BLACKLIST.remove("BLCN") self.assertEqual(reticker.config.BLACKLIST, original_blacklist) - def test_mapping(self): + def test_mapping_to_str(self): + self.assertNotIn("ADA", reticker.config.MAPPING) self.assertNotIn("BTC", reticker.config.MAPPING) text = "ADA BTC RIOT" self.assertEqual(self.default_ticker_extractor.extract(text), ["ADA", "BTC", "RIOT"]) @@ -69,6 +70,22 @@ def test_mapping(self): del reticker.config.MAPPING["BTC"] self.assertEqual(reticker.config.MAPPING, original_mapping) + def test_mapping_to_list(self): + self.assertNotIn("COMP", reticker.config.MAPPING) + self.assertNotIn("USD", reticker.config.MAPPING) + text = 'Is COMP for the equity "Compass, Inc." or is it for the cryptocurrency "Compound USD"?' + self.assertEqual(self.default_ticker_extractor.extract(text), ["COMP", "USD"]) + + original_mapping = reticker.config.MAPPING.copy() + reticker.config.MAPPING["COMP"] = ["COMP", "COMP-USD"] + self.assertEqual(self.default_ticker_extractor.extract(text), ["COMP", "COMP-USD", "USD"]) + reticker.config.MAPPING["USD"] = ["DXY"] + self.assertEqual(self.default_ticker_extractor.extract(text), ["COMP", "COMP-USD", "DXY"]) + + del reticker.config.MAPPING["COMP"] + del reticker.config.MAPPING["USD"] + self.assertEqual(reticker.config.MAPPING, original_mapping) + class TestCustomExtraction(unittest.TestCase): def test_prefixed_uppercase(self):