In [16]:
import pandas as pd
import numpy as np
import typing
import joblib # >= 1.2.0
import matplotlib.pyplot as plt
import alphagens
from alphagens.factor_utils import get_clean_factor_and_forward_returns, get_clean_factor_and_current_returns
from alphagens.backtest import QuickBackTestor, QuickFactorTestor
from alphagens.utils.metrics import FactorMetrics, StrategyMetrics
from alphagens.data_source.tushare import pro, Stock, Index
from alphagens.calendars import DEFAULT_CALENDAR
from alphagens.utils.format_output import df_to_html
from alphagens.edbt.test import SimulationEngine, SimulatedBroker, Account, BaseStrategy

In [17]:
alphagens.__version__

'0.4.1'

In [2]:
class Context:
    DATA_PATH = "./data"
    START_DATE = "20100101"
    END_DATE = "20180101"
    BENCHMARK = "000300.SH"
    
    trade_dates = DEFAULT_CALENDAR.sessions_in_range(START_DATE, END_DATE)
    REBALANCE_DATES = DEFAULT_CALENDAR.Weekly(trade_dates)

In [10]:
class DataPortal:
    def __init__(self):
        start_date = Context.START_DATE
        end_date = Context.END_DATE
        self._all_basic_data: pd.DataFrame = joblib.load(f"{Context.DATA_PATH}/tushare.ex_basic")
        self.prices: pd.DataFrame = self._all_basic_data["close"].unstack().fillna(method='ffill')
        self.universe: list[str] = Index.components(ts_code=Context.BENCHMARK, end_date="20180101").index
        self.universe = sorted([x[:-3] for x in self.universe])
    
        self.factors: pd.DataFrame = joblib.load(f"{Context.DATA_PATH}/uqer.factor").loc[Context.trade_dates]
        self.industry_map: pd.Series = joblib.load(f"{Context.DATA_PATH}/uqer.industry_map")
        self.benchmark: pd.Series = Index.history(["000300.SH"], start_date, end_date)[0]["close"]

    def history(self, date, symbols: list, field: str, lookback: int = None):
        if lookback is not None:
            slice_dates = DEFAULT_CALENDAR.history(date, lookback)
            return self._all_basic_data.loc[(slice_dates, symbols), field]

    def query_covariance(self, date, symbols, lookback):
        """请注意协方差矩阵的数量级!!!
        """
        if lookback < 2 * len(symbols):
            raise ValueError("lookback must be twice as long as length of symbols")
        slice_data = self.history(date, symbols, "pct_chg", lookback) / 100
        slice_data = slice_data.unstack().fillna(0)
        mean = slice_data.mean(axis=0)
        cov = slice_data.cov()
        return mean, cov
    
    @property
    def factor_names(self):
        return self.factors.columns.to_list()
    
    def factor_get(self, factor_name, dates=None):
        assert factor_name in self.factor_names
        if dates is not None:
            return self.factors.loc[(dates, slice(None)), factor_name]
        else:
            return self.factors[factor_name]
        
    def get_trading_constraints(self, type: typing.Literal["limit_up", "limit_down"]):
        if type == "limit_up":
            return (self._all_basic_data["pct_chg"] >= 10).unstack().fillna(False)
data_portal = DataPortal()

In [11]:
class BuyAndHold(BaseStrategy):
    def __init__(self, engine, broker, account):
        super().__init__(engine, broker, account)
    
    def before_trading_end(self):
        if self.current_date == self.trade_dates[0]:
            self.account.order_target_pct_to(pd.Series(1, index=["000001"]))

In [12]:
engine = SimulationEngine(Context.trade_dates, data_portal)
broker = SimulatedBroker(engine)
account = Account(engine, capital_base=1e6)
algo = BuyAndHold(engine, broker, account)

In [13]:
algo.run_backtest()

1945it [00:00, 21835.04it/s]            


In [14]:
algo.account.cash

1.6509999999543652

In [15]:
algo.account.portfolio_value

1660825.3190000001