In [None]:
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Literal, get_args

import pandas as pd
import tushare as ts
from requests import ConnectionError as RequestsConnectionError
from sqlalchemy import create_engine, text
from sqlalchemy.exc import OperationalError
from tenacity import retry, retry_if_exception_type
from tqdm.notebook import tqdm

In [None]:
# Set the token and the http override of the Tushare Pro API
pro = ts.pro_api("20240522230128-6df8980c-e053-4db7-98f2-7b932d27d23d")
pro._DataApi__http_url = "http://tsapi.majors.ltd:7000"

In [None]:
# ...
today = pd.Timestamp.today().normalize()


def dt2date(dt: pd.Timestamp) -> str:
    return dt.strftime("%Y%m%d")


def get_trade_days(start_date: str, end_date: str) -> pd.DatetimeIndex:
    return pd.DatetimeIndex(
        pro.trade_cal(start_date=start_date, end_date=end_date, is_open="1")["cal_date"].sort_values()
    )


# Set the end of the performance evaluation
# Default to the previous year end
end = today - pd.offsets.YearEnd()

# Set the start of the performance evaluation
# Default to the earliest data available on Tushare Pro
start = pd.Timestamp("2017-01-01")

# ...
tds = get_trade_days(dt2date(start), dt2date(end))

# ...
tmes = pd.Series(index=tds).groupby([lambda x: x.year, lambda x: x.month]).tail(1).index

# ...
qes = pd.date_range(start - pd.DateOffset(years=1), end, freq="QE")

## Get the benchmark

Use CSI 300 Index (沪深指数) as the benchmark

In [None]:
# Set the data directory
DATA_DIR = Path.cwd() / "data"

# Setup the database for raw data
raw_data_db = create_engine(f"sqlite:///{DATA_DIR / 'raw_data.db'}")

# Get the CSI 300 Index daily data
index_daily = pro.index_daily(
    ts_code="000300.SH", 
    start_date=dt2date(start), 
    end_date=dt2date(end),
    fields="trade_date,close"
)

# Write the daily data to the raw data database
_ = index_daily.to_sql("index_daily", raw_data_db, if_exists="replace", index=False)

## Get the stock data

### Define the update frequency of the APIs

- TD (Trading Day)
    - Update at the close of each trading day
- TME (Trading Month End) 
    - Update at the close of the last trading day of each month
- QE (Quarter End)
    - Update on the last calendar day of each quarter

In [None]:
TushareAPI = Callable[[str], pd.DataFrame]

Registry = dict[str, TushareAPI]

td_registry: Registry = {}
tme_registry: Registry = {}
qe_registry: Registry = {}

Frequency = Literal["TD", "TME", "QE"]

APIWrapper = Callable[[TushareAPI], TushareAPI]


def register(*, freq: Frequency) -> APIWrapper:
    def wrapper(api: TushareAPI) -> TushareAPI:
        assert freq in get_args(Frequency)
        
        match freq:
            case "TD":
                td_registry[api.__name__] = api
            case "TME":
                tme_registry[api.__name__] = api
            case _:  # QE
                qe_registry[api.__name__] = api

        return api

    return wrapper

### Define the fields to fetch for each API

Using a simplified version of the Barra Total Market Equity Model (CNE6).

In [None]:
@register(freq="TME")
def bak_basic(date: str) -> pd.DataFrame:
    return pro.bak_basic(
        trade_date=date, fields="trade_date,ts_code,name,list_date",
    )


@register(freq="TD")
def daily(date: str) -> pd.DataFrame:
    return pro.daily(trade_date=date)


@register(freq="TME")
def daily_basic(date: str) -> pd.DataFrame:
    return pro.daily_basic(trade_date=date)


@register(freq="TME")
def suspend_d(date: str) -> pd.DataFrame:
    return pro.suspend_d(
        trade_date=date, suspend_type="S", fields="ts_code,trade_date",
    )


@register(freq="QE")
def balancesheet(date: str) -> pd.DataFrame:
    return pro.balancesheet_vip(period=date)


@register(freq="QE")
def fina_indicator(date: str) -> pd.DataFrame:
    return pro.fina_indicator_vip(period=date)

### Fetch the wrapped APIs and write to the raw data database

Note: `Unable to fetch bak_basic on 2021-01-29` is expected as the data is not available on Tushare Pro.

