Part 1: Data Preprocessing

This part prepares the raw ETH/USDT 1-minute dataset for reinforcement learning. It sorts the data by timestamp, selects OHLCV and indicator features, creates additional features such as log returns, splits the dataset into training and testing sets, and standardizes all features to ensure stable training.

In [7]:
# Install required packages
%pip install numpy pandas matplotlib seaborn plotly scikit-learn stable-baselines3 gymnasium pyarrow ta


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:

import os
import numpy as np
import pandas as pd

# ===================== Config / 配置 =====================
PATH = "../ETHUSDT_1m_with_indicators.parquet"
TS_COL = "ts"                  # 时间戳列名 / timestamp column
WIN = 60                       # 归一化滚动窗口 / rolling window (minutes)
EPS = 1e-12                    # 防除零 / avoid division by zero
OUTPUT_DIR = "./processed_data_minmax"

# 列集合 / Column sets
OHLCV = ["open","high","low","close","volume"]
INDICATORS_PRICE_LEVEL = ["BB_mid","BB_high","BB_low","EMA_12","EMA_26"]  # 与价格同量纲的指标
INDICATORS_OWN_SCALE  = ["MACD","MACD_signal","MACD_diff","ATR"]          # 自身尺度的指标
RSI_COL = "RSI"  # RSI 单独处理 / special-case for RSI (0..100)

# ===================== Helpers / 工具函数 =====================
def minmax_rolling(series: pd.Series, win: int, min_periods: int = 1) -> pd.Series:
    """滚动 min-max 归一化到 [0,1]；前期使用逐步扩大的窗口，避免 NaN
    Rolling min-max to [0,1]; uses expanding window at the beginning to avoid NaNs.
    """
    roll_min = series.rolling(win, min_periods=min_periods).min()
    roll_max = series.rolling(win, min_periods=min_periods).max()
    denom = (roll_max - roll_min).replace(0, np.nan)
    out = (series - roll_min) / denom
    # 处理 0 除与边界 / handle zero-division and bounds
    out = out.fillna(0.5)   # 当区间无波动时给 0.5 中性值 / neutral when no range
    return out.clip(0.0, 1.0)

# ===================== Load & clean / 读取与清洗 =====================
df = pd.read_parquet(PATH)
df = df.reset_index()                # 防止 ts 在索引里 / in case ts was index
df.columns = df.columns.str.strip()
df = df.sort_values(TS_COL).reset_index(drop=True)

# 检查必需列 / sanity check
need = set([TS_COL] + OHLCV + [RSI_COL] + INDICATORS_PRICE_LEVEL + INDICATORS_OWN_SCALE)
missing = [c for c in need if c not in df.columns]
if missing:
    print(f"[WARN] missing columns (will be ignored): {missing}")

# 仅保留存在的列名 / keep existing
OHLCV = [c for c in OHLCV if c in df.columns]
INDICATORS_PRICE_LEVEL = [c for c in INDICATORS_PRICE_LEVEL if c in df.columns]
INDICATORS_OWN_SCALE = [c for c in INDICATORS_OWN_SCALE if c in df.columns]
has_rsi = (RSI_COL in df.columns)

# 替换无穷与异常 / replace infinities then keep NaNs for rolling minmax
df[OHLCV + INDICATORS_PRICE_LEVEL + INDICATORS_OWN_SCALE + ([RSI_COL] if has_rsi else [])] = \
    df[OHLCV + INDICATORS_PRICE_LEVEL + INDICATORS_OWN_SCALE + ([RSI_COL] if has_rsi else [])] \
      .replace([np.inf, -np.inf], np.nan)

# ===================== Price-based normalization / 价格相关归一化 =====================
# 价格用 60 分钟窗口的 low(min) 与 high(max) 做统一尺度 / use rolling low/high across last 60 mins
roll_low  = df["low"].rolling(WIN, min_periods=1).min()
roll_high = df["high"].rolling(WIN, min_periods=1).max()
price_denom = (roll_high - roll_low).mask((roll_high - roll_low) == 0, EPS)

