Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data storage #372

Merged
merged 12 commits into from May 26, 2021
2 changes: 1 addition & 1 deletion qlib/contrib/backtest/position.py
Expand Up @@ -166,7 +166,7 @@ def update_weight_all(self):
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
Expand Down
117 changes: 68 additions & 49 deletions qlib/data/data.py
Expand Up @@ -6,7 +6,9 @@
from __future__ import print_function

import os
import re
import abc
import copy
import time
import queue
import bisect
Expand All @@ -27,12 +29,41 @@
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path


class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
return backend

def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)

# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default uri map
if "uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()})
backend_kwargs["uri"] = uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)


class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class

Provide calendar data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
you-n-g marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
Expand Down Expand Up @@ -127,12 +158,15 @@ def _uri(self, start_time, end_time, freq, future=False):
return hash_args(start_time, end_time, freq, future)


class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class

Provide instrument data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})

@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
Expand Down Expand Up @@ -215,12 +249,15 @@ def get_inst_type(cls, inst):
raise ValueError(f"Unknown instrument type {inst}")


class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class

Provide feature data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})

@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
Expand Down Expand Up @@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
"""

def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -517,21 +555,18 @@ def load_calendar(self, freq, future):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]

backend_obj = self.backend_obj(freq=freq, future=future)
if future and not backend_obj.check_exists():
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False)

return [pd.Timestamp(x) for x in backend_obj.data]

def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
Expand Down Expand Up @@ -562,38 +597,20 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""

def __init__(self):
pass

@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")

def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)

_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data

def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
Expand All @@ -604,7 +621,7 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
Expand All @@ -630,6 +647,7 @@ class LocalFeatureProvider(FeatureProvider):
"""

def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -641,14 +659,14 @@ def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series
try:
data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
except Exception as e:
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
get_module_logger("data").warning(
f"WARN: data not found for {instrument}.{field}\n\tFeature exception info: {str(e)}"
)
data = pd.Series(dtype=np.float32)
return data


class LocalExpressionProvider(ExpressionProvider):
Expand Down Expand Up @@ -1065,7 +1083,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")

register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")

if getattr(C, "feature_provider", None) is not None:
Expand Down
2 changes: 1 addition & 1 deletion qlib/data/dataset/__init__.py
Expand Up @@ -357,7 +357,7 @@ def build_index(data: pd.DataFrame) -> dict:
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
Expand Down
4 changes: 4 additions & 0 deletions qlib/data/storage/__init__.py
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT