Skip to content

Commit

Permalink
remove uri parameter from storage && modify file_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupr committed May 26, 2021
1 parent 602f78b commit 5da3356
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 201 deletions.
28 changes: 28 additions & 0 deletions docs/reference/api.rst
Expand Up @@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:


Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:

.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:

.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:

.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:

.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:

.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:


Dataset
---------------

Expand Down
41 changes: 19 additions & 22 deletions qlib/data/data.py
Expand Up @@ -45,12 +45,12 @@ def backend_obj(self, **kwargs):

# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default uri map
if "uri" not in backend_kwargs:
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()})
backend_kwargs["uri"] = uri_map[freq]
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)

Expand Down Expand Up @@ -556,17 +556,21 @@ def load_calendar(self, freq, future):
list of timestamps
"""

backend_obj = self.backend_obj(freq=freq, future=future)
if future and not backend_obj.check_exists():
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False)
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
if future:
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False).data
else:
raise

return [pd.Timestamp(x) for x in backend_obj.data]
return [pd.Timestamp(x) for x in backend_obj]

def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
Expand Down Expand Up @@ -659,14 +663,7 @@ def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
try:
data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
except Exception as e:
get_module_logger("data").warning(
f"WARN: data not found for {instrument}.{field}\n\tFeature exception info: {str(e)}"
)
data = pd.Series(dtype=np.float32)
return data
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]


class LocalExpressionProvider(ExpressionProvider):
Expand Down
110 changes: 74 additions & 36 deletions qlib/data/storage/file_storage.py
Expand Up @@ -14,19 +14,35 @@
logger = get_module_logger("file_storage")


class FileStorage:
def check_exists(self):
return self.uri.exists()
class FileStorageMixin:
@property
def uri(self) -> Path:
_provider_uri = self.kwargs.get("provider_uri", None)
if _provider_uri is None:
raise ValueError(
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
f'please specify `provider_uri` in the "provider\'s backend"'
)
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)

def check(self):
"""check self.uri
Raises
-------
ValueError
"""
if not self.uri.exists():
raise ValueError(f"{self.storage_name} not exists: {self.uri}")


class FileCalendarStorage(FileStorage, CalendarStorage):
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs)
_file_name = f"{freq}_future.txt" if future else f"{freq}.txt"
self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower())
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()

def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]:
if not self.check_exists():
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [
Expand All @@ -39,7 +55,8 @@ def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
np.savetxt(fp, values, fmt="%s", encoding="utf-8")

@property
def data(self) -> Iterable[CalVT]:
def data(self) -> List[CalVT]:
self.check()
return self._read_calendar()

def extend(self, values: Iterable[CalVT]) -> None:
Expand All @@ -49,6 +66,7 @@ def clear(self) -> None:
self._write_calendar(values=[])

def index(self, value: CalVT) -> int:
self.check()
calendar = self._read_calendar()
return int(np.argwhere(calendar == value)[0])

Expand All @@ -58,6 +76,7 @@ def insert(self, index: int, value: CalVT):
self._write_calendar(values=calendar)

def remove(self, value: CalVT) -> None:
self.check()
index = self.index(value)
calendar = self._read_calendar()
calendar = np.delete(calendar, index)
Expand All @@ -69,24 +88,29 @@ def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]
self._write_calendar(values=calendar)

def __delitem__(self, i: Union[int, slice]) -> None:
self.check()
calendar = self._read_calendar()
calendar = np.delete(calendar, i)
self._write_calendar(values=calendar)

def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]:
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
self.check()
return self._read_calendar()[i]

def __len__(self) -> int:
return len(self.data)


class FileInstrumentStorage(FileStorage, InstrumentStorage):
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):

INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"

def __init__(self, market: str, uri: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, uri, **kwargs)
self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt")
def __init__(self, market: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, **kwargs)
self.file_name = f"{market.lower()}.txt"

def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
Expand Down Expand Up @@ -128,6 +152,7 @@ def clear(self) -> None:

@property
def data(self) -> Dict[InstKT, InstVT]:
self.check()
return self._read_instrument()

def __setitem__(self, k: InstKT, v: InstVT) -> None:
Expand All @@ -136,11 +161,13 @@ def __setitem__(self, k: InstKT, v: InstVT) -> None:
self._write_instrument(inst)

def __delitem__(self, k: InstKT) -> None:
self.check()
inst = self._read_instrument()
del inst[k]
self._write_instrument(inst)

def __getitem__(self, k: InstKT) -> InstVT:
self.check()
return self._read_instrument()[k]

def update(self, *args, **kwargs) -> None:
Expand All @@ -164,13 +191,14 @@ def update(self, *args, **kwargs) -> None:

self._write_instrument(inst)

def __len__(self) -> int:
return len(self.data)

class FileFeatureStorage(FileStorage, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs)
self.uri = (
Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin")
)

class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"

def clear(self):
with self.uri.open("wb") as _:
Expand Down Expand Up @@ -214,35 +242,44 @@ def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:

@property
def start_index(self) -> Union[int, None]:
if len(self) == 0:
if not self.uri.exists():
return None
with open(self.uri, "rb") as fp:
with self.uri.open("rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index

@property
def end_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
# The next data appending index point will be `end_index + 1`
return self.start_index + len(self) - 1

def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
if isinstance(i, int):
return None, None
elif isinstance(i, slice):
return pd.Series()
return pd.Series(dtype=np.float32)
else:
raise TypeError(f"type(i) = {type(i)}")

with open(self.uri, "rb") as fp:

storage_start_index = self.start_index
storage_end_index = self.end_index
with self.uri.open("rb") as fp:
if isinstance(i, int):
if self.start_index > i:
raise IndexError(f"{i}: start index is {self.start_index}")
fp.seek(4 * (i - self.start_index) + 4)

if storage_start_index > i:
raise IndexError(f"{i}: start index is {storage_start_index}")
fp.seek(4 * (i - storage_start_index) + 4)
return i, struct.unpack("f", fp.read(4))[0]
elif isinstance(i, slice):
start_index = self.start_index if i.start is None else i.start
end_index = self.end_index if i.stop is None else i.stop - 1
si = max(self.start_index, start_index)
start_index = storage_start_index if i.start is None else i.start
end_index = storage_end_index if i.stop is None else i.stop - 1
si = max(start_index, storage_start_index)
if si > end_index:
return pd.Series()
fp.seek(4 * (si - self.start_index) + 4)
return pd.Series(dtype=np.float32)
fp.seek(4 * (si - storage_start_index) + 4)
# read n bytes
count = end_index - si + 1
data = np.frombuffer(fp.read(4 * count), dtype="<f")
Expand All @@ -251,4 +288,5 @@ def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Serie
raise TypeError(f"type(i) = {type(i)}")

def __len__(self) -> int:
return self.uri.stat().st_size // 4 - 1 if self.check_exists() else 0
self.check()
return self.uri.stat().st_size // 4 - 1

0 comments on commit 5da3356

Please sign in to comment.