# 保留原始列，同时新增 *_norm 列 / keep raw columns and add *_norm
for col in ["open","high","low","close"]:
    if col in df.columns:
        df[col + "_norm"] = ((df[col] - roll_low) / price_denom).clip(0.0, 1.0)

# Volume：先对数再滚动归一化 / log1p then rolling min-max
if "volume" in df.columns:
    vol_log = np.log1p(df["volume"].clip(lower=0))
    df["volume_norm"] = minmax_rolling(vol_log, WIN, min_periods=1)

# ===================== Indicator normalization / 技术指标归一化 =====================
# RSI：天然 0..100 → /100 到 [0,1]
if has_rsi:
    df["RSI_norm"] = (df["RSI"].clip(0, 100) / 100.0).fillna(0.5)

# 与价格同量纲的指标，使用同一套 roll_low/roll_high
for col in INDICATORS_PRICE_LEVEL:
    df[col + "_norm"] = ((df[col] - roll_low) / price_denom).clip(0.0, 1.0)

# 自有尺度的指标，用自身滚动 min-max
for col in INDICATORS_OWN_SCALE:
    df[col + "_norm"] = minmax_rolling(df[col], WIN, min_periods=1)

# 额外：log return（可选）
df["logret"] = np.log(df["close"]).diff()
df["logret_norm"] = minmax_rolling(df["logret"].fillna(0), WIN, min_periods=1)


# ===================== Split & export / 切分与导出 =====================
#  按时间切分训练/测试 / Time-based split (80/20)
split_idx = int(len(df) * 0.8)
train_df = df.iloc[:split_idx].copy()
test_df  = df.iloc[split_idx:].copy()
train_df["split"] = "train"
test_df["split"]  = "test"

# 选择要导出的列：ts + 原始 OHLCV + 全部 *_norm 特征
#    Select columns to export: ts + raw OHLCV + all *_norm features
norm_cols = [c for c in df.columns if c.endswith("_norm")]
export_cols = [TS_COL] + OHLCV + norm_cols + ["split"]

# 确认所有导出列都存在于两个子集 / Ensure all export columns exist in both splits
missing_train = [c for c in export_cols if c not in train_df.columns]
missing_test  = [c for c in export_cols if c not in test_df.columns]
if missing_train or missing_test:
    raise KeyError(f"Export columns missing. "
                   f"train_df missing: {missing_train}, test_df missing: {missing_test}")

# 对 *_norm 特征做 NaN/Inf 清洗并用中位数填充
#    Clean NaN/Inf in *_norm features and fill with per-column median
def fill_norm_with_median(df_sub, norm_feature_cols):
    # 将 Inf 转为 NaN / replace Inf with NaN
    df_sub.loc[:, norm_feature_cols] = df_sub[norm_feature_cols].replace([np.inf, -np.inf], np.nan)
    # 逐列用中位数填充；若整列都是 NaN，则回退为 0.5（中性值）
    # Fill NaNs with column median; fallback to 0.5 if the column is all NaN
    for col in norm_feature_cols:
        med = df_sub[col].median()
        if pd.isna(med):
            med = 0.5
        df_sub.loc[:, col] = df_sub[col].fillna(med)
    return df_sub

train_df = fill_norm_with_median(train_df, norm_cols)
test_df  = fill_norm_with_median(test_df,  norm_cols)

# 最终健检：确保没有 NaN/Inf 遗留 / Final sanity check: no NaN/Inf remaining
assert np.isfinite(train_df[norm_cols].to_numpy()).all(), "train_df still has NaN/Inf after fill."
assert np.isfinite(test_df[norm_cols].to_numpy()).all(),  "test_df still has NaN/Inf after fill."

# 导出 CSV（仅导出所选列）/ Export CSVs (selected columns only)
os.makedirs(OUTPUT_DIR, exist_ok=True)

train_path = os.path.join(OUTPUT_DIR, "train_minmax.csv")
test_path  = os.path.join(OUTPUT_DIR, "test_minmax.csv")
combo_path = os.path.join(OUTPUT_DIR, "combined_minmax.csv")

