In [2]:
import sys
sys.path.append('../')

import warnings
warnings.filterwarnings("ignore")

import os
import pandas as pd
import numpy as np
import akshare as ak
import sqlite3
import matplotlib.pyplot as plt
%matplotlib inline

from datetime import datetime
from dateutil.relativedelta import relativedelta
from tqdm import tqdm
from database.downloader.downloader_base import DownloaderBase
import database.database_config as db_config

pd.options.display.max_rows=None
pd.options.display.max_columns=None

In [3]:
def plot_series_dist(series):
    data = series
    # 使用matplotlib画直方图
    plt.hist(data, bins=60, edgecolor='k', alpha=0.7)
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.title('Histogram of Data')
    plt.show()

class PreProcessing:
    def __init__(self, db_downloader:DownloaderBase) -> None:
        self.db_downloader = db_downloader

    def _build_reg_label(self, stock_dataframe):
        N = 10 # 最大持仓周期 = N天，第N+1天开盘卖出
        df = stock_dataframe.copy()
        # 标签构建
        df['label'] = df['close'].shift(-N) / df['open'].shift(-1) - 1 # 计算第N日收益率
        # 极值处理 - quantile处理
        df['label'] = np.clip(
            df['label'], 
            np.nanquantile(df['label'], 0.01), 
            np.nanquantile(df['label'], 0.99),
            )
        # 过滤第二天一字涨停情况
        df = df[df['high'].shift(-1) != df['low'].shift(-1)]
        return df[['datetime', 'label']]
    
    # def _build_cls_label(self, dataframe, N=10, ATR_period=14):
    #     def calculate_atr(df, period=14):
    #         # True Range的计算
    #         df['high_low'] = df['high'] - df['low']
    #         df['high_cp'] = abs(df['high'] - df['close'].shift())
    #         df['low_cp'] = abs(df['low'] - df['close'].shift())
    #         tr = df[['high_low', 'high_cp', 'low_cp']].max(axis=1)
    #         return tr.rolling(window=period, min_periods=1).mean()
    #     df = dataframe.copy()
    #     df['ATR'] = calculate_atr(df, ATR_period)
    #     # 初始化标签列
    #     df['label'] = 0
    #     # 迭代每条记录
    #     for i in range(len(df)):
    #         if i + N >= len(df):
    #             continue  # 如果没有足够的未来数据，则跳过
    #         buy_price = df.at[i + 1, 'open']  # 明天的开盘价
    #         # 初始化未触发止盈止损标志
    #         triggered = False
    #         # 检查接下来的N天内是否满足条件
    #         for j in range(1, N + 1):
    #             current_close = df.at[i + j, 'close']
    #             atr_value = df.at[i + j, 'ATR']  # 第j天的ATR值
    #             take_profit = buy_price + 2 * atr_value
    #             stop_loss = buy_price - atr_value
    #             if current_close >= take_profit:
    #                 df.at[i, 'label'] = 1 # 止盈=1
    #                 triggered = True
    #                 break
    #             elif current_close <= stop_loss:
    #                 df.at[i, 'label'] = 2 # 止损=2
    #                 triggered = True
    #                 break
    #         # N天后既没有触发止盈也没有触发止损
    #         if not triggered:
    #             df.at[i, 'label'] = 0
    #     # 删除ATR列，因为它是一个中间计算列
    #     df = df[df['high'].shift(-1) != df['low'].shift(-1)]
    #     return df[['datetime', 'label', 'ATR']]

    def _process_one_stock(self, stock_code, start_date, end_date):
        stock_base = self.db_downloader._download_stock_base_info(stock_code) # 获取基础代码
        stock_individual = self.db_downloader._download_stock_individual_info(stock_code) # 获取profile信息
        stock_history = self.db_downloader._download_stock_history_info(stock_code, start_date, end_date) # 获取历史行情
        stock_indicator = self.db_downloader._download_stock_indicator_info(stock_code, start_date, end_date) # 获取指标数据
        stock_factor_date = self.db_downloader._download_stock_factor_date_info() # 获取日期特征
        stock_factor_qlib = self.db_downloader._download_stock_factor_qlib_info(stock_code, start_date, end_date) # 获取量价特征
        stock_label = self._build_reg_label(stock_history) # 构建回归Label
        stock_df = stock_base.merge(stock_individual, on=['stock_code']).merge(stock_history, on=['stock_code']).merge(stock_indicator, on=['stock_code', 'datetime']).merge(stock_label, on=['datetime']).merge(stock_factor_date, on=['datetime']).merge(stock_factor_qlib, on=['stock_code', 'datetime']) # 整合数据
        stock_df = stock_base \
            .merge(stock_individual, on=['stock_code', 'stock_name']) \
            .merge(stock_history, on=['stock_code']) \
            .merge(stock_indicator, on=['stock_code', 'datetime']) \
            .merge(stock_label, on=['datetime']) \
            .merge(stock_factor_date, on=['datetime']) \
            .merge(stock_factor_qlib, on=['stock_code', 'datetime']) # 整合数据
        stock_df = stock_df.dropna()
        return stock_df
    
    def _process_all_stock(self, code_type, start_date, end_date):
        # stock_code_list = list(ak.stock_info_a_code_name()['code'].unique()) # 获取A股所有股票列表
        stock_code_list = list(ak.index_stock_cons(code_type)['品种代码'].unique()) # 获取沪深300的股票代码列表
        stock_df_list = []
        for stock_code in tqdm(stock_code_list, desc=f'Process: {code_type} ...'):
            stock_df = self._process_one_stock(stock_code, start_date, end_date)
            if not stock_df.empty:
                stock_df_list.append(stock_df)
        return pd.concat(stock_df_list)

