Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 124 additions & 19 deletions scripts/data_collector/csi/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import re
import abc
import sys
import bisect
from io import BytesIO
Expand All @@ -18,14 +19,12 @@
from data_collector.utils import get_hs_calendar_list as get_calendar_list


NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/000300cons.xls"
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"

CSI300_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"

CSI300_START_DATE = pd.Timestamp("2005-01-01")


class CSI300:
class CSIIndex:

REMOVE = "remove"
ADD = "add"
Expand All @@ -45,14 +44,51 @@ def __init__(self, qlib_dir=None):
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self._calendar_list = None

self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)

@property
def calendar_list(self) -> list:
"""get history trading date

Returns
-------
"""
return get_calendar_list(bench=True)
return get_calendar_list(bench_code=self.index_name.upper())

@property
def new_companies_url(self):
return NEW_COMPANIES_URL.format(index_code=self.index_code)

@property
def changes_url(self):
return INDEX_CHANGES_URL

@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
raise NotImplementedError()

@property
@abc.abstractmethod
def index_code(self):
raise NotImplementedError()

@property
@abc.abstractmethod
def index_name(self):
raise NotImplementedError()

@property
@abc.abstractmethod
def html_table_index(self):
"""Which table of changes in html

CSI300: 0
CSI100: 1
:return:
"""
raise NotImplementedError()

def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
"""get trading date by shift
Expand Down Expand Up @@ -119,14 +155,18 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame:
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
logger.info(f"get {add_date} changes")
try:

excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
_io = BytesIO(requests.get(f"http://www.csindex.com.cn{excel_url}").content)
content = requests.get(f"http://www.csindex.com.cn{excel_url}").content
_io = BytesIO(content)
df_map = pd.read_excel(_io, sheet_name=None)
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(content)
tmp = []
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == "000300", ["证券代码"]]
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = ["symbol"]
_df["type"] = _type
Expand All @@ -135,9 +175,13 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame:
df = pd.concat(tmp)
except Exception:
df = None
_tmp_count = 0
for _df in pd.read_html(resp.content):
if _df.shape[-1] != 4:
continue
_tmp_count += 1
if self.html_table_index + 1 > _tmp_count:
continue
tmp = []
for _s, _type, _date in [
(_df.iloc[2:, 0], self.REMOVE, remove_date),
Expand All @@ -149,31 +193,42 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame:
_tmp_df["date"] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
df.to_csv(
str(
self.cache_dir.joinpath(
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
).resolve()
)
)
break
return df

@staticmethod
def _get_change_notices_url() -> list:
def _get_change_notices_url(self) -> list:
"""get change notices url

Returns
-------

"""
resp = requests.get(CSI300_CHANGES_URL)
resp = requests.get(self.changes_url)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")

def _get_new_companies(self):

logger.info("get new companies")
_io = BytesIO(requests.get(NEW_COMPANIES_URL).content)
context = requests.get(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
).open("wb") as fp:
fp.write(context)
_io = BytesIO(context)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = ["end_date", "symbol"]
df["symbol"] = df["symbol"].map(self.normalize_symbol)
df["end_date"] = pd.to_datetime(df["end_date"])
df["start_date"] = CSI300_START_DATE
df["start_date"] = self.bench_start_date
return df

def parse_instruments(self):
Expand All @@ -183,7 +238,7 @@ def parse_instruments(self):
-------
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info("start parse csi300 companies.....")
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = ["symbol", "start_date", "end_date"]
changers_df = self._get_changes()
new_df = self._get_new_companies()
Expand All @@ -196,15 +251,65 @@ def parse_instruments(self):
] = _row.date
else:
_tmp_df = pd.DataFrame(
[[_row.symbol, CSI300_START_DATE, _row.date]], columns=["symbol", "start_date", "end_date"]
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
)
new_df = new_df.append(_tmp_df, sort=False)

new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath("csi300.txt"), sep="\t", index=False, header=None
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info("parse csi300 companies finished.")
logger.info(f"parse {self.index_name.lower()} companies finished.")


class CSI300(CSIIndex):
@property
def index_code(self):
return "000300"

@property
def index_name(self):
return "csi300"

@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2005-01-01")

@property
def html_table_index(self):
return 0


class CSI100(CSIIndex):
@property
def index_code(self):
return "000903"

@property
def index_name(self):
return "csi100"

@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2006-05-29")

@property
def html_table_index(self):
return 1


def parse_instruments(qlib_dir: str):
"""

Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
qlib_dir.mkdir(exist_ok=True, parents=True)
CSI300(qlib_dir).parse_instruments()
CSI100(qlib_dir).parse_instruments()


if __name__ == "__main__":
fire.Fire(CSI300)
fire.Fire(parse_instruments)
67 changes: 51 additions & 16 deletions scripts/data_collector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# Licensed under the MIT License.

import re
import time
import pickle
import requests
from pathlib import Path

import pandas as pd
from lxml import etree
Expand All @@ -11,39 +14,46 @@
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"

CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"

CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
}

_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_CALENDAR_MAP = {}

# NOTE: Until 2020-10-20 20:00:00
MINIMUM_SYMBOLS_NUM = 3900


def get_hs_calendar_list(bench=False) -> list:
def get_hs_calendar_list(bench_code="CSI300") -> list:
"""get SH/SZ history calendar list

Parameters
----------
bench: bool
whether to get the bench calendar list, by default False
bench_code: str
value from ["CSI300", "CSI500", "ALL"]

Returns
-------
history calendar list
"""
global _ALL_CALENDAR_LIST
global _BENCH_CALENDAR_LIST

def _get_calendar(url):
_value_list = requests.get(url).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))

# TODO: get calendar from MSN
if bench:
if _BENCH_CALENDAR_LIST is None:
_BENCH_CALENDAR_LIST = _get_calendar(CSI300_BENCH_URL)
return _BENCH_CALENDAR_LIST

if _ALL_CALENDAR_LIST is None:
_ALL_CALENDAR_LIST = _get_calendar(SH600000_BENCH_URL)
return _ALL_CALENDAR_LIST
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
_CALENDAR_MAP[bench_code] = calendar
return calendar


def get_hs_stock_symbols() -> list:
Expand All @@ -54,7 +64,8 @@ def get_hs_stock_symbols() -> list:
stock symbols
"""
global _HS_SYMBOLS
if _HS_SYMBOLS is None:

def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
Expand All @@ -64,7 +75,27 @@ def get_hs_stock_symbols() -> list:
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
_HS_SYMBOLS = sorted(list(_res))
return _res

if _HS_SYMBOLS is None:
symbols = set()
_retry = 60
# It may take multiple times to get the complete
while len(symbols) < MINIMUM_SYMBOLS_NUM:
symbols |= _get_symbol()
time.sleep(3)

symbol_cache_path = Path("~/.cache/hs_symbols_cache.pkl").expanduser().resolve()
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
if symbol_cache_path.exists():
with symbol_cache_path.open("rb") as fp:
cache_symbols = pickle.load(fp)
symbols |= cache_symbols
with symbol_cache_path.open("wb") as fp:
pickle.dump(symbols, fp)

_HS_SYMBOLS = sorted(list(symbols))

return _HS_SYMBOLS


Expand Down Expand Up @@ -104,3 +135,7 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
"""
res = f"{symbol[:-2]}.{symbol[-2:]}"
return res.upper() if capital else res.lower()


if __name__ == '__main__':
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
Loading