train_df[export_cols].to_csv(train_path, index=False, float_format="%.6f")
test_df[export_cols].to_csv(test_path,   index=False, float_format="%.6f")
pd.concat([train_df[export_cols], test_df[export_cols]], ignore_index=True)\
  .to_csv(combo_path, index=False, float_format="%.6f")

print("[OK] Exported CSVs with RAW OHLCV + *_norm features at", OUTPUT_DIR)
print("  -", train_path)
print("  -", test_path)
print("  -", combo_path)


FileNotFoundError: [Errno 2] No such file or directory: '/ETHUSDT_1m_with_indicators.parquet'

### Part 2: Environment Setup

This part defines the custom trading environment `MinuteTradingEnv` based on the Gymnasium API. The environment simulates one-minute trading with continuous position control.

- **Initialization**:
  Load and clean the input data, define parameters such as window size, fees, slippage, and penalties. Set up action space as continuous positions in [-1, 1] and observation space as flattened windows of features.

- **Observation**:
  Construct state representations by extracting the most recent feature window and flattening it into a vector.

- **Reset**:
  Initialize the environment for a new episode with index at the end of the first window, flat position, and starting NAV = 1.

- **Step Function**:
  Execute one action, calculate log returns from consecutive prices, apply transaction costs and penalties, update NAV, and return the new observation, reward, termination flag, and info (including current NAV and position).


In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
# 可选：如需类型提示，取消下一行注释
# from typing import Dict, Optional, Tuple