In [4]:
db_conn = sqlite3.connect('../database/hh_quant.db')
db_downloader = DownloaderBase(db_conn, db_config)
proprocessor = PreProcessing(db_downloader=db_downloader)

## 使用Tensorflow

In [5]:
# 使用tensorflow处理原始数据
import numpy as np
import pandas as pd
import tensorflow as tf
from model import QuantModel
print(tf.__version__)

2024-03-25 20:09:10.473040: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


2.15.0


In [6]:
def extract_train_val_data(df, train_start_date, train_end_date, val_start_date, val_end_date):
    train_start_date = pd.to_datetime(train_start_date)
    train_end_date = pd.to_datetime(train_end_date)
    val_start_date = pd.to_datetime(val_start_date)
    val_end_date = pd.to_datetime(val_end_date)
    train_data = df[(pd.to_datetime(df['datetime']) >= train_start_date) & (pd.to_datetime(df['datetime']) <= train_end_date)]
    val_data = df[(pd.to_datetime(df['datetime']) >= val_start_date) & (pd.to_datetime(df['datetime']) <= val_end_date)]

    print(f"train_data_size: {train_data.shape}")
    print(f"validation_data_size: {val_data.shape}")
    return train_data, val_data

def transfer_data_type(df, columns, dtype):
    for col in columns:
        df[col] = df[col].astype(dtype)
    return df

def get_numeric_boundaries(series, num_bins=30):
    if series.nunique() < num_bins:
        boundaries = sorted(series.unique())
    else:
        boundaries = sorted(pd.qcut(series, num_bins, retbins=True, duplicates='drop')[1].tolist())
    return boundaries

def df_to_dataset(dataframe, feature_cols, label_cols, shuffle=True, batch_size=32):
    features = dataframe[feature_cols]
    labels = dataframe[label_cols]
    ds = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(features))
    ds = ds.batch(batch_size)
    ds = ds.prefetch(batch_size)
    return ds

def get_class_weights(label_series, udf_class_weight={}):
    class_weight = {}
    cnt_list = np.bincount(label_series).tolist()
    total_cnt = np.sum(cnt_list)
    for label, cnt in enumerate(cnt_list):
        class_weight[label] = (1 / cnt) * (total_cnt / 2.0) * udf_class_weight.get(label, 1)
    return class_weight
    

In [7]:
backtest_start_date = '20190101'
backtest_end_date = '20240101'
train_period = 6 # year：训练数据周期长度
update_period = 6 # month：模型更新周期长度

def get_rolling_date_period(backtest_start_date, backtest_end_date, training_period, update_period):
    backtest_start_date = datetime.strptime(backtest_start_date, '%Y%m%d')
    backtest_end_date = datetime.strptime(backtest_end_date, '%Y%m%d')
    result = []
    rolling_flag = True
    while rolling_flag:
        current_val_start_date = backtest_start_date
        current_val_end_date = current_val_start_date + relativedelta(months=update_period) - relativedelta(days=1)
        if current_val_start_date < backtest_end_date:
            current_train_start_date = current_val_start_date - relativedelta(years=training_period)
            current_train_end_date = current_val_start_date - relativedelta(days=1)
            result.append([
                current_train_start_date.strftime("%Y%m%d"),
                current_train_end_date.strftime("%Y%m%d"),
                current_val_start_date.strftime("%Y%m%d"),
                current_val_end_date.strftime("%Y%m%d")
                ])
            backtest_start_date += relativedelta(months=update_period) 
        else:
            rolling_flag=False # 结束滚动训练
    return result

