Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiplier on source prices. #73

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions beanprice/price.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import collections
import datetime
from decimal import Decimal
import functools
from os import path
import shelve
Expand Down Expand Up @@ -38,10 +39,13 @@
# module: A Python module, the module to be called to create a price source.
# symbol: A ticker symbol in the universe of the source.
# invert: A boolean, true if we need to invert the currency.
# multiplier: A Decimal instance to be multiplied on prices from the source.
# This is useful with sources returning 1.23 USD as 123.
class PriceSource(NamedTuple):
module: Any
symbol: str
invert: bool
multiplier: Decimal


# A dated price source description.
Expand Down Expand Up @@ -151,7 +155,7 @@ def parse_single_source(source: str) -> PriceSource:

Source specifications follow the syntax:

<module>/[^]<ticker>
[<multiplier>*]<module>/[^]<ticker>

The <module> is resolved against the Python path, but first looked up
under the package where the default price extractors lie.
Expand All @@ -163,12 +167,15 @@ def parse_single_source(source: str) -> PriceSource:
Raises:
ValueError: If invalid.
"""
match = re.match(r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source)
match = re.match(r'(?:([0-9]+(?:\.[0-9]+)?)\*)?'
r'([a-zA-Z]+[a-zA-Z0-9\._]+)/(\^?)([a-zA-Z0-9:=_\-\.\(\)]+)$', source)
if not match:
raise ValueError('Invalid source name: "{}"'.format(source))
short_module_name, invert, symbol = match.groups()
multiplier_str, short_module_name, invert, symbol = match.groups()
module = import_source(short_module_name)
return PriceSource(module, symbol, bool(invert))
multiplier = Decimal(multiplier_str) if multiplier_str else Decimal(1)
print(f'{multiplier=} {multiplier_str=}')
return PriceSource(module, symbol, bool(invert), multiplier)


def import_source(module_name: str):
Expand Down Expand Up @@ -323,7 +330,7 @@ def get_price_jobs_at_date(entries: data.Entries,

# If there are no sources, create a default one.
if not psources:
psources = [PriceSource(default_source, base, False)]
psources = [PriceSource(default_source, base, False, Decimal(1))]

jobs.append(DatedPrice(base, quote, date, psources))
return sorted(jobs)
Expand Down Expand Up @@ -599,7 +606,7 @@ def fetch_price(dprice: DatedPrice, swap_inverted: bool = False) -> Optional[dat

base = dprice.base
quote = dprice.quote or srcprice.quote_currency
price = srcprice.price
price = srcprice.price * psource.multiplier

# Invert the rate if requested.
if psource.invert:
Expand Down
73 changes: 53 additions & 20 deletions beanprice/price_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


PS = price.PriceSource
ONE = Decimal(1)


def run_with_args(function, args, runner_file=None):
Expand Down Expand Up @@ -225,7 +226,7 @@ def test_expressions(self):
self.assertEqual(
[price.DatedPrice(
'AAPL', 'USD', None,
[price.PriceSource(yahoo, 'AAPL', False)])], jobs)
[PS(yahoo, 'AAPL', False, ONE)])], jobs)


class TestClobber(cmptest.TestCase):
Expand Down Expand Up @@ -292,7 +293,7 @@ def test_fetch_price__naive_time_no_timeozne(self, fetch_cached):
dprice = price.DatedPrice('JPY', 'USD', datetime.date(2015, 11, 22), None)
with self.assertRaises(ValueError):
price.fetch_price(dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', False)]), False)
PS(yahoo, 'USDJPY', False, ONE)]), False)


class TestInverted(unittest.TestCase):
Expand All @@ -309,23 +310,41 @@ def setUp(self):

def test_fetch_price__normal(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', False)]), False)
PS(yahoo, 'USDJPY', False, ONE)]), False)
self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('125.00'), entry.amount.number)

def test_fetch_price__inverted(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', True)]), False)
PS(yahoo, 'USDJPY', True, ONE)]), False)
self.assertEqual(('JPY', 'USD'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('0.008'), entry.amount.number)

def test_fetch_price__swapped(self):
entry = price.fetch_price(self.dprice._replace(sources=[
price.PriceSource(yahoo, 'USDJPY', True)]), True)
PS(yahoo, 'USDJPY', True, ONE)]), True)
self.assertEqual(('USD', 'JPY'), (entry.currency, entry.amount.currency))
self.assertEqual(Decimal('125.00'), entry.amount.number)


class TestMultiplier(unittest.TestCase):

def test_multiplier(self):
fetch_cached = mock.patch('beanprice.price.fetch_cached_price').start()
self.addCleanup(mock.patch.stopall)
fetch_cached.return_value = SourcePrice(
Decimal('16824.00'), datetime.datetime(2023, 1, 1, 16, 0, 0,
tzinfo=tz.tzlocal()),
None)
dprice = price.DatedPrice(
'GBP', 'XSDR', datetime.date(2023, 1, 1), [
PS(yahoo, 'XSDR.L', False, Decimal('0.01')),
])
entry = price.fetch_price(dprice)
self.assertEqual(('GBP', 'XSDR'), (entry.currency, entry.amount.currency))
self.assertEqual('168.2400', str(entry.amount.number))


class TestImportSource(unittest.TestCase):

def test_import_source_valid(self):
Expand All @@ -352,22 +371,25 @@ def test_source_invalid(self):
with self.assertRaises(ImportError):
price.parse_single_source('invalid.module.name/NASDAQ:AAPL')

def test_source_valid(self):
psource = price.parse_single_source('yahoo/CNYUSD=X')
self.assertEqual(PS(yahoo, 'CNYUSD=X', False), psource)

# Make sure that an invalid name at the tail doesn't succeed.
with self.assertRaises(ValueError):
psource = price.parse_single_source('yahoo/CNYUSD&X')

def test_source_valid(self):
psource = price.parse_single_source('yahoo/CNYUSD=X')
self.assertEqual(PS(yahoo, 'CNYUSD=X', False, ONE), psource)

psource = price.parse_single_source('beanprice.sources.yahoo/AAPL')
self.assertEqual(PS(yahoo, 'AAPL', False), psource)
self.assertEqual(PS(yahoo, 'AAPL', False, ONE), psource)

psource = price.parse_single_source('0.01*yahoo/XSDR.L')
self.assertEqual(PS(yahoo, 'XSDR.L', False, Decimal('0.01')), psource)


class TestParseSourceMap(unittest.TestCase):

def _clean_source_map(self, smap):
return {currency: [PS(s[0].__name__, s[1], s[2]) for s in sources]
return {currency: [PS(s[0].__name__, s[1], s[2], s[3]) for s in sources]
for currency, sources in smap.items()}

def test_source_map_invalid(self):
Expand All @@ -378,39 +400,50 @@ def test_source_map_invalid(self):
def test_source_map_onecur_single(self):
smap = price.parse_source_map('USD:yahoo/AAPL')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'AAPL', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'AAPL', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_onecur_multiple(self):
smap = price.parse_source_map('USD:oanda/USDCAD,yahoo/CAD=X')
self.assertEqual(
{'USD': [PS('beanprice.sources.oanda', 'USDCAD', False),
PS('beanprice.sources.yahoo', 'CAD=X', False)]},
{'USD': [PS('beanprice.sources.oanda', 'USDCAD', False, ONE),
PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_manycur_single(self):
smap = price.parse_source_map('USD:yahoo/USDCAD '
'CAD:yahoo/CAD=X')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False)],
'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'USDCAD', False, ONE)],
'CAD': [PS('beanprice.sources.yahoo', 'CAD=X', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_manycur_multiple(self):
smap = price.parse_source_map('USD:yahoo/GBPUSD,oanda/GBPUSD '
'CAD:yahoo/GBPCAD')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False),
PS('beanprice.sources.oanda', 'GBPUSD', False)],
'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False)]},
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', False, ONE),
PS('beanprice.sources.oanda', 'GBPUSD', False, ONE)],
'CAD': [PS('beanprice.sources.yahoo', 'GBPCAD', False, ONE)]},
self._clean_source_map(smap))

def test_source_map_inverse(self):
smap = price.parse_source_map('USD:yahoo/^GBPUSD')
self.assertEqual(
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True)]},
{'USD': [PS('beanprice.sources.yahoo', 'GBPUSD', True, ONE)]},
self._clean_source_map(smap))

def test_source_map_multiplier(self):
smap = price.parse_source_map(
'GBP:0.01*yahoo/XSDR.L;GBP:yahoo/XSDR;USD:1000*yahoo/mXSDRUSD')
print(smap)
self.assertEqual({
'GBP': [PS('beanprice.sources.yahoo', 'XSDR.L', False, Decimal('0.01')),
PS('beanprice.sources.yahoo', 'XSDR', False, ONE)],
'USD': [PS('beanprice.sources.yahoo', 'mXSDRUSD', False, Decimal(1000))],

}, self._clean_source_map(smap))


class TestFilters(unittest.TestCase):

Expand Down