class MinuteTradingEnv(gym.Env):
    metadata = {"render_modes": []}

    # ===== 1) 初始化 / Initialization =====
    def __init__(self, df, feat_cols, window=60,
                 fee_rate=0.0005, slippage=0.0002, pos_change_penalty=0.001,
                 logret_clip=0.10, reward_clip=0.25, eps=1e-12,
                 # ---- 止损与跟踪止盈参数 / Stop-loss & trailing-stop params ----
                 open_thr=0.10,          # 视为“持仓”的阈值（|pos|>open_thr 才认为已开仓）
                 flat_thr=0.05,          # 视为“空仓”的阈值（|pos|<=flat_thr 认为空仓）
                 stop_loss_pct=0.02,     # 硬止损：相对开仓价的不利幅度（2%）
                 trailing_stop_pct=0.03, # 跟踪止盈：相对开仓以来极值的回撤幅度（3%）
                 # ---- 指标权重调整器 / Feature-weight controller ----
                 feature_weights=None,   # 例: {"RSI_norm":1.5, "MACD_norm":0.8}
                 normalize_weights=False,# 是否把权重按均值归一化，避免整体量级漂移
                 weight_clip=(0.1, 5.0)  # 权重夹紧范围，防止极端放大/缩小
                 ):
        """
        Custom minute-level trading environment with risk controls (+ per-feature weights).
        带风险控制（硬止损 + 跟踪止盈）与“指标权重调整器”的分钟级交易环境
        """
        super().__init__()

        # 数据清理 / Data cleaning
        self.df = (df.replace([np.inf, -np.inf], np.nan)
                     .dropna(subset=["close"] + list(feat_cols))
                     .reset_index(drop=True))
        assert len(self.df) > window + 1, "Data length must be larger than window size."

        # 基础参数 / Basic settings
        self.feat_cols = list(feat_cols)
        self.window = window
        self.fee_rate = fee_rate
        self.slippage = slippage
        self.pos_change_penalty = pos_change_penalty
        self.logret_clip = logret_clip
        self.reward_clip = reward_clip
        self.eps = eps

        # 风险控制参数 / Risk control params
        self.open_thr = float(open_thr)
        self.flat_thr = float(flat_thr)
        self.stop_loss_pct = float(stop_loss_pct)
        self.trailing_stop_pct = float(trailing_stop_pct)

        # 缓存数据 / Cached arrays
        self.prices = self.df["close"].to_numpy(dtype=np.float64)
        self.features = self.df[self.feat_cols].to_numpy(dtype=np.float32)

        # === 指标权重：与 feat_cols 对齐的缩放向量 / per-feature scale vector ===
        self._feat_index = {c: i for i, c in enumerate(self.feat_cols)}  # 列名->索引
        self._feat_scale = np.ones(len(self.feat_cols), dtype=np.float32) # 默认全1
        self._weight_clip = (float(weight_clip[0]), float(weight_clip[1]))
        if feature_weights:
            self._apply_feature_weights(feature_weights, normalize_weights)

        # 动作与观测空间 / Action & Observation spaces
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(self.window * len(self.feat_cols),),
            dtype=np.float32
        )

        # 时间控制 / Time indexing
        self.start = self.window
        self.end = len(self.df) - 2   # 预留 i+1 做结算
        self.i = None
        self.position = None
        self.nav = None

        # 进出场跟踪 / Entry tracking
        self.entry_price = None   # 开仓参考价
        self.entry_sign = 0       # 开仓方向（+1/-1）
        self.peak_price = None    # 多头最高价（用于跟踪止盈）
        self.trough_price = None  # 空头最低价（用于跟踪止盈）

    # === 内部：应用权重 / Internal: apply feature weights ===
    def _apply_feature_weights(self, weights, normalize: bool):
        """
        根据传入的字典更新每列的缩放系数；可按均值归一化，避免整体量级变化。
        Update per-feature scales from a dict; optionally normalize by mean.
        """
        scale = self._feat_scale.copy()
        lo, hi = self._weight_clip
        for name, w in weights.items():
            if name in self._feat_index:
                scale[self._feat_index[name]] = np.clip(float(w), lo, hi)
        if normalize:
            m = float(np.mean(scale))
            if m > 0:
                scale = scale / m
        self._feat_scale = scale.astype(np.float32)

    # === 外部：运行时更新权重 / Public: update weights at runtime ===
    def set_feature_weights(self, weights: dict, normalize: bool = False):
        """
        训练/评估过程中随时更新部分或全部指标权重。
        Update (a subset of) feature weights on the fly during training/evaluation.
        """
        self._apply_feature_weights(weights, normalize)

    # ===== 2) 构造观测 / Build observation =====
    def _get_obs(self):
        x = self.features[self.i - self.window + 1 : self.i + 1]  # [W, F]
        # 按列加权 / per-feature scaling
        x = x * self._feat_scale  # [W, F] * [F]（广播）
        obs = x.reshape(-1).astype(np.float32)
        assert np.isfinite(obs).all(), "Observation contains NaN/Inf"
        return obs

    # ===== 3) 重置环境 / Reset environment =====
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.i = self.start
        self.position = 0.0
        self.nav = 1.0

        # 重置进出场跟踪 / reset entry trackers
        self.entry_price = None
        self.entry_sign = 0
        self.peak_price = None
        self.trough_price = None

        return self._get_obs(), {}

    # ===== 4) 执行一步（含止损逻辑）/ Take one step (with stops) =====
    def step(self, action):
        # --- 动作处理 / Action processing ---
        a = float(np.clip(action[0], -1.0, 1.0))
        prev_pos = self.position
        self.position = a

        # --- 价格获取 / Price handling ---
        p_now  = float(self.prices[self.i])
        p_next = float(self.prices[self.i + 1])
        if not np.isfinite(p_now) or p_now <= 0:
            p_now = max(1.0, p_next)
        if not np.isfinite(p_next) or p_next <= 0:
            p_next = p_now

        # --- 对数收益 / Log return for i -> i+1 ---
        ratio = max(p_next / max(p_now, self.eps), self.eps)
        log_ret = float(np.log(ratio))
        if self.logret_clip is not None:
            log_ret = float(np.clip(log_ret, -self.logret_clip, self.logret_clip))

        # ===== A) 开平仓事件跟踪 / Entry & tracking =====
        # 进入“有效持仓”状态：|pos| > open_thr
        became_long  = (abs(prev_pos) <= self.open_thr) and (self.position >  self.open_thr)
        became_short = (abs(prev_pos) <= self.open_thr) and (self.position < -self.open_thr)
        flattened    = (abs(prev_pos)  > self.open_thr) and (abs(self.position) <= self.flat_thr)

        if flattened:
            # 平仓 -> 清空跟踪器 / clear trackers on flatten
            self.entry_price = None
            self.entry_sign = 0
            self.peak_price = None
            self.trough_price = None

        if became_long:
            self.entry_price = p_now
            self.entry_sign  = +1
            self.peak_price  = p_now
            self.trough_price = None

        elif became_short:
            self.entry_price = p_now
            self.entry_sign  = -1
            self.trough_price = p_now
            self.peak_price   = None

        # 若已持仓则更新极值（用于跟踪止盈）/ update extremes while in position
        if self.entry_sign == +1:
            self.peak_price = max(self.peak_price if self.peak_price is not None else p_now, p_now)
        elif self.entry_sign == -1:
            self.trough_price = min(self.trough_price if self.trough_price is not None else p_now, p_now)

        # ===== B) 成本与惩罚 / Costs and penalties =====
        pos_change = abs(self.position - prev_pos)
        fee_slip = (self.fee_rate + self.slippage)
        trade_cost = pos_change * fee_slip                     # 手续费 + 滑点
        adj_penalty = self.pos_change_penalty * pos_change     # 调仓惩罚

        # ===== C) 止损与跟踪止盈判定 / Stop-loss & trailing-stop checks =====
        forced_flat = False
        extra_close_cost = 0.0

        if self.entry_sign != 0 and self.entry_price is not None:
            # 硬止损（相对开仓价）
            if self.entry_sign == +1:
                hard_stop = (p_next <= self.entry_price * (1.0 - self.stop_loss_pct))
            else:
                hard_stop = (p_next >= self.entry_price * (1.0 + self.stop_loss_pct))

            # 跟踪止盈（相对极值的回撤）
            if self.entry_sign == +1 and self.peak_price is not None:
                trail_stop = (p_next <= self.peak_price * (1.0 - self.trailing_stop_pct))
            elif self.entry_sign == -1 and self.trough_price is not None:
                trail_stop = (p_next >= self.trough_price * (1.0 + self.trailing_stop_pct))
            else:
                trail_stop = False

            if hard_stop or trail_stop:
                # 触发强制平仓：本步结束后把仓位设为 0，并加一次额外平仓成本
                forced_flat = True
                extra_close_cost = abs(self.position - 0.0) * fee_slip \
                                   + self.pos_change_penalty * abs(self.position - 0.0)

        # ===== D) 单步收益与奖励 / Step return & reward =====
        step_ret = self.position * log_ret - trade_cost - adj_penalty

        # 若触发强制平仓，立刻扣一次额外成本（模拟被动平仓）/ apply extra cost if forced close
        if forced_flat:
            step_ret -= extra_close_cost

        if self.reward_clip is not None:
            step_ret = float(np.clip(step_ret, -self.reward_clip, self.reward_clip))

        # NAV 更新 / NAV update
        growth = float(np.exp(step_ret))
        if not np.isfinite(growth):
            growth = 1.0
        self.nav *= growth
        reward = step_ret

        # ===== E) 时间推进与强制平仓落地 / Time advance & enforce flat =====
        self.i += 1
        terminated = (self.i >= self.end)
        truncated = False

        # 若触发强平，则把下一步的目标仓位重置为 0，并清空跟踪器
        if forced_flat:
            self.position = 0.0
            self.entry_price = None
            self.entry_sign = 0
            self.peak_price = None
            self.trough_price = None

        info = {
            "nav": float(self.nav),
            "pos": float(self.position),
            "forced_flat": bool(forced_flat),
        }
        return self._get_obs(), reward, terminated, truncated, info