In [None]:
@retry(retry=retry_if_exception_type((RequestsConnectionError, TimeoutError)))
def fetch_registry(api_name: str) -> None:
    date_column = "trade_date"
    
    if api_name in td_registry:
        dts = tds
        api = td_registry[api_name]
    
    elif api_name in tme_registry:
        dts = tmes
        api = tme_registry[api_name]
    
    else:  # qe_registry
        date_column = "end_date"
        dts = qes
        api = qe_registry[api_name]
    
    with raw_data_db.connect() as conn:
        try:
            prefetched: list[pd.Timestamp] = [
                pd.to_datetime(i)
                for i in conn.execute(
                    text(f"SELECT DISTINCT {date_column} FROM {api_name}")
                )
            ]
        except OperationalError:
            prefetched = None
    
    # TODO bulk insert
    for dt in (pbar := tqdm(dts, desc=api_name)):
        if prefetched and dt in prefetched:
            continue
        
        api_result = api(dt2date(dt))
        
        if api_result.empty:
            pbar.write(
                f"Unable to fetch {api_name} on {dt.date()}, please try again later"
            )
            continue
        
        api_result.to_sql(api_name, raw_data_db, if_exists="append", index=False)


with ThreadPoolExecutor() as t:
    _ = t.map(fetch_registry, [*td_registry, *tme_registry, *qe_registry])

In [None]:
# Load the bak_basic data as all Coverage Universes
all_cu = pd.read_sql_table("bak_basic", raw_data_db).query("list_date != '0'")

# Load the patch data
patch = pd.read_csv(DATA_DIR / "patch.csv", dtype=str)

# Append the patch data to the Coverage Universes if not already exists
if not patch["trade_date"][0] in all_cu["trade_date"].unique():
    all_cu = pd.concat([all_cu, patch], ignore_index=True)
    all_cu.to_sql("bak_basic", raw_data_db, if_exists="replace", index=False)

all_cu[["trade_date", "list_date"]] = all_cu[["trade_date", "list_date"]].apply(pd.to_datetime)

In [None]:
# Get the SW2021 industry classification
@retry(retry=retry_if_exception_type((RequestsConnectionError, TimeoutError)))
def get_industry():
    return pd.concat(
        [
            pro.index_member_all(
                l1_code=l1_code,
                fields="l1_name,ts_code,in_date,out_date",
                is_new=is_new,
            )
            for is_new in tqdm(["Y", "N"])
            for l1_code in tqdm(
                pro.index_classify(level="L1", src="SW2021")["index_code"], leave=False
            )
        ]
    ).rename(columns={"l1_name": "industry"})


industry = get_industry()

industry[["in_date", "out_date"]] = industry[["in_date", "out_date"]].map(pd.to_datetime)

In [None]:
_all_estu: list[pd.DataFrame] = []

for tme in tqdm(tmes):
    # Get the Coverage Universe
    cu = all_cu.query("trade_date == @tme")
    
    # Filter ST stocks and stocks listed less than 52 weeks (a year)
    st_filtered = cu.query("~name.str.contains('ST')")
    
    left = st_filtered[tme - st_filtered["list_date"] > pd.Timedelta(52, "W")]

    # Prepare stocks with industry classification
    right = industry.query("(in_date <= @tme) & ((out_date > @tme) | out_date.isna())")

    # Merge the dataframes to get the temporary Estimated Universe
    _estu = left.merge(right[["ts_code", "industry"]], on="ts_code")

    # Get the suspended stocks
    suspended = pd.read_sql_query(
        f"SELECT ts_code FROM suspend_d WHERE trade_date = '{dt2date(tme)}'",
        raw_data_db,
    )

    # Filter suspended stocks
    estu = _estu.query("~ts_code.isin(@suspended.ts_code)")

    _all_estu.append(estu)

all_estu = pd.concat(_all_estu)

# Write the Estimated Universe to the raw data db
all_estu[["trade_date", "list_date"]] = all_estu[["trade_date", "list_date"]].map(dt2date)

_ = all_estu.to_sql("estu", raw_data_db, if_exists="replace", index=False)

In [None]:
fundamentals_db = create_engine(f"sqlite:///{DATA_DIR / 'fundamentals.db'}")

In [217]:
# noinspection PyShadowingNames
def td2tme(column: str, api_name: str) -> pd.DataFrame:
    return pd.concat(
        [
            pd.read_sql_query(
                f"""
                SELECT ts_code,
                       trade_date,
                       {column}
                FROM {api_name}
                WHERE trade_date = '{dt2date(tme)}'
                  AND ts_code IN (
                      SELECT ts_code
                      FROM estu
                      WHERE trade_date = '{dt2date(tme)}'
                  )
                """,
                raw_data_db,
            )
            for tme in tqdm(tmes, desc=column)
        ]
    )

