Skip to content

Commit

Permalink
US stock code supports Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupr committed Dec 20, 2020
1 parent 7fc88b3 commit 7c16ef1
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 76 deletions.
42 changes: 15 additions & 27 deletions qlib/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
import traceback
import numpy as np
import pandas as pd
from pathlib import Path
from multiprocessing import Pool

from .cache import H
from ..config import C
from .ops import *
from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
Expand Down Expand Up @@ -215,23 +214,6 @@ def get_inst_type(cls, inst):
return cls.LIST
raise ValueError(f"Unknown instrument type {inst}")

def convert_instruments(self, instrument):
_instruments_map = getattr(self, "_instruments_map", None)
if _instruments_map is None:
_df_list = []
# FIXME: each process will read these files
for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
_df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
_df_list.append(_df.iloc[:, [0, -1]])
df = pd.concat(_df_list, sort=False)
df["inst"] = df["inst"].astype(str)
df = df.fillna(axis=1, method="ffill")
df = df.sort_values("inst").drop_duplicates(subset=["inst"], keep="first")
df["save_inst"] = df["save_inst"].astype(str)
_instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
setattr(self, "_instruments_map", _instruments_map)
return _instruments_map.get(instrument, instrument)


class FeatureProvider(abc.ABC):
"""Feature provider class
Expand Down Expand Up @@ -591,13 +573,19 @@ def _load_instruments(self, market):
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df["start_datetime"] = pd.to_datetime(df["start_datetime"])
df["end_datetime"] = pd.to_datetime(df["end_datetime"])
df["inst"] = df["inst"].astype(str)
df["save_inst"] = df.loc[:, ["inst", "save_inst"]].fillna(axis=1, method="ffill")["save_inst"].astype(str)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
with open(fname) as f:
for line in f:
inst_time = line.strip().split()
inst = inst_time[0]
if len(inst_time) == 3:
# `day`
begin = inst_time[1]
end = inst_time[2]
elif len(inst_time) == 5:
# `1min`
begin = inst_time[1] + " " + inst_time[2]
end = inst_time[3] + " " + inst_time[4]
_instruments.setdefault(inst, []).append((pd.Timestamp(begin), pd.Timestamp(end)))
return _instruments

def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
Expand Down Expand Up @@ -652,7 +640,7 @@ def _uri_data(self):
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
instrument = Inst.convert_instruments(instrument)
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
Expand Down
37 changes: 34 additions & 3 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,7 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False

Expand Down Expand Up @@ -711,3 +709,36 @@ def load_dataset(path_or_obj):
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f"unsupported file type `{extension}`")


def code_to_fname(code: str):
"""stock code to file name
Parameters
----------
code: str
"""
# NOTE: In windows, the following name is I/O device, and the file with the corresponding name cannot be created
# reference: https://superuser.com/questions/86999/why-cant-i-name-a-folder-or-file-con-in-windows
replace_names = ["CON", "PRN", "AUX", "NUL"]
replace_names += [f"COM{i}" for i in range(10)]
replace_names += [f"LPT{i}" for i in range(10)]

prefix = "_qlib_"
if str(code).upper() in replace_names:
code = prefix + str(code)

return code


def fname_to_code(fname: str):
"""file name to stock code
Parameters
----------
fname: str
"""
prefix = "_qlib_"
if fname.startswith(prefix):
fname = fname.lstrip(prefix)
return fname
3 changes: 2 additions & 1 deletion scripts/data_collector/yahoo/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
from qlib.utils import code2fname

CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
Expand Down Expand Up @@ -350,7 +351,7 @@ def download_index_data(self):
pass

def normalize_symbol(self, symbol):
return symbol.upper()
return code2fname(symbol).upper()

@property
def _timezone(self):
Expand Down
78 changes: 33 additions & 45 deletions scripts/dump_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import fname_to_code, code_to_fname


class DumpDataBase:
Expand All @@ -27,7 +28,6 @@ class DumpDataBase:
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
SAVE_INST_FIELD = "save_inst"

UPDATE_MODE = "update"
ALL_MODE = "all"
Expand All @@ -45,7 +45,6 @@ def __init__(
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
inst_prefix: str = "",
):
"""
Expand Down Expand Up @@ -73,9 +72,6 @@ def __init__(
fields not dumped
limit_nums: int
Use when debugging, default None
inst_prefix: str
add a column to the instruments file and record the saved instrument name,
the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
Expand All @@ -84,7 +80,6 @@ def __init__(
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self._inst_prefix = inst_prefix.strip()
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
Expand Down Expand Up @@ -145,7 +140,7 @@ def _get_source_data(self, file_path: Path) -> pd.DataFrame:
return df

def get_symbol_from_file(self, file_path: Path) -> str:
return file_path.name[: -len(self.file_suffix)].strip().lower()
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip()).lower()

def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
Expand Down Expand Up @@ -173,7 +168,6 @@ def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
self.symbol_field_name,
self.INSTRUMENTS_START_FIELD,
self.INSTRUMENTS_END_FIELD,
self.SAVE_INST_FIELD,
],
)