### Part 3: DRL Training

This part configures and trains the reinforcement learning agent using the A2C algorithm.

- **Feature Selection**:
  Choose the normalized `*_norm` columns from the preprocessed dataset as the input features for the environment.

- **Environment Setup**:
  Construct training and testing environments (`DummyVecEnv`) based on `MinuteTradingEnv`, with a window size of 60 to match the normalization window.

- **Model Configuration**:
  Initialize the A2C agent with an MLP policy. Define key hyperparameters such as learning rate, rollout length (`n_steps`), discount factor (`gamma`), GAE lambda, and gradient clipping.

- **Training and Saving**:
  Train the model for 1,000,000 timesteps and save the trained policy to a file for later evaluation.


In [None]:
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np

# ===== 1) 选环境用的特征列 / Select feature columns for environment =====
# 使用预处理生成的 *_norm 列 / Use the normalized *_norm columns generated in preprocessing
feat_cols_env = [c for c in train_df.columns if c.endswith("_norm")]

# 如果想把原始 OHLCV 也一起喂给模型（一般不需要），可这样并入：
# If you also want to feed raw OHLCV (usually not necessary), uncomment:
# feat_cols_env = feat_cols_env + ["open", "high", "low", "close", "volume"]

# 安全检查：不能有空列或 NaN/Inf / Sanity checks: ensure no empty list and no NaN/Inf
assert len(feat_cols_env) > 0, "No *_norm feature columns found for the environment."
for name, df_ in [("train_df", train_df), ("test_df", test_df)]:
    X = df_[feat_cols_env].to_numpy()
    if not np.isfinite(X).all():
        bad_cols = [c for c in feat_cols_env if not np.isfinite(df_[c]).all()]
        raise ValueError(f"{name} contains NaN/Inf in selected features: {bad_cols}")

