Skip to content

Commit

Permalink
Support multiplier on source prices.
Browse files Browse the repository at this point in the history
  • Loading branch information
SEIAROTg committed Jan 3, 2023
1 parent b761102 commit eee4050
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 26 deletions.
19 changes: 13 additions & 6 deletions beanprice/price.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import collections
import datetime
from decimal import Decimal
import functools
from os import path
import shelve
Expand Down Expand Up @@ -38,10 +39,13 @@
# module: A Python module, the module to be called to create a price source.
# symbol: A ticker symbol in the universe of the source.
# invert: A boolean, true if we need to invert the currency.
# multiplier: A Decimal instance to be multiplied on prices from the source.
# This is useful with sources returning 1.23 USD as 123.
class PriceSource(NamedTuple):
module: Any
symbol: str
invert: bool
multiplier: Decimal


# A dated price source description.
Expand Down Expand Up @@ -151,7 +155,7 @@ def parse_single_source(source: str) -> PriceSource:
Source specifications follow the syntax:
<module>/[^]<ticker>
[<multiplier>*]<module>/[^]<ticker>
The <module> is resolved against the Python path, but first looked up
under the package where the default price extractors lie.
Expand All @@ -163,12 +167,15 @@ def parse_single_source(source: str) -> PriceSource:
Raises:
ValueError: If invalid.
"""
match = re.match(r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source)
match = re.match(r'(?:([0-9]+(?:\.[0-9]+)?)\*)?'
r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source)
if not match:
raise ValueError('Invalid source name: "{}"'.format(source))
short_module_name, invert, symbol = match.groups()
multiplier_str, short_module_name, invert, symbol = match.groups()
module = import_source(short_module_name)
return PriceSource(module, symbol, bool(invert))
multiplier = Decimal(multiplier_str) if multiplier_str else Decimal(1)
print(f'{multiplier=} {multiplier_str=}')
return PriceSource(module, symbol, bool(invert), multiplier)


def import_source(module_name: str):
Expand Down Expand Up @@ -323,7 +330,7 @@ def get_price_jobs_at_date(entries: data.Entries,

# If there are no sources, create a default one.
if not psources:
psources = [PriceSource(default_source, base, False)]
psources = [PriceSource(default_source, base, False, Decimal(1))]

jobs.append(DatedPrice(base, quote, date, psources))
return sorted(jobs)
Expand Down Expand Up @@ -599,7 +606,7 @@ def fetch_price(dprice: DatedPrice, swap_inverted: bool = False) -> Optional[dat

base = dprice.base
quote = dprice.quote or srcprice.quote_currency
price = srcprice.price
price = srcprice.price * psource.multiplier

# Invert the rate if requested.
if psource.invert:
Expand Down
73 changes: 53 additions & 20 deletions beanprice/price_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


PS = price.PriceSource
ONE = Decimal(1)


def run_with_args(function, args, runner_file=None):
Expand Down Expand Up @@ -225,7 +226,7 @@ def test_expressions(self):
self.assertEqual(
[price.DatedPrice(
'AAPL', 'USD', None,
[price.PriceSource(yahoo, 'AAPL', False)])], jobs)
[PS(yahoo, 'AAPL', False, ONE)])], jobs)


class TestClobber(cmptest.TestCase):
Expand Down Expand Up @@ -292,7 +293,7 @@ def test_fetch_price__naive_time_no_timeozne(self, fetch_cached):
dprice = price.DatedPrice('JPY', 'USD', datetime.date(2015, 11, 22), None)
with self.assertRaises(ValueError):
price.fetch_price(dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', False)]), False)
PS(yahoo, 'USDJPY', False, ONE)]), False)


class TestInverted(unittest.TestCase):
Expand All @@ -309,23 +310,41 @@ def setUp(self):

def test_fetch_price__normal(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', False)]), False)
PS(yahoo, 'USDJPY', False, ONE)]), False)
self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('125.00'), entry.amount.number)

def test_fetch_price__inverted(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', True)]), False)
PS(yahoo, 'USDJPY', True, ONE)]), False)
self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('0.008'), entry.amount.number)

def test_fetch_price__swapped(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', True)]), True)
PS(yahoo, 'USDJPY', True, ONE)]), True)
self.assertEqual(('USD', 'JPY'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('125.00'), entry.amount.number)


class TestMultiplier(unittest.TestCase):

def test_multiplier(self):
fetch_cached = mock.patch('beanprice.price.fetch_cached_price').start()
self.addCleanup(mock.patch.stopall)
fetch_cached.return_value = SourcePrice(
Decimal('16824.00'), datetime.datetime(2023, 1, 1, 16, 0, 0,
tzinfo=tz.tzlocal()),
None)
dprice = price.DatedPrice(
'GBP', 'XSDR', datetime.date(2023, 1, 1), [
PS(yahoo, 'XSDR.L', False, Decimal('0.01')),
])
entry = price.fetch_price(dprice)
self.assertEqual(('GBP', 'XSDR'), (entry.currency, entry.amount.currency))
self.assertEqual('168.2400', str(entry.amount.number))


class TestImportSource(unittest.TestCase):

def test_import_source_valid(self):
Expand All @@ -352,22 +371,25 @@ def test_source_invalid(self):
with self.assertRaises(ImportError):
price.parse_single_source('invalid.module.name/NASDAQ:AAPL')

def test_source_valid(self):
psource = price.parse_single_source('yahoo/CNYUSD=X')
self.assertEqual(PS(yahoo, 'CNYUSD=X', False), psource)

# Make sure that an invalid name at the tail doesn't succeed.
with self.assertRaises(ValueError):
psource = price.parse_single_source('yahoo/CNYUSD&X')

def test_source_valid(self):
psource = price.parse_single_source('yahoo/CNYUSD=X')
self.assertEqual(PS(yahoo, 'CNYUSD=X', False, ONE), psource)

psource = price.parse_single_source('beanprice.sources.yahoo/AAPL')
self.assertEqual(PS(yahoo, 'AAPL', False), psource)
self.assertEqual(PS(yahoo, 'AAPL', False, ONE), psource)

psource = price.parse_single_source('0.01*yahoo/XSDR.L')
self.assertEqual(PS(yahoo, 'XSDR.L', False, Decimal('0.01')), psource)


class TestParseSourceMap(unittest.TestCase):

def _clean_source_map(self, smap):
return {currency: [PS(s[0].__name__, s[1], s[2]) for s in sources]
return {currency: [PS(s[0].__name__, s[1], s[2], s[3]) for s in sources]
for currency, sources in smap.items()}

def test_source_map_invalid(self):
Expand All @@ -378,39 +400,50 @@ def test_source_map_invalid(self):
def test_source_map_onecur_single(self):
smap = price.parse_source_map('USD:yahoo/AAPL')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'AAPL', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'AAPL', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_onecur_multiple(self):
smap = price.parse_source_map('USD:oanda/USDCAD,yahoo/CAD=X')
self.assertEqual(
{'USD': [PS('beanprice.sources.oanda', 'USDCAD', False),
PS('beanprice.sources.yahoo', 'CAD=X', False)]},
{'USD': [PS('beanprice.sources.oanda', 'USDCAD', False, ONE),
PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_manycur_single(self):
smap = price.parse_source_map('USD:yahoo/USDCAD '
'CAD:yahoo/CAD=X')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False)],
'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False, ONE)],
'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_manycur_multiple(self):
smap = price.parse_source_map('USD:yahoo/GBPUSD,oanda/GBPUSD '
'CAD:yahoo/GBPCAD')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False),
PS('beanprice.sources.oanda', 'GBPUSD', False)],
'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False, ONE),
PS('beanprice.sources.oanda', 'GBPUSD', False, ONE)],
'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_inverse(self):
smap = price.parse_source_map('USD:yahoo/^GBPUSD')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True)]},
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True, ONE)]},
self._clean_source_map(smap))

def test_source_map_multiplier(self):
smap = price.parse_source_map(
'GBP:0.01*yahoo/XSDR.L;GBP:yahoo/XSDR;USD:1000*yahoo/mXSDRUSD')
print(smap)
self.assertEqual({
'GBP': [PS('beanprice.sources.yahoo', 'XSDR.L', False, Decimal('0.01')),
PS('beanprice.sources.yahoo', 'XSDR', False, ONE)],
'USD': [PS('beanprice.sources.yahoo', 'mXSDRUSD', False, Decimal(1000))],

}, self._clean_source_map(smap))


class TestFilters(unittest.TestCase):

Expand Down

0 comments on commit eee4050

Please sign in to comment.