rolling_period = get_rolling_date_period(backtest_start_date, backtest_end_date, train_period, update_period)
rolling_period

[['20130101', '20181231', '20190101', '20190630'],
 ['20130701', '20190630', '20190701', '20191231'],
 ['20140101', '20191231', '20200101', '20200630'],
 ['20140701', '20200630', '20200701', '20201231'],
 ['20150101', '20201231', '20210101', '20210630'],
 ['20150701', '20210630', '20210701', '20211231'],
 ['20160101', '20211231', '20220101', '20220630'],
 ['20160701', '20220630', '20220701', '20221231'],
 ['20170101', '20221231', '20230101', '20230630'],
 ['20170701', '20230630', '20230701', '20231231']]

In [8]:
# df = proprocessor._process_all_stock(code_type='000016', start_date='20130101', end_date='20191231')
# dd = proprocessor._process_one_stock('601398', start_date='20130101', end_date='20191231')
# dd = proprocessor.db_downloader._download_stock_history_info('601398', start_date='20130101', end_date='20191231')

In [9]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler, QuantileTransformer

# 创建一个预处理管道
preprocessing_pipeline = Pipeline([
    ('quantile_transformer', QuantileTransformer(output_distribution='normal', n_quantiles=1000)),
    # ('minmax_scaler', MinMaxScaler()),
    ('standard_scaler', StandardScaler())
])

In [10]:
feature_config = {
    "target_feature_name": "label",
    "numeric_features": ['turnover_rate', 'pe_ttm', 'ps_ttm', 'pcf_ncf_ttm', 'pb_mrq', 'KMID', 'KLEN', 'KMID2', 'KUP', 'KUP2', 'KLOW', 'KLOW2', 'KSFT', 'KSFT2', 'OPEN0', 'OPEN1', 'OPEN2', 'OPEN3', 'OPEN4', 'HIGH0', 'HIGH1', 'HIGH2', 'HIGH3', 'HIGH4', 'LOW0', 'LOW1', 'LOW2', 'LOW3', 'LOW4', 'CLOSE0', 'CLOSE1', 'CLOSE2', 'CLOSE3', 'CLOSE4', 'VOLUME0', 'VOLUME1', 'VOLUME2', 'VOLUME3', 'VOLUME4', 'ROC5', 'ROC10', 'ROC20', 'ROC30', 'ROC60', 'MAX5', 'MAX10', 'MAX20', 'MAX30', 'MAX60', 'MIN5', 'MIN10', 'MIN20', 'MIN30', 'MIN60', 'MA5', 'MA10', 'MA20', 'MA30', 'MA60', 'STD5', 'STD10', 'STD20', 'STD30', 'STD60', 'BETA5', 'BETA10', 'BETA20', 'BETA30', 'BETA60', 'RSQR5', 'RSQR10', 'RSQR20', 'RSQR30', 'RSQR60', 'RESI5', 'RESI10', 'RESI20', 'RESI30', 'RESI60', 'QTLU5', 'QTLU10', 'QTLU20', 'QTLU30', 'QTLU60', 'QTLD5', 'QTLD10', 'QTLD20', 'QTLD30', 'QTLD60', 'TSRANK5', 'TSRANK10', 'TSRANK20', 'TSRANK30', 'TSRANK60', 'RSV5', 'RSV10', 'RSV20', 'RSV30', 'RSV60', 'IMAX5', 'IMAX10', 'IMAX20', 'IMAX30', 'IMAX60', 'IMIN5', 'IMIN10', 'IMIN20', 'IMIN30', 'IMIN60', 'IMXD5', 'IMXD10', 'IMXD20', 'IMXD30', 'IMXD60', 'CORR5', 'CORR10', 'CORR20', 'CORR30', 'CORR60', 'CORD5', 'CORD10', 'CORD20', 'CORD30', 'CORD60', 'CNTP5', 'CNTP10', 'CNTP20', 'CNTP30', 'CNTP60', 'CNTN5', 'CNTN10', 'CNTN20', 'CNTN30', 'CNTN60', 'CNTD5', 'CNTD10', 'CNTD20', 'CNTD30', 'CNTD60', 'SUMP5', 'SUMP10', 'SUMP20', 'SUMP30', 'SUMP60', 'SUMN5', 'SUMN10', 'SUMN20', 'SUMN30', 'SUMN60', 'SUMD5', 'SUMD10', 'SUMD20', 'SUMD30', 'SUMD60', 'VMA5', 'VMA10', 'VMA20', 'VMA30', 'VMA60', 'VSTD5', 'VSTD10', 'VSTD20', 'VSTD30', 'VSTD60', 'WVMA5', 'WVMA10', 'WVMA20', 'WVMA30', 'WVMA60', 'VSUMP5', 'VSUMP10', 'VSUMP20', 'VSUMP30', 'VSUMP60', 'VSUMN5', 'VSUMN10', 'VSUMN20', 'VSUMN30', 'VSUMN60', 'VSUMD5', 'VSUMD10', 'VSUMD20', 'VSUMD30', 'VSUMD60'],
    "integer_categorical_features": ['month'],
    "string_categorical_features": ['industry', 'season'],
}
full_feature_names = feature_config.get('numeric_features', []) + feature_config.get('integer_categorical_features', []) + feature_config.get('string_categorical_features', [])
benchmark = '000016' # [上证50:000016, 沪深300:000300, 中证500:000905]所有股票作为训练数据
batch_size = 256

