Skip to content

Commit

Permalink
Merge pull request quantopian#1467 from quantopian/check_param-string…
Browse files Browse the repository at this point in the history
…_types

Check param string types
  • Loading branch information
richafrank committed Sep 8, 2016
2 parents 9ad670b + 7f6db68 commit 7a10d93
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
43 changes: 30 additions & 13 deletions tests/test_algorithm.py
Expand Up @@ -25,7 +25,7 @@
from logbook import TestHandler, WARNING
from mock import MagicMock
from nose_parameterized import parameterized
from six import iteritems, itervalues
from six import iteritems, itervalues, string_types
from six.moves import range
from testfixtures import TempDirectory

Expand All @@ -37,7 +37,7 @@
from zipline import run_algorithm
from zipline import TradingAlgorithm
from zipline.api import FixedSlippage
from zipline.assets import Equity, Future
from zipline.assets import Equity, Future, Asset
from zipline.assets.synthetic import (
make_jagged_equity_info,
make_simple_equity_info,
Expand Down Expand Up @@ -1523,29 +1523,46 @@ class TestAlgoScript(WithLogger,
DATA_PORTAL_USE_MINUTE_DATA = False
EQUITY_DAILY_BAR_LOOKBACK_DAYS = 5 # max history window length

STRING_TYPE_NAMES = [s.__name__ for s in string_types]
STRING_TYPE_NAMES_STRING = ', '.join(STRING_TYPE_NAMES)
ASSET_TYPE_NAME = Asset.__name__
ASSET_OR_STRING_TYPE_NAMES = ', '.join([ASSET_TYPE_NAME] +
STRING_TYPE_NAMES)
ARG_TYPE_TEST_CASES = (
('history__assets', (bad_type_history_assets, 'Asset, str', True)),
('history__fields', (bad_type_history_fields, 'str', True)),
('history__assets', (bad_type_history_assets,
ASSET_OR_STRING_TYPE_NAMES,
True)),
('history__fields', (bad_type_history_fields,
STRING_TYPE_NAMES_STRING,
True)),
('history__bar_count', (bad_type_history_bar_count, 'int', False)),
('history__frequency', (bad_type_history_frequency, 'str', False)),
('current__assets', (bad_type_current_assets, 'Asset, str', True)),
('current__fields', (bad_type_current_fields, 'str', True)),
('history__frequency', (bad_type_history_frequency,
STRING_TYPE_NAMES_STRING,
False)),
('current__assets', (bad_type_current_assets,
ASSET_OR_STRING_TYPE_NAMES,
True)),
('current__fields', (bad_type_current_fields,
STRING_TYPE_NAMES_STRING,
True)),
('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)),
('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)),
('history_kwarg__assets',
(bad_type_history_assets_kwarg, 'Asset, str', True)),
(bad_type_history_assets_kwarg, ASSET_OR_STRING_TYPE_NAMES, True)),
('history_kwarg_bad_list__assets',
(bad_type_history_assets_kwarg_list, 'Asset, str', True)),
(bad_type_history_assets_kwarg_list,
ASSET_OR_STRING_TYPE_NAMES,
True)),
('history_kwarg__fields',
(bad_type_history_fields_kwarg, 'str', True)),
(bad_type_history_fields_kwarg, STRING_TYPE_NAMES_STRING, True)),
('history_kwarg__bar_count',
(bad_type_history_bar_count_kwarg, 'int', False)),
('history_kwarg__frequency',
(bad_type_history_frequency_kwarg, 'str', False)),
(bad_type_history_frequency_kwarg, STRING_TYPE_NAMES_STRING, False)),
('current_kwarg__assets',
(bad_type_current_assets_kwarg, 'Asset, str', True)),
(bad_type_current_assets_kwarg, ASSET_OR_STRING_TYPE_NAMES, True)),
('current_kwarg__fields',
(bad_type_current_fields_kwarg, 'str', True)),
(bad_type_current_fields_kwarg, STRING_TYPE_NAMES_STRING, True)),
)

sids = 0, 1, 3, 133
Expand Down
15 changes: 9 additions & 6 deletions zipline/_protocol.pyx
Expand Up @@ -20,7 +20,7 @@ from pandas.tslib import normalize_date
import pandas as pd
import numpy as np

from six import iteritems, PY2
from six import iteritems, PY2, string_types
from cpython cimport bool
from collections import Iterable

Expand All @@ -29,7 +29,7 @@ from zipline.zipline_warnings import ZiplineDeprecationWarning


cdef bool _is_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
return isinstance(obj, Iterable) and not isinstance(obj, string_types)


# Wraps doesn't work for method objects in python2. Docs should be generated
Expand Down Expand Up @@ -247,7 +247,8 @@ cdef class BarData:

return dt

@check_parameters(('assets', 'fields'), ((Asset, str), str))
@check_parameters(('assets', 'fields'),
((Asset,) + string_types, string_types))
def current(self, assets, fields):
"""
Returns the current value of the given assets for the given fields
Expand Down Expand Up @@ -568,8 +569,10 @@ cdef class BarData:

return not (last_traded_dt is pd.NaT)

@check_parameters(('assets', 'fields', 'bar_count', 'frequency'),
((Asset, str), str, int, str))
@check_parameters(('assets', 'fields', 'bar_count',
'frequency'),
((Asset,) + string_types, string_types, int,
string_types))
def history(self, assets, fields, bar_count, frequency):
"""
Returns a window of data for the given assets and fields.
Expand Down Expand Up @@ -615,7 +618,7 @@ cdef class BarData:
If the current simulation time is not a valid market time, we use the
last market close instead.
"""
if isinstance(fields, str):
if isinstance(fields, string_types):
single_asset = isinstance(assets, Asset)

if single_asset:
Expand Down

0 comments on commit 7a10d93

Please sign in to comment.