Skip to content

Commit

Permalink
rewriting the 'write' interface of Design FeatureStorage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupr committed May 21, 2021
1 parent 317357b commit 3ccd353
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 169 deletions.
62 changes: 39 additions & 23 deletions qlib/data/data.py
Expand Up @@ -8,6 +8,7 @@
import os
import re
import abc
import copy
import time
import queue
import bisect
Expand All @@ -31,19 +32,27 @@
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] # type: str
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {}) # type: dict
backend_kwargs.setdefault("uri", os.path.join(C.get_data_path(), f"{provider_name.lower()}s"))
return backend

@property
def backend_obj(self):
return init_instance_by_config(self.backend)
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)

# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default uri map
if "uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
uri_map = backend_kwargs.setdefault("uri_map", {freq: C.get_data_path()})
backend_kwargs["uri"] = uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)


class CalendarProvider(abc.ABC, ProviderBackendMixin):
Expand All @@ -54,8 +63,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()

@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
Expand Down Expand Up @@ -159,8 +166,6 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()

@staticmethod
def instruments(market="all", filter_pipe=None):
Expand Down Expand Up @@ -252,8 +257,6 @@ class FeatureProvider(abc.ABC, ProviderBackendMixin):

def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
if not self.backend:
self.backend = self.get_default_backend()