Expand All @@ -190,13 +184,9 @@ def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
_df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
if self._inst_prefix:
_df_fields.append(self.SAVE_INST_FIELD)
instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
lambda x: f"{self._inst_prefix}{x}"
)
instruments_data = instruments_data.loc[:, _df_fields]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply(fname_to_code)
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")

Expand All @@ -223,26 +213,26 @@ def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], feat
logger.warning(f"{features_dir.name} data is None or empty")
return
# align index
_df = self.data_merge_calendar(df, self._calendars_list)
_df = self.data_merge_calendar(df, calendar_list)
# used when creating a bin file
date_index = self.get_datetime_index(_df, calendar_list)
for field in self.get_dump_fields(_df.columns):
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
if field not in _df.columns:
continue
if self._mode == self.UPDATE_MODE:
if bin_path.exists() and self._mode == self.UPDATE_MODE:
# update
with bin_path.open("ab") as fp:
np.array(_df[field]).astype("<f").tofile(fp)
elif self._mode == self.ALL_MODE:
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
else:
raise ValueError(f"{self._mode} cannot support!")
# append; self._mode == self.ALL_MODE or not bin_path.exists()
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))

def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
if isinstance(file_or_data, pd.DataFrame):
if file_or_data.empty:
return
code = file_or_data.iloc[0][self.symbol_field_name].lower()
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name]).lower()
df = file_or_data
elif isinstance(file_or_data, Path):
code = self.get_symbol_from_file(file_or_data)
Expand All @@ -253,8 +243,7 @@ def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.T
logger.warning(f"{code} data is None or empty")
return
# features save dir
code = self._inst_prefix + code if self._inst_prefix else code
features_dir = self._features_dir.joinpath(code)
features_dir = self._features_dir.joinpath(code_to_fname(code))
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)

Expand Down Expand Up @@ -283,8 +272,6 @@ def _get_all_date(self):
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
_inst_fields = [symbol.upper(), _begin_time, _end_time]
if self._inst_prefix:
_inst_fields.append(self._inst_prefix + symbol.upper())
date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
Expand Down Expand Up @@ -323,12 +310,12 @@ class DumpDataFix(DumpDataAll):
def _dump_instruments(self):
logger.info("start dump instruments......")
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
new_stock_files = sorted(filter(lambda x: fname_to_code(x.name) not in self._old_instruments, self.csv_files))
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = self.get_symbol_from_file(file_path).upper()
symbol = fname_to_code(self.get_symbol_from_file(file_path).upper())
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
Expand Down Expand Up @@ -406,10 +393,10 @@ def __init__(
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
self._update_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
self._update_instruments = (
self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
.set_index([self.symbol_field_name])
.to_dict(orient="index")
) # type: dict

# load all csv files
Expand All @@ -425,18 +412,15 @@ def _load_all_source_data(self):
all_df = []

def _read_csv(file_path: Path):
if self._include_fields:
_df = pd.read_csv(file_path, usecols=self._include_fields)
else:
_df = pd.read_csv(file_path)
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df

with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
if df:
if not df.empty:
all_df.append(df)
p_bar.update()

Expand All @@ -455,33 +439,37 @@ def _dump_features(self):
with ProcessPoolExecutor(max_workers=self.works) as executor:
futures = {}
for _code, _df in self._all_data.groupby(self.symbol_field_name):
_code = str(_code).upper()
_code = fname_to_code(str(_code)).upper()
_start, _end = self._get_date(_df, is_begin_end=True)
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
self._update_instruments[_code]["end_time"] = _end
self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())
_dt_range["start_time"] = _start
_dt_range["end_time"] = _end
_dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start)
_dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end)
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code

for _future in tqdm(as_completed(futures)):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
with tqdm(total=len(futures)) as p_bar:
for _future in as_completed(futures):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
p_bar.update()
logger.info(f"dump bin errors: {error_code}")

logger.info("end of features dump.\n")

def dump(self):
self.save_calendars(self._new_calendar_list)
self._dump_features()
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
df = pd.DataFrame.from_dict(self._update_instruments, orient="index")
df.index.names = [self.symbol_field_name]
self.save_instruments(df.reset_index())


if __name__ == "__main__":
Expand Down

0 comments on commit 7c16ef1

Please sign in to comment.