Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add data storage #372

Merged
merged 12 commits into from May 26, 2021
94 changes: 50 additions & 44 deletions qlib/data/data.py
Expand Up @@ -6,6 +6,7 @@
from __future__ import print_function

import os
import re
import abc
import time
import queue
Expand All @@ -27,12 +28,35 @@
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path


class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {}) # type: dict
backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s"))
return backend

@property
def backend_obj(self):
return init_instance_by_config(self.backend)


class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class

Provide calendar data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
if not self.backend:
self.backend = self.get_default_backend()

@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
Expand Down Expand Up @@ -127,12 +151,17 @@ def _uri(self, start_time, end_time, freq, future=False):
return hash_args(start_time, end_time, freq, future)


class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class

Provide instrument data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()

@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
Expand Down Expand Up @@ -215,12 +244,17 @@ def get_inst_type(cls, inst):
raise ValueError(f"Unknown instrument type {inst}")


class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class

Provide feature data.
"""

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()

@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
Expand Down Expand Up @@ -497,6 +531,7 @@ class LocalCalendarProvider(CalendarProvider):
"""

def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -517,18 +552,8 @@ def load_calendar(self, freq, future):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
self.backend.setdefault("kwargs", {}).update(freq=freq, future=future)
return [pd.Timestamp(x) for x in self.backend_obj.data]
you-n-g marked this conversation as resolved.
Show resolved Hide resolved

def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
Expand Down Expand Up @@ -559,31 +584,15 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""

def __init__(self):
pass

@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")

def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)

_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments

self.backend.setdefault("kwargs", {}).update(market=market)
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
return self.backend_obj.data

def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
Expand All @@ -601,7 +610,7 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
Expand All @@ -627,6 +636,7 @@ class LocalFeatureProvider(FeatureProvider):
"""

def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)

@property
Expand All @@ -638,14 +648,9 @@ def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series

self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq)
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
return self.backend_obj[start_index : end_index + 1]


class LocalExpressionProvider(ExpressionProvider):
Expand Down Expand Up @@ -1061,7 +1066,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")

register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")

if getattr(C, "feature_provider", None) is not None:
Expand Down
4 changes: 4 additions & 0 deletions qlib/data/storage/__init__.py
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT