# XGBoost 模型训练

本 Notebook 用于训练和评估 XGBoost 分类模型，用于预测加密货币价格变动。

In [1]:
!pip install matplotlib seaborn
import sys
import os

# 添加项目路径
sys.path.insert(0, os.path.abspath('../src'))

import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置绘图样式
sns.set_style("whitegrid")

print("环境初始化完成")

Looking in indexes: https://pypi.org/simple/

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m26.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
环境初始化完成


## 1. 数据加载和探索

In [2]:
from collect.feature_handler import feature_handler
from config.settings import config

# 配置参数
INST_ID = "ETH-USDT-SWAP"
BAR = "1H"
LIMIT = 10000

# 从 MongoDB 加载特征数据
print(f"从 MongoDB 加载特征数据...")
print(f"  inst_id: {INST_ID}")
print(f"  bar: {BAR}")
print(f"  limit: {LIMIT}")

features = feature_handler.get_features(
    limit=LIMIT, 
    inst_id=INST_ID, 
    bar=BAR
)

print(f"\n加载完成，共 {len(features)} 条记录")

2026-02-03 20:25:04,236 - collect.mongodb_base - INFO - Connected to MongoDB at mongodb://localhost:27017
2026-02-03 20:25:04,237 - collect.mongodb_base - INFO - Database: technical_analysis


从 MongoDB 加载特征数据...
  inst_id: ETH-USDT-SWAP
  bar: 1H
  limit: 10000

加载完成，共 1464 条记录


In [3]:
# 转换为 DataFrame
df = pd.DataFrame(features)

print("数据形状:", df.shape)
print("\n数据类型:")
print(df.dtypes)

数据形状: (1464, 21)

数据类型:
_id                       object
close_1h_normalized      float64
volume_1h_normalized     float64
rsi_14_1h                float64
macd_line_1h             float64
macd_signal_1h           float64
hour_cos                 float64
hour_sin                 float64
day_of_week                int64
rsi_14_15m               float64
volume_impulse_15m       float64
macd_line_15m            float64
macd_signal_15m          float64
rsi_14_4h                float64
trend_continuation_4h    float64
macd_line_4h             float64
macd_signal_4h           float64
inst_id                      str
bar                          str
timestamp                  int64
label                      int64
dtype: object


In [None]:
# 查看前几条数据
print("前 5 条数据:")
df.head()

In [None]:
# 检查缺失值
print("缺失值统计:")
print(df.isnull().sum()[df.isnull().sum() > 0])

In [None]:
# 检查标签分布
if 'label' in df.columns:
    label_counts = df['label'].value_counts().sort_index()
    print("标签分布:")
    print(label_counts)
    
    # 绘制标签分布图
    plt.figure(figsize=(10, 6))
    label_counts.plot(kind='bar')
    plt.title('标签分布')
    plt.xlabel('标签')
    plt.ylabel('数量')
    plt.grid(axis='y', alpha=0.3)
    plt.show()
else:
    print("警告：数据中没有 'label' 字段")

In [4]:
# 统计信息
print("数值特征统计:")
numeric_cols = df.select_dtypes(include=[np.number]).columns
df[numeric_cols].describe()

数值特征统计:


Unnamed: 0,close_1h_normalized,volume_1h_normalized,rsi_14_1h,macd_line_1h,macd_signal_1h,hour_cos,hour_sin,day_of_week,rsi_14_15m,volume_impulse_15m,macd_line_15m,macd_signal_15m,rsi_14_4h,trend_continuation_4h,macd_line_4h,macd_signal_4h,timestamp,label
count,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0
mean,-0.583026,-0.186855,49.360724,-5.316656,-4.651908,-1.820038e-18,-1.5166979999999998e-19,3.060109,49.893511,0.984631,-2.007136,-1.885276,49.330055,0.002869,-9.032656,-6.825497,1767382000000.0,3.842213
std,0.26222,0.988203,17.360276,45.589232,43.503552,0.707335,0.707335,2.001146,15.053805,1.173732,21.473391,20.487542,19.667277,0.016951,78.16198,70.224545,1521953000.0,1.551695
min,-1.7838,-1.01,3.0,-155.839,-128.008,-1.0,-1.0,0.0,6.9,0.07,-119.063,-104.741,7.2,-0.05,-287.123,-234.978,1764749000000.0,1.0
25%,-0.71485,-0.746,38.175,-28.0905,-28.07375,-0.7071,-0.7071,1.0,39.675,0.42,-9.096,-9.08425,35.2,-0.01,-46.085,-35.601,1766066000000.0,3.0
50%,-0.59395,-0.526,50.3,-2.0355,-1.4455,0.0,0.0,3.0,50.4,0.68,0.0135,0.1985,48.4,0.01,0.83,3.556,1767382000000.0,4.0
75%,-0.438575,-0.0655,60.9,19.543,18.1855,0.7071,0.7071,5.0,60.2,1.07,7.538,7.211,63.5,0.01,50.262,49.559,1768699000000.0,5.0
max,-0.084,7.226,94.1,123.795,117.261,1.0,1.0,6.0,90.9,17.94,121.416,108.673,96.6,0.03,127.221,98.818,1770016000000.0,7.0