@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
Expand Down Expand Up @@ -552,8 +555,18 @@ def load_calendar(self, freq, future):
list
list of timestamps
"""
self.backend.setdefault("kwargs", {}).update(freq=freq, future=future)
return [pd.Timestamp(x) for x in self.backend_obj.data]

backend_obj = self.backend_obj(freq=freq, future=future)
if future and not backend_obj.check_exists():
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
backend_obj = self.backend_obj(freq=freq, future=False)

return [pd.Timestamp(x) for x in backend_obj.data]

def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
Expand Down Expand Up @@ -589,17 +602,15 @@ def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")

def _load_instruments(self, market):

self.backend.setdefault("kwargs", {}).update(market=market)
return self.backend_obj.data
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data

def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
Expand Down Expand Up @@ -648,9 +659,14 @@ def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)

self.backend.setdefault("kwargs", {}).update(instrument=instrument, field=field, freq=freq)
return self.backend_obj[start_index : end_index + 1]
try:
data = self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
except Exception as e:
get_module_logger("data").warning(
f"WARN: data not found for {instrument}.{field}\n\tException info: {str(e)}"
)
data = pd.Series(dtype=np.float32)
return data


class LocalExpressionProvider(ExpressionProvider):
Expand Down
122 changes: 66 additions & 56 deletions qlib/data/storage/file_storage.py
Expand Up @@ -3,25 +3,36 @@

import struct
from pathlib import Path
from typing import Iterator, Iterable, Union, Dict, Mapping, Tuple
from typing import Iterable, Union, Dict, Mapping, Tuple, List

import numpy as np
import pandas as pd

from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT

logger = get_module_logger("file_storage")

class FileCalendarStorage(CalendarStorage):
def __init__(self, freq: str, future: bool, uri: str):
super(FileCalendarStorage, self).__init__(freq, future, uri)

class FileStorage:
def check_exists(self):
return self.uri.exists()


class FileCalendarStorage(FileStorage, CalendarStorage):
def __init__(self, freq: str, future: bool, uri: str, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, uri, **kwargs)
_file_name = f"{freq}_future.txt" if future else f"{freq}.txt"
self.uri = Path(self.uri).expanduser().joinpath(_file_name.lower())
self.uri = Path(self.uri).expanduser().joinpath("calendars", _file_name.lower())

def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> np.ndarray:
if not self.uri.exists():
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> Iterable[CalVT]:
if not self.check_exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")
return [
str(x)
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
]

def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
Expand Down Expand Up @@ -65,23 +76,17 @@ def __delitem__(self, i: Union[int, slice]) -> None:
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, Iterable[CalVT]]:
return self._read_calendar()[i]

def __len__(self) -> int:
return len(self._read_calendar())

def __iter__(self):
return iter(self._read_calendar())


class FileInstrumentStorage(InstrumentStorage):
class FileInstrumentStorage(FileStorage, InstrumentStorage):

INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"

def __init__(self, market: str, uri: str):
super(FileInstrumentStorage, self).__init__(market, uri)
self.uri = Path(self.uri).expanduser().joinpath(f"{market.lower()}.txt")
def __init__(self, market: str, uri: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, uri, **kwargs)
self.uri = Path(self.uri).expanduser().joinpath("instruments", f"{market.lower()}.txt")

def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
Expand Down Expand Up @@ -138,14 +143,6 @@ def __delitem__(self, k: InstKT) -> None:
def __getitem__(self, k: InstKT) -> InstVT:
return self._read_instrument()[k]

def __len__(self) -> int:
inst = self._read_instrument()
return len(inst)

def __iter__(self) -> Iterator[InstKT]:
for _inst in self._read_instrument().keys():
yield _inst

def update(self, *args, **kwargs) -> None:

if len(args) > 1:
Expand All @@ -168,11 +165,11 @@ def update(self, *args, **kwargs) -> None:
self._write_instrument(inst)


class FileFeatureStorage(FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str):
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri)
class FileFeatureStorage(FileStorage, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, uri: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, uri, **kwargs)
self.uri = (
Path(self.uri).expanduser().joinpath(instrument.lower()).joinpath(f"{field.lower()}.{freq.lower()}.bin")
Path(self.uri).expanduser().joinpath("features", instrument.lower(), f"{field.lower()}.{freq.lower()}.bin")
)

def clear(self):
Expand All @@ -183,18 +180,45 @@ def clear(self):
def data(self) -> pd.Series:
return self[:]

def extend(self, series: pd.Series) -> None:
extend_start_index = self[0][0] + len(self) if self.uri.exists() else series.index[0]
series = series.reindex(pd.RangeIndex(extend_start_index, series.index[-1] + 1))
with self.uri.open("ab") as fp:
np.array(series.values).astype("<f").tofile(fp)
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
if len(data_array) == 0:
logger.info(
"len(data_array) == 0, write"
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if not self.uri.exists():
# write
index = 0 if index is None else index
with self.uri.open("wb") as fp:
np.hstack([index, data_array]).astype("<f").tofile(fp)
else:
if index is None or index > self.end_index:
# append
index = 0 if index is None else index
with self.uri.open("ab+") as fp:
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
else:
# rewrite
with self.uri.open("rb+") as fp:
_old_data = np.fromfile(fp, dtype="<f")
_old_index = _old_data[0]
_old_df = pd.DataFrame(
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
)
fp.seek(0)
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)

def rebase(self, series: pd.Series) -> None:
origin_series = self[:]
series = series.append(origin_series.loc[origin_series.index > series.index[-1]])
series = series.reindex(pd.RangeIndex(series.index[0], series.index[-1]))
with self.uri.open("wb") as fp:
np.array(series.values).astype("<f").tofile(fp)
@property
def start_index(self) -> Union[int, None]:
if len(self) == 0:
return None
with open(self.uri, "rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index

def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
Expand Down Expand Up @@ -228,18 +252,4 @@ def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Serie
raise TypeError(f"type(i) = {type(i)}")

def __len__(self) -> int:
return self.uri.stat().st_size // 4 - 1 if self.uri.exists() else 0

def __iter__(self):
if not self.uri.exists():
return
with open(self.uri, "rb") as fp:
ref_start_index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
fp.seek(4)
while True:
v = fp.read(4)
if v:
yield ref_start_index, struct.unpack("f", v)[0]
ref_start_index += 1
else:
break
return self.uri.stat().st_size // 4 - 1 if self.check_exists() else 0

0 comments on commit 3ccd353

Please sign in to comment.