# ===== 2) 可选：初始指标权重 / Optional: initial feature weights =====
# 举例：抬高 RSI 权重、略降 MACD、轻微增强布林上轨；名字需与列名完全一致
# Example: boost RSI, downweight MACD a bit, slightly boost BB_high. Names must match columns.
init_feature_weights = {
    "RSI_norm": 1,
    "MACD_norm": 1,
    "BB_high_norm": 1,
}

# ===== 3) 构建训练/测试环境（含权重控制） / Build training & test envs (with weights) =====
# 保持 window=60，与归一化窗口一致更合理 / Keep window=60 to match normalization window
train_env = DummyVecEnv([lambda: MinuteTradingEnv(
    train_df, feat_cols_env, window=60,
    # === 传入权重控制参数 / pass weight controller params ===
    feature_weights=init_feature_weights,
    normalize_weights=True,       # 将权重按均值归一，避免整体量级漂移 / normalize weights by mean
    # （可选）止损与风控参数 / optional risk-control params
    # stop_loss_pct=0.02, trailing_stop_pct=0.03,
)])
test_env  = DummyVecEnv([lambda: MinuteTradingEnv(
    test_df,  feat_cols_env, window=60,
    feature_weights=init_feature_weights,
    normalize_weights=True,
)])

# ===== 4) （可选）训练期动态调权回调 / Optional: dynamic feature-weight scheduler =====
# 用法：每次回调被触发时，根据当前 timesteps 返回要调整的权重字典
# Usage: at each callback, compute a dict of weights to update

def weight_schedule(total_steps_done: int) -> dict:
    """
    根据已训练步数返回‘增量权重’（只更新给出的键）；
    Return a partial dict of weights to update.
    例：前 300k 步逐步提高 RSI 权重到 1.8；之后维持不变。
    """
    # Linear warm-up from 1.5 -> 1.8 in first 300k steps (given we started at 1.5)
    target_rsi = 1.5 + 0.3 * min(total_steps_done, 300_000) / 300_000
    return {"RSI_norm": float(target_rsi)}