sample_period = [['20130101', '20181231', '20190101', '20191231']]
for date_period in tqdm(sample_period, desc='Rolling Training...'):
# for date_period in tqdm(rolling_period, desc='Rolling Training...'):
    train_start_date, train_end_date, val_start_date, val_end_date = date_period
    print(f"train_start: {train_start_date}, train_end: {train_end_date}, val_start: {val_start_date}, val_end: {val_end_date}")
    # 1. 获取所有股票信息
    df = proprocessor._process_all_stock(code_type=benchmark, start_date=train_start_date, end_date=val_end_date)
    # 2. 拆分训练数据&验证数据
    train_data, val_data = extract_train_val_data(df, train_start_date, train_end_date, val_start_date, val_end_date)
    # 2.1 计算类别权重（分类任务使用）
    # class_weights = get_class_weights(train_data['label'])
    # print(f"class_weights: {class_weights}")
    # 2.2 特征工程
    norm_feature_columns = feature_config.get('numeric_features', [])
    train_data[norm_feature_columns] = preprocessing_pipeline.fit_transform(train_data[norm_feature_columns])
    val_data[norm_feature_columns] = preprocessing_pipeline.transform(val_data[norm_feature_columns])
    # # 2.3 目标工程
    # target_fp = FeaturePreprocessor()
    # norm_target_columns = feature_config.get('target_feature_name', [])
    # train_data[norm_target_columns] = feature_fp.fit_transform(train_data[norm_target_columns])
    # val_data[norm_target_columns] = feature_fp.transform(val_data[norm_target_columns])
    
    # 3. 构建训练集和验证集
    train_ds = df_to_dataset(train_data, full_feature_names, feature_config.get('target_feature_name', []), shuffle=True, batch_size=batch_size)
    val_ds = df_to_dataset(val_data, full_feature_names, feature_config.get('target_feature_name', []), shuffle=False, batch_size=batch_size)
    # 4. 配置模型相关参数
    model_config = {
        "seed": 1024,
        # "l2_reg": 0.001,
        "reduction_ratio": 3,
        "dnn_hidden_units": [256,128,64],
        "dnn_activation": 'relu',
        "dnn_dropout": 0.2,
        "dnn_use_bn": True,
        "numeric_features_with_boundaries": {k: list(get_numeric_boundaries(train_data[k])) for k in feature_config.get('numeric_features', [])},
        "integer_categorical_features_with_vocab": {k: list(train_data[k].unique()) for k in feature_config.get('integer_categorical_features', [])},
        "string_categorical_features_with_vocab": {k: list(train_data[k].unique()) for k in feature_config.get('string_categorical_features', [])},
        "feature_embedding_dims": 4,
    }
    # 5. 初始化模型
    model = QuantModel(model_config)

    # 8. 配置optimizer
    initial_learning_rate = 5e-4
    lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
        initial_learning_rate,
        decay_steps=(len(train_data) // batch_size)*5,
        decay_rate=1,
        staircase=False)
    # 配置模型compile
    model.compile(
        optimizer=tf.keras.optimizers.legacy.Adam(lr_schedule),
        # loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 分类任务
        loss = tf.keras.losses.MeanAbsoluteError(),
        # loss = tf.keras.losses.MeanSquaredError(),
        metrics=[
            # tf.keras.metrics.SparseCategoricalAccuracy() # 分类任务
            tf.keras.metrics.MeanAbsoluteError(),
            # tf.keras.metrics.MeanSquaredError(),
        ]
    )
    # 配置模型fit
    model.fit(
            train_ds, 
            validation_data=val_ds, 
            epochs=50,
            verbose=2,
            # class_weight=class_weights,
            callbacks=[
                tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='min'),
                tf.keras.callbacks.TensorBoard(log_dir="./logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S"), histogram_freq=1)
            ],
    )
    # 9. 配置保存模型功能
    # model_save_path = f'./models/saved_model/model_of_{val_start_date}'
    # model.save(model_save_path)
    # best_model = tf.keras.models.load_model('./best_model')

    # 10. 记录预测集合
    # model_pred_result = model.predict(val_ds)
    # model_pred_result = tf.nn.softmax(model_pred_result)
    # model_pred_label = np.argmax(model_pred_result, axis=1)
    # output_df = val_data[['stock_code', 'industry', 'stock_name', 'datetime']]
    # # 分类结果
    # # output_df['true_label'] = val_data['label']
    # # output_df['pred_label'] = model_pred_label
    # # output_df[['pred_label_0_prob','pred_label_1_prob','pred_label_2_prob']] = model_pred_result
    # # 回归结果
    # output_df['future_return'] = val_data['future_return']
    # output_df['future_return'] = 
    # # output_df.to_pickle(f'../../Offline/backtest/backtest_data/{benchmark}/stock_selection_results_{val_start_date}.pkl')
    # output_df.to_pickle(f'../../Offline/backtest/backtest_data/test/stock_selection_results_test.pkl')