## 2. 模型训练

In [5]:
!pip install joblib xgboost scikit-learn
from models.xgboost_trainer import xgb_trainer

print("开始训练 XGBoost 模型...")
print(f"训练时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# 训练模型
results = xgb_trainer.train_model(
    inst_id=INST_ID,
    bar=BAR,
    limit=LIMIT,
    test_size=0.2,
    cv_folds=5,
    use_class_weight=True
)

print(f"\n训练完成！")

Looking in indexes: https://pypi.org/simple/

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m26.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


2026-02-03 20:25:35,769 - models.xgboost_trainer - INFO - Starting XGBoost model training
2026-02-03 20:25:35,781 - models.xgboost_trainer - INFO - Retrieved 1464 features from MongoDB
2026-02-03 20:25:35,792 - models.xgboost_trainer - INFO - Prepared training data: 1464 samples, 16 features
2026-02-03 20:25:35,792 - models.xgboost_trainer - INFO - Feature columns: ['close_1h_normalized', 'volume_1h_normalized', 'rsi_14_1h', 'macd_line_1h', 'macd_signal_1h', 'hour_cos', 'hour_sin', 'day_of_week', 'rsi_14_15m', 'volume_impulse_15m', 'macd_line_15m', 'macd_signal_15m', 'rsi_14_4h', 'trend_continuation_4h', 'macd_line_4h', 'macd_signal_4h']
2026-02-03 20:25:35,794 - models.xgboost_trainer - INFO - Label distribution:
label
0    102
1    239
2    252
3    327
4    336
5    155
6     53
Name: count, dtype: int64
2026-02-03 20:25:35,795 - models.xgboost_trainer - INFO - Number of features: 16
2026-02-03 20:25:35,804 - models.xgboost_trainer - INFO - Training set size: 1171
2026-02-03 20:25:3

开始训练 XGBoost 模型...
训练时间: 2026-02-03 20:25:35


2026-02-03 20:25:37,333 - models.xgboost_trainer - INFO - Performing cross-validation...
2026-02-03 20:25:37,336 - models.xgboost_trainer - INFO - Cross-validation fold 1/5
2026-02-03 20:25:39,037 - models.xgboost_trainer - INFO -   Fold 1 accuracy: 0.7404
2026-02-03 20:25:39,038 - models.xgboost_trainer - INFO - Cross-validation fold 2/5
2026-02-03 20:25:40,878 - models.xgboost_trainer - INFO -   Fold 2 accuracy: 0.7436
2026-02-03 20:25:40,878 - models.xgboost_trainer - INFO - Cross-validation fold 3/5
2026-02-03 20:25:42,500 - models.xgboost_trainer - INFO -   Fold 3 accuracy: 0.7137
2026-02-03 20:25:42,501 - models.xgboost_trainer - INFO - Cross-validation fold 4/5
2026-02-03 20:25:44,094 - models.xgboost_trainer - INFO -   Fold 4 accuracy: 0.7009
2026-02-03 20:25:44,095 - models.xgboost_trainer - INFO - Cross-validation fold 5/5
2026-02-03 20:25:45,691 - models.xgboost_trainer - INFO -   Fold 5 accuracy: 0.6752
2026-02-03 20:25:45,710 - models.xgboost_trainer - INFO - Accuracy: 0.7


训练完成！


## 3. 训练结果分析

In [6]:
# 打印训练结果
print("="*60)
print("训练结果摘要")
print("="*60)
print(f"准确率: {results['accuracy']:.4f}")
print(f"交叉验证准确率: {results['cv_mean_accuracy']:.4f} (+/- {results['cv_std_accuracy']*2:.4f})")
print(f"训练时间: {results['trained_at']}")

训练结果摘要
准确率: 0.7031
交叉验证准确率: 0.7148 (+/- 0.0510)
训练时间: 2026-02-03T20:25:45.710812


In [8]:
# 每个类别的置信度
print("\n各类别置信度:")
for label, confidence in results['class_confidence'].items():
    print(f"  类别 {label}: {confidence:.4f}")


各类别置信度:
  类别 1: 0.7688
  类别 2: 0.6771
  类别 3: 0.5398
  类别 4: 0.5131
  类别 5: 0.5324
  类别 6: 0.6001
  类别 7: 0.8669


In [None]:
# 绘制各类别置信度
plt.figure(figsize=(12, 6))
labels = list(results['class_confidence'].keys())
confidences = list(results['class_confidence'].values())
plt.bar(labels, confidences)
plt.xlabel('类别')
plt.ylabel('平均置信度')
plt.title('各类别平均置信度')
plt.grid(axis='y', alpha=0.3)
plt.show()

In [None]:
# 绘制混淆矩阵
conf_matrix = np.array(results['confusion_matrix'])

plt.figure(figsize=(12, 10))
sns.heatmap(
    conf_matrix, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=sorted(config.CLASSIFICATION_THRESHOLDS.keys()),
    yticklabels=sorted(config.CLASSIFICATION_THRESHOLDS.keys())
)
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('混淆矩阵')
plt.tight_layout()
plt.show()

In [9]:
# 打印详细的分类报告
print("\n详细分类报告:")
print("="*60)
class_report = results['classification_report']
for label in sorted(class_report.keys()):
    if label == 'accuracy' or label == 'macro avg' or label == 'weighted avg':
        print(f"\n{label}: {class_report[label]}")
    else:
        print(f"\n类别 {label}:")
        for metric, value in class_report[label].items():
            print(f"  {metric}: {value:.4f}")


详细分类报告:

类别 1:
  precision: 0.8947
  recall: 0.8500
  f1-score: 0.8718
  support: 20.0000

类别 2:
  precision: 0.8298
  recall: 0.8125
  f1-score: 0.8211
  support: 48.0000

类别 3:
  precision: 0.6383
  recall: 0.6000
  f1-score: 0.6186
  support: 50.0000

类别 4:
  precision: 0.5732
  recall: 0.7121
  f1-score: 0.6351
  support: 66.0000

类别 5:
  precision: 0.7368
  recall: 0.6269
  f1-score: 0.6774
  support: 67.0000

类别 6:
  precision: 0.7778
  recall: 0.6774
  f1-score: 0.7241
  support: 31.0000

类别 7:
  precision: 0.7143
  recall: 0.9091
  f1-score: 0.8000
  support: 11.0000

accuracy: 0.7030716723549488

macro avg: {'precision': 0.7378426110746014, 'recall': 0.7411424496703746, 'f1-score': 0.7354423750590106, 'support': 293.0}

weighted avg: {'precision': 0.7126462832994457, 'recall': 0.7030716723549488, 'f1-score': 0.7041930046362278, 'support': 293.0}


## 4. 特征重要性分析

In [10]:
# 获取特征重要性
importance = xgb_trainer.model.get_score(importance_type='gain')

# 转换为 DataFrame
importance_df = pd.DataFrame({
    'feature': list(importance.keys()),
    'importance': list(importance.values())
})
importance_df = importance_df.sort_values('importance', ascending=False)

print("特征重要性 Top 20:")
print(importance_df.head(20))

特征重要性 Top 20:
   feature  importance
13     f13    2.757510
7       f7    2.332153
0       f0    1.994593
15     f15    1.842307
12     f12    1.800961
14     f14    1.725730
3       f3    1.399534
4       f4    1.234479
10     f10    0.985291
11     f11    0.935887
2       f2    0.848002
6       f6    0.819492
1       f1    0.709581
8       f8    0.639467
5       f5    0.615157
9       f9    0.560023


In [None]:
# 绘制特征重要性
plt.figure(figsize=(12, 8))
top_n = 20
plt.barh(range(top_n), importance_df['importance'].head(top_n)[::-1])
plt.yticks(range(top_n), importance_df['feature'].head(top_n)[::-1])
plt.xlabel('重要性 (Gain)')
plt.title(f'特征重要性 Top {top_n}')
plt.tight_layout()
plt.show()

In [None]:
# 按特征类型分组统计
def categorize_feature(feature_name):
    if '1h' in feature_name.lower():
        return '1H'
    elif '15m' in feature_name.lower():
        return '15M'
    elif '4h' in feature_name.lower():
        return '4H'
    elif 'hour' in feature_name.lower() or 'day' in feature_name.lower():
        return 'Time'
    else:
        return 'Other'

importance_df['category'] = importance_df['feature'].apply(categorize_feature)
category_importance = importance_df.groupby('category')['importance'].sum().sort_values(ascending=False)

print("按特征类型分组的重要性:")
print(category_importance)

# 绘制饼图
plt.figure(figsize=(10, 8))
plt.pie(category_importance.values, labels=category_importance.index, autopct='%1.1f%%')
plt.title('各特征类型的重要性占比')
plt.show()

## 5. 模型预测测试

In [None]:
# 加载一些新数据进行预测
test_features = feature_handler.get_features(limit=5, inst_id=INST_ID, bar=BAR)

if test_features:
    print(f"测试预测，使用 {len(test_features)} 条新数据...")
    
    predictions, probabilities = xgb_trainer.predict(test_features)
    
    print("\n预测结果:")
    for i, (pred, prob, feature) in enumerate(zip(predictions, probabilities, test_features)):
        timestamp = feature.get('timestamp', 'N/A')
        actual_label = feature.get('label', 'N/A')
        print(f"\n样本 {i+1}:")
        print(f"  时间戳: {timestamp}")
        print(f"  实际标签: {actual_label}")
        print(f"  预测标签: {pred}")
        print(f"  预测概率: {prob}")
        print(f"  最高概率类别: {np.argmax(prob) + 1} (置信度: {np.max(prob):.4f})")

## 6. 模型保存信息

In [None]:
print("模型保存位置:")
print(f"  模型文件: {config.MODEL_SAVE_PATH}")
print(f"  Scaler 文件: {config.MODEL_SAVE_PATH.replace('.json', '_scaler.pkl')}")
print(f"  特征列文件: {config.MODEL_SAVE_PATH.replace('.json', '_features.json')}")
print(f"\n使用的特征列 ({len(xgb_trainer.feature_columns)} 个):")
for i, col in enumerate(xgb_trainer.feature_columns, 1):
    print(f"  {i}. {col}")

## 7. 参数调优建议

In [11]:
# 基于当前结果，给出参数调优建议
print("\n参数调优建议:")
print("="*60)

# 分析类别不平衡
if 'label' in df.columns:
    label_counts = df['label'].value_counts()
    min_count = label_counts.min()
    max_count = label_counts.max()
    imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
    
    if imbalance_ratio > 5:
        print(f"⚠️ 类别不平衡严重 (比例: {imbalance_ratio:.2f})")
        print("   建议: 使用 class_weight 或 SMOTE 进行平衡")

# 分析模型性能
if results['accuracy'] < 0.7:
    print("⚠️ 模型准确率较低")
    print("   建议: 检查特征质量，考虑增加更多特征")
elif results['accuracy'] > 0.9:
    print("✅ 模型准确率很高")
    print("   注意: 检查是否存在过拟合，考虑使用更多验证数据")

# 分析交叉验证稳定性
cv_std = results['cv_std_accuracy']
if cv_std > 0.05:
    print(f"⚠️ 交叉验证标准差较高 ({cv_std:.4f})")
    print("   建议: 模型稳定性有待提高，考虑调整超参数")

# 分析特征重要性
if len(importance_df) > 0:
    top_features = importance_df.head(5)['feature'].tolist()
    print("\n最重要的 5 个特征:")
    for i, feat in enumerate(top_features, 1):
        print(f"  {i}. {feat}")


参数调优建议:
⚠️ 类别不平衡严重 (比例: 6.34)
   建议: 使用 class_weight 或 SMOTE 进行平衡

最重要的 5 个特征:
  1. f13
  2. f7
  3. f0
  4. f15
  5. f12