class FeatureWeightScheduler(BaseCallback):
    """
    简单的调权回调：每次 rollout 结束后调用调度函数，并更新环境的特征权重。
    Simple callback to update feature weights after each rollout.
    """
    def __init__(self, schedule_fn, normalize=True, verbose=0):
        super().__init__(verbose)
        self.schedule_fn = schedule_fn
        self.normalize = normalize

    def _on_rollout_end(self) -> None:
        steps = int(self.model.num_timesteps)
        new_w = self.schedule_fn(steps)
        # 由于我们用的是 DummyVecEnv，取第 0 个子环境更新即可
        self.training_env.envs[0].set_feature_weights(new_w, normalize=self.normalize)
        return None

fw_scheduler = FeatureWeightScheduler(weight_schedule, normalize=True)

# ===== 5) 配置 A2C 模型 / Configure A2C model =====
# 给一点熵（ent_coef）以增强探索，避免动作长期粘在边界
model = A2C(
    policy="MlpPolicy",
    env=train_env,
    learning_rate=3e-4,
    n_steps=5 * 60,      # 每次 rollout 步数 ~5 小时 / rollout length (~5 hours)
    gamma=0.99,
    gae_lambda=0.95,
    ent_coef=0.01,       # ↑ add exploration; was 0.0
    vf_coef=0.5,
    max_grad_norm=0.5,
    verbose=1,
    seed=42,
)

# ===== 6) 训练与保存模型 / Train and save the model =====
# 如果不想动态调权，把 callback=fw_scheduler 去掉即可
model.learn(total_timesteps=1_000_000, callback=fw_scheduler)  # 先跑 1e6 步感受一下 / run 1e6 steps
model.save("a2c_ethusdt_1m_no_sentiment.zip")


### Part 4: Evaluation

This part evaluates the trained agent on the test environment and reports detailed performance metrics and plots.

- **Rollout & Logging**:
  Run the trained policy on the test environment to collect NAV, position, and per-step rewards. Detect trades using a small position-change threshold to avoid counting floating-point jitter as trades.

- **Metrics**:
  Compute Final NAV, CAGR (by steps-per-year), Sharpe ratio (risk-free = 0), and Max Drawdown (based on NAV). Aggregate trade-level statistics (trade count, win rate, average win/loss).

- **Visualization**:
  Plot the equity curve (NAV) and the position time series to provide an intuitive view of behavior and performance.

- **Reproducibility Notes**:
  Use the same feature set as training (e.g., `*_norm` columns) and the same window length. Set `periods_per_year` to 525,600 for crypto 1-minute data (24×7) or ~98,280 for US equities (252 trading days × 390 minutes).


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

