From da2a2ef1864f5f87e9cc5c4668470565543052d4 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 20 Oct 2020 22:10:17 +0800 Subject: [PATCH] Support csi100 data collection && Fix data collector --- scripts/data_collector/csi/collector.py | 143 +++++++++++++++++++--- scripts/data_collector/utils.py | 67 +++++++--- scripts/data_collector/yahoo/collector.py | 32 +++-- 3 files changed, 194 insertions(+), 48 deletions(-) diff --git a/scripts/data_collector/csi/collector.py b/scripts/data_collector/csi/collector.py index cc6833a7ce..af10c12d68 100644 --- a/scripts/data_collector/csi/collector.py +++ b/scripts/data_collector/csi/collector.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import abc import sys import bisect from io import BytesIO @@ -18,14 +19,12 @@ from data_collector.utils import get_hs_calendar_list as get_calendar_list -NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls" +NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls" -CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" +INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" -CSI300_START_DATE = pd.Timestamp("2005-01-01") - -class CSI300: +class CSIIndex: REMOVE = "remove" ADD = "add" @@ -45,6 +44,9 @@ def __init__(self, qlib_dir=None): self.instruments_dir.mkdir(exist_ok=True, parents=True) self._calendar_list = None + self.cache_dir = Path("~/.cache/csi").expanduser().resolve() + self.cache_dir.mkdir(exist_ok=True, parents=True) + @property def calendar_list(self) -> list: """get history trading date @@ -52,7 +54,41 @@ def calendar_list(self) -> list: Returns ------- """ - return get_calendar_list(bench=True) + return get_calendar_list(bench_code=self.index_name.upper()) + + @property + def new_companies_url(self): + return NEW_COMPANIES_URL.format(index_code=self.index_code) + + @property + def changes_url(self): + return INDEX_CHANGES_URL + + @property + @abc.abstractmethod + def bench_start_date(self) -> pd.Timestamp: + raise NotImplementedError() + + @property + @abc.abstractmethod + def index_code(self): + raise NotImplementedError() + + @property + @abc.abstractmethod + def index_name(self): + raise NotImplementedError() + + @property + @abc.abstractmethod + def html_table_index(self): + """Which table of changes in html + + CSI300: 0 + CSI100: 1 + :return: + """ + raise NotImplementedError() def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1): """get trading date by shift @@ -119,14 +155,18 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: remove_date = self._get_trading_date_by_shift(add_date, shift=-1) logger.info(f"get {add_date} changes") try: - excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0] - _io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content) + content = requests.get(f"http://www.csindex.com.cn{excel_url}").content + _io = BytesIO(content) df_map = pd.read_excel(_io, sheet_name=None) + with self.cache_dir.joinpath( + f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}" + ).open("wb") as fp: + fp.write(content) tmp = [] for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]: _df = df_map[_s_name] - _df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]] + _df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]] _df = _df.applymap(self.normalize_symbol) _df.columns = ["symbol"] _df["type"] = _type @@ -135,9 +175,13 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: df = pd.concat(tmp) except Exception: df = None + _tmp_count = 0 for _df in pd.read_html(resp.content): if _df.shape[-1] != 4: continue + _tmp_count += 1 + if self.html_table_index + 1 > _tmp_count: + continue tmp = [] for _s, _type, _date in [ (_df.iloc[2:, 0], self.REMOVE, remove_date), @@ -149,31 +193,42 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: _tmp_df["date"] = _date tmp.append(_tmp_df) df = pd.concat(tmp) + df.to_csv( + str( + self.cache_dir.joinpath( + f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv" + ).resolve() + ) + ) break return df - @staticmethod - def _get_change_notices_url() -> list: + def _get_change_notices_url(self) -> list: """get change notices url Returns ------- """ - resp = requests.get(CSI300_CHANGES_URL) + resp = requests.get(self.changes_url) html = etree.HTML(resp.text) return html.xpath("//*[@id='itemContainer']//li/a/@href") def _get_new_companies(self): logger.info("get new companies") - _io = BytesIO(requests.get(NEW_COMPANIES_URL).content) + context = requests.get(self.new_companies_url).content + with self.cache_dir.joinpath( + f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}" + ).open("wb") as fp: + fp.write(context) + _io = BytesIO(context) df = pd.read_excel(_io) df = df.iloc[:, [0, 4]] df.columns = ["end_date", "symbol"] df["symbol"] = df["symbol"].map(self.normalize_symbol) df["end_date"] = pd.to_datetime(df["end_date"]) - df["start_date"] = CSI300_START_DATE + df["start_date"] = self.bench_start_date return df def parse_instruments(self): @@ -183,7 +238,7 @@ def parse_instruments(self): ------- $ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data """ - logger.info("start parse csi300 companies.....") + logger.info(f"start parse {self.index_name.lower()} companies.....") instruments_columns = ["symbol", "start_date", "end_date"] changers_df = self._get_changes() new_df = self._get_new_companies() @@ -196,15 +251,65 @@ def parse_instruments(self): ] = _row.date else: _tmp_df = pd.DataFrame( - [[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"] + [[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"] ) new_df = new_df.append(_tmp_df, sort=False) new_df.loc[:, instruments_columns].to_csv( - self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None + self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None ) - logger.info("parse csi300 companies finished.") + logger.info(f"parse {self.index_name.lower()} companies finished.") + + +class CSI300(CSIIndex): + @property + def index_code(self): + return "000300" + + @property + def index_name(self): + return "csi300" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2005-01-01") + + @property + def html_table_index(self): + return 0 + + +class CSI100(CSIIndex): + @property + def index_code(self): + return "000903" + + @property + def index_name(self): + return "csi100" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2006-05-29") + + @property + def html_table_index(self): + return 1 + + +def parse_instruments(qlib_dir: str): + """ + + Parameters + ---------- + qlib_dir: str + qlib data dir, default "Path(__file__).parent/qlib_data" + """ + qlib_dir = Path(qlib_dir).expanduser().resolve() + qlib_dir.mkdir(exist_ok=True, parents=True) + CSI300(qlib_dir).parse_instruments() + CSI100(qlib_dir).parse_instruments() if __name__ == "__main__": - fire.Fire(CSI300) + fire.Fire(parse_instruments) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index d70aaa8b4c..d2b3835c18 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,7 +2,10 @@ # Licensed under the MIT License. import re +import time +import pickle import requests +from pathlib import Path import pandas as pd from lxml import etree @@ -11,39 +14,46 @@ CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" + +CALENDAR_BENCH_URL_MAP = { + "CSI300": CALENDAR_URL_BASE.format(bench_code="000300"), + "CSI100": CALENDAR_URL_BASE.format(bench_code="000903"), + # NOTE: Use the time series of SH600000 as the sequence of all stocks + "ALL": CALENDAR_URL_BASE.format(bench_code="600000"), +} + _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None +_CALENDAR_MAP = {} + +# NOTE: Until 2020-10-20 20:00:00 +MINIMUM_SYMBOLS_NUM = 3900 -def get_hs_calendar_list(bench=False) -> list: +def get_hs_calendar_list(bench_code="CSI300") -> list: """get SH/SZ history calendar list Parameters ---------- - bench: bool - whether to get the bench calendar list, by default False + bench_code: str + value from ["CSI300", "CSI500", "ALL"] Returns ------- history calendar list """ - global _ALL_CALENDAR_LIST - global _BENCH_CALENDAR_LIST def _get_calendar(url): _value_list = requests.get(url).json()["data"]["klines"] return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) - # TODO: get calendar from MSN - if bench: - if _BENCH_CALENDAR_LIST is None: - _BENCH_CALENDAR_LIST = _get_calendar(CSI300_BENCH_URL) - return _BENCH_CALENDAR_LIST - - if _ALL_CALENDAR_LIST is None: - _ALL_CALENDAR_LIST = _get_calendar(SH600000_BENCH_URL) - return _ALL_CALENDAR_LIST + calendar = _CALENDAR_MAP.get(bench_code, None) + if calendar is None: + calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) + _CALENDAR_MAP[bench_code] = calendar + return calendar def get_hs_stock_symbols() -> list: @@ -54,7 +64,8 @@ def get_hs_stock_symbols() -> list: stock symbols """ global _HS_SYMBOLS - if _HS_SYMBOLS is None: + + def _get_symbol(): _res = set() for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")): resp = requests.get(SYMBOLS_URL.format(s_type=_k)) @@ -64,7 +75,27 @@ def get_hs_stock_symbols() -> list: etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), ) ) - _HS_SYMBOLS = sorted(list(_res)) + return _res + + if _HS_SYMBOLS is None: + symbols = set() + _retry = 60 + # It may take multiple times to get the complete + while len(symbols) < MINIMUM_SYMBOLS_NUM: + symbols |= _get_symbol() + time.sleep(3) + + symbol_cache_path = Path("~/.cache/hs_symbols_cache.pkl").expanduser().resolve() + symbol_cache_path.parent.mkdir(parents=True, exist_ok=True) + if symbol_cache_path.exists(): + with symbol_cache_path.open("rb") as fp: + cache_symbols = pickle.load(fp) + symbols |= cache_symbols + with symbol_cache_path.open("wb") as fp: + pickle.dump(symbols, fp) + + _HS_SYMBOLS = sorted(list(symbols)) + return _HS_SYMBOLS @@ -104,3 +135,7 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str: """ res = f"{symbol[:-2]}.{symbol[-2:]}" return res.upper() if capital else res.lower() + + +if __name__ == '__main__': + assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index c54b1b8bfb..9456c6bc39 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -19,7 +19,7 @@ from dump_bin import DumpData from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols -CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" MIN_NUMBERS_TRADING = 252 / 4 @@ -130,17 +130,23 @@ def collector_data(self): logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}") - self.download_csi300_data() + self.download_index_data() - def download_csi300_data(self): + def download_index_data(self): # TODO: from MSN - logger.info(f"get bench data: csi300(SH000300)......") - df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"])) - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] - df["date"] = pd.to_datetime(df["date"]) - df = df.astype(float, errors="ignore") - df["adjclose"] = df["close"] - df.to_csv(self.save_dir.joinpath("sh000300.csv"), index=False) + for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): + logger.info(f"get bench data: {_index_name}({_index_code})......") + df = pd.DataFrame( + map( + lambda x: x.split(","), + requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"], + ) + ) + df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df["date"] = pd.to_datetime(df["date"]) + df = df.astype(float, errors="ignore") + df["adjclose"] = df["close"] + df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) class Run: @@ -192,7 +198,7 @@ def _normalize(file_path: Path): df = df[~df.index.duplicated(keep="first")] # using China stock market data calendar - df = df.reindex(pd.Index(get_calendar_list())) + df = df.reindex(pd.Index(get_calendar_list("ALL"))) df.sort_index(inplace=True) df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan @@ -274,8 +280,8 @@ def download_data(self, asynchronous=False, max_collector_count=5, delay=0): delay=delay, ).collector_data() - def download_csi300_data(self): - YahooCollector(self.source_dir).download_csi300_data() + def download_index_data(self): + YahooCollector(self.source_dir).download_index_data() def download_bench_data(self): """download bench stock data(SH000300)"""