Rolling Training...:   0%|          | 0/1 [00:00<?, ?it/s]

train_start: 20130101, train_end: 20181231, val_start: 20190101, val_end: 20191231


Process: 000016 ...: 100%|██████████| 50/50 [00:07<00:00,  6.30it/s]


train_data_size: (47508, 207)
validation_data_size: (8747, 207)
Epoch 1/50


Rolling Training...:   0%|          | 0/1 [02:36<?, ?it/s]


InvalidArgumentError: Graph execution error:

Detected at node gradient_tape/quant_model/add/BroadcastGradientArgs defined at (most recent call last):
  File "/Users/alsc/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/Users/alsc/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/Users/alsc/.pyenv/versions/3.10.13/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/Users/alsc/.pyenv/versions/3.10.13/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/Users/alsc/.pyenv/versions/3.10.13/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 542, in dispatch_queue

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 531, in process_one

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 359, in execute_request

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 775, in execute_request

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 446, in do_execute

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3106, in _run_cell

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3311, in run_cell_async

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3493, in run_ast_nodes

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/var/folders/vc/j8df25m509sdsv8x_v9j7gk00000ks/T/ipykernel_72044/1115789825.py", line 73, in <module>

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/engine/training.py", line 1807, in fit

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/engine/training.py", line 1401, in train_function

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/engine/training.py", line 1384, in step_function

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/engine/training.py", line 1373, in run_step

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/engine/training.py", line 1154, in train_step

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 598, in minimize

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 656, in _compute_gradients

  File "/Users/alsc/VscodeProject/hh_quant/.venv/lib/python3.10/site-packages/keras/src/optimizers/legacy/optimizer_v2.py", line 532, in _get_gradients

Incompatible shapes: [256,187,1] vs. [256,1]
	 [[{{node gradient_tape/quant_model/add/BroadcastGradientArgs}}]] [Op:__inference_train_function_24088]

In [None]:
# model.summary()
# model_config

In [None]:
model_pred_result = model.predict(val_ds)
output_df = val_data[['stock_code', 'industry', 'stock_name', 'datetime']]
# 回归结果
output_df['label'] = val_data['label']
output_df['label_pred'] = model_pred_result
output_df.to_pickle(f'../../Offline/backtest/backtest_data/test/stock_selection_results_test.pkl')



In [None]:
# output_df.head()

In [None]:
# train_data.columns[:100]

In [None]:
# plot_series_dist(train_data['MIN60'])