In [299]:
# noinspection PyShadowingNames
def qe2tme(column: str, api_name: str, date_column: str) -> pd.DataFrame:
    return pd.concat(
        [
            pd.read_sql_query(
                f"""
                WITH point_in_time AS (
                    SELECT *,
                           ROW_NUMBER() OVER (
                               PARTITION BY ts_code
                               ORDER BY {date_column} DESC, update_flag DESC
                           ) AS rn
                    FROM {api_name}
                    WHERE {date_column} <= '{dt2date(tme)}'
                          AND ts_code IN (
                              SELECT ts_code
                              FROM estu
                              WHERE trade_date = '{dt2date(tme)}'
                          )
                )
                SELECT ts_code, {column}
                FROM point_in_time
                WHERE rn = 1
                """,
                raw_data_db,
            ).assign(trade_date=tme)
            for tme in tqdm(tmes, desc=column)
        ],
        ignore_index=True,
    )[["ts_code", "trade_date", column]]

In [311]:
# 1.
total_mv = td2tme("total_mv", "daily_basic")

# 2. 
bp = td2tme("pb", "daily_basic").rename(columns={"pb": "bp"})

bp["bp"] = 1 / bp["bp"]

# 3.
epttm = td2tme("pe_ttm", "daily_basic").rename(columns={"pe_ttm": "epttm"})

epttm["epttm"] = 1 / epttm["epttm"]

# 4.

# 5.

# 6.
dv_ratio = td2tme("dv_ratio", "daily_basic") 

total_mv:   0%|          | 0/84 [00:00<?, ?it/s]

pb:   0%|          | 0/84 [00:00<?, ?it/s]

pe_ttm:   0%|          | 0/84 [00:00<?, ?it/s]

In [316]:
bp

Unnamed: 0,ts_code,trade_date,bp
0,300268.SZ,20170126,0.048974
1,300263.SZ,20170126,0.400834
2,300294.SZ,20170126,0.142278
3,002616.SZ,20170126,0.238299
4,300274.SZ,20170126,0.396825
...,...,...,...
4891,300148.SZ,20231229,0.328127
4892,603878.SH,20231229,0.643004
4893,002443.SZ,20231229,0.925241
4894,300886.SZ,20231229,0.293850


In [None]:
"""
import requests

start_m = dt2m(start)
end_m = dt2m(end)

m1_yoy = pro.cn_m(start_m=start_m, end_m=end_m, fields="month,m1_yoy")
cpi_yoy = pro.cn_cpi(start_m=start_m, end_m=end_m, fields="month,nt_yoy")
ppi_yoy = pro.cn_ppi(start_m=start_m, end_m=end_m, fields="month,ppi_yoy")"""


"""def bond_china_yield(
        start_date: str, end_date: str, gjqx: int, qxId: str = "hzsylqx"
) -> pd.DataFrame:
    start_date = start_date.date()
    end_date = end_date.date()
    
    url = "http://yield.chinabond.com.cn/cbweb-pbc-web/pbc/historyQuery"
    params = {
        "startDate": start_date,
        "endDate": end_date,
        "gjqx": str(gjqx),
        "qxId": qxId,
        "locale": "cn_ZH",
    }
    headers = {
        "User-Agent":
            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/79.0.3945.130 Safari/537.36",
    }
    res = requests.get(url, params=params, headers=headers)
    data_text = res.text.replace("&nbsp", "")
    data_df = pd.read_html(data_text, header=0)[1]
    return data_df


def get_bond_yield(
        start_date: str, end_date: str, periods: int, bond_type: str
) -> pd.DataFrame:
    dates = get_trade_days(start_date, end_date)
    n_days = len(dates)
    limit = 244

    if n_days > limit:

        n = n_days // limit
        df_list = []
        i = 0
        pos1, pos2 = n * i, n * (i + 1) - 1
        while pos2 < n_days:
            print(pos2)
            df = bond_china_yield(start_date=dates[pos1], end_date=dates[pos2],gjqx=periods,qxId=bond_type)
            df_list.append(df)
            i += 1
            pos1, pos2 = n * i, n * (i + 1) - 1

        if pos1 < n_days:
            df = bond_china_yield(start_date=dates[pos1], end_date=dates[-1],gjqx=periods,qxId=bond_type)
            df_list.append(df)
        df = pd.concat(df_list, axis=0)
    else:
        df = bond_china_yield(start_date=start_date, end_date=end_date,gjqx=periods,qxId=bond_type)

    return df.dropna(axis=1)
    
    
get_bond_yield(dt2date(start), dt2date(end), 0.25, "hzsylqx")
"""