# ===== 1) 定义评估函数 / Define evaluation function =====
def evaluate_detailed(model, env,
                      periods_per_year: float = 525_600,
                      trade_eps: float = 1e-2):
    """
    periods_per_year: 年化步数 / Annualized steps
        - Crypto 1-min ≈ 525,600 (24×7)
        - US equities 1-min ≈ 98,280 (=252*390)
    trade_eps: 判定调仓阈值，避免把浮点微抖当成一次交易
               Threshold to detect position changes and avoid jitter trades
    """
    # ===== 2) 初始化状态 / Initialize state =====
    obs, _ = env.reset()
    done = truncated = False

    navs, poss, rewards = [], [], []
    trade_rets = []
    last_pos = 0.0
    entry_nav = None

    # ===== 3) 策略回放 / Roll out policy on env =====
    while not (done or truncated):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)

        nav = float(info["nav"])
        pos = float(info["pos"])

        navs.append(nav)
        poss.append(pos)
        rewards.append(float(reward))

        # —— 交易事件检测 / Detect trade events ——
        if abs(pos - last_pos) > trade_eps:
            is_flat_now = abs(pos) <= trade_eps
            was_flat    = abs(last_pos) <= trade_eps

            if was_flat and not is_flat_now:
                # 开仓 / Open position
                entry_nav = nav
            elif (not was_flat) and is_flat_now and (entry_nav is not None):
                # 平仓 / Close position
                trade_rets.append(nav / entry_nav - 1.0)
                entry_nav = None
            elif np.sign(last_pos) != np.sign(pos) and (entry_nav is not None):
                # 反手：先结算旧单，再以当前 NAV 为新开仓成本
                # Reverse: close old trade then reopen
                trade_rets.append(nav / entry_nav - 1.0)
                entry_nav = nav

        last_pos = pos

    navs = np.asarray(navs, dtype=float)
    poss = np.asarray(poss, dtype=float)

    # ===== 4) 期末强制结算 / Force-close remaining position at the end =====
    if entry_nav is not None and navs.size:
        trade_rets.append(navs[-1] / entry_nav - 1.0)
        entry_nav = None

    # ===== 5) 计算绩效指标 / Compute performance metrics =====
    # Step returns from NAV
    if navs.size >= 2:
        rets = navs[1:] / navs[:-1] - 1.0
    else:
        rets = np.array([], dtype=float)

    # Sharpe (annualized, rf=0)
    if rets.size > 1 and np.isfinite(rets).all():
        mu = rets.mean()
        sigma = rets.std(ddof=1)
        sharpe = (mu / sigma) * np.sqrt(periods_per_year) if sigma > 0 else np.nan
    else:
        sharpe = np.nan

    # Max Drawdown (positive magnitude)
    if navs.size > 0:
        roll_max = np.maximum.accumulate(navs)
        dd = navs / roll_max - 1.0
        maxdd = float(abs(dd.min()))
    else:
        maxdd = np.nan

    # CAGR by duration (steps / periods_per_year)
    if navs.size > 1:
        years = len(navs) / periods_per_year
        cagr = (navs[-1] / navs[0]) ** (1.0 / years) - 1.0 if years > 0 else np.nan
    else:
        cagr = np.nan

    # ===== 6) 交易统计 / Trade statistics =====
    n_trades = len(trade_rets)
    if n_trades > 0:
        wins = [r for r in trade_rets if r > 0]
        losses = [r for r in trade_rets if r <= 0]
        win_rate = len(wins) / n_trades if n_trades else np.nan
        avg_win = float(np.mean(wins)) if wins else np.nan
        avg_loss = float(np.mean(losses)) if losses else np.nan
    else:
        win_rate = np.nan
        avg_win = np.nan
        avg_loss = np.nan

    metrics = {
        "FinalNAV": float(navs[-1]) if navs.size else np.nan,
        "CAGR": float(cagr) if np.isfinite(cagr) else np.nan,
        "Sharpe": float(sharpe) if np.isfinite(sharpe) else np.nan,
        "MaxDD": float(maxdd) if np.isfinite(maxdd) else np.nan,
        "Trades": int(n_trades),
        "WinRate": float(win_rate) if np.isfinite(win_rate) else np.nan,
        "AvgWin": float(avg_win) if np.isfinite(avg_win) else np.nan,
        "AvgLoss": float(avg_loss) if np.isfinite(avg_loss) else np.nan,
        "AvgPos": float(np.mean(np.abs(poss))) if poss.size else np.nan,
    }

    # ===== 7) 可视化 / Visualization =====
    plt.figure(figsize=(12, 6))

    # Equity curve
    plt.subplot(2, 1, 1)
    plt.plot(navs, label="NAV")
    plt.title("Equity Curve")
    plt.legend()

    # Position over time
    plt.subplot(2, 1, 2)
    plt.plot(poss, label="Position")
    plt.title("Position over Time")
    plt.legend()

    plt.tight_layout()
    plt.show()

    return metrics


# ===== 8) 使用示例 / Usage example =====
# 注意：这里用的是 feat_cols_env（与训练一致），而不是旧的 feat_cols
# Note: use feat_cols_env (same features as training), not the old feat_cols
test_env = MinuteTradingEnv(
    test_df, feat_cols_env, window=60,
    fee_rate=0.0005, slippage=0.0002, pos_change_penalty=0.001
)

metrics = evaluate_detailed(model, test_env, periods_per_year=525_600, trade_eps=1e-2)

# 用 json 打印避免“0,”小数点渲染问题 / safer printing
print(json.dumps(metrics, indent=2))

