In [1]:
# ## install finrl library
# !pip install git+https://github.com/AI4Finance-Foundation/FinRL.git

# 导入必要库

In [2]:
# 导入必要的库
import pandas as pd
import numpy as np
import datetime
import yfinance as yf
import os
import warnings
import itertools

from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split
from finrl.config import INDICATORS

warnings.filterwarnings("ignore")


# 创建项目所需目录结构
def create_directories():
    directories = ["data/raw_data", "data/processed_data", "models", "results"]

    for directory in directories:
        os.makedirs(directory, exist_ok=True)
        print(f"创建目录: {directory}")


create_directories()

创建目录: data/raw_data
创建目录: data/processed_data
创建目录: models
创建目录: results


# 获取标普500成分股列表

In [3]:
# 获取标普500成分股列表
# 从Wikipedia获取标普500成分股列表
import requests
from bs4 import BeautifulSoup

url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser")
table = soup.find("table", {"id": "constituents"})

# 解析表格数据
tickers = []
company_names = []
sectors = []

for row in table.findAll("tr")[1:]:
    cells = row.findAll("td")
    ticker = cells[0].text.strip()
    company = cells[1].text.strip()
    sector = cells[2].text.strip()

    # 转换股票代码格式 (将 ticker.x 转换为 ticker-x)
    ticker = ticker.replace(".", "-")

    tickers.append(ticker)
    company_names.append(company)
    sectors.append(sector)

sp500_tickers_df = pd.DataFrame(
    {"Ticker": tickers, "Company": company_names, "Sector": sectors}
)

print(f"共获取 {len(sp500_tickers_df)} 支标普500成分股")
sp500_tickers_df.head()

共获取 503 支标普500成分股


Unnamed: 0,Ticker,Company,Sector
0,MMM,3M,Industrials
1,AOS,A. O. Smith,Industrials
2,ABT,Abbott Laboratories,Health Care
3,ABBV,AbbVie,Health Care
4,ACN,Accenture,Information Technology


# 下载股票历史数据

In [4]:
# 定义数据的时间范围 - 保持原有时间范围
TRAIN_START_DATE = "2015-01-01"    #测试集持续八年   训练集持续两年
TRAIN_END_DATE = "2023-01-01"  
TRADE_START_DATE = "2023-01-01"
TRADE_END_DATE = "2025-01-01"  
print(f"数据时间范围: {TRAIN_START_DATE} 至 {TRADE_END_DATE}")

# 按照示例代码下载股票数据，但增加了缓存功能
raw_data_file = "data/raw_data/sp500_raw_data_20150101~20250101.csv"

# 检查是否已有下载好的数据文件
if os.path.exists(raw_data_file):
    print(f"发现已有数据文件: {raw_data_file}")
    print("直接加载现有数据，跳过下载过程...")
    df_raw = pd.read_csv(raw_data_file)
    print(
        f"成功加载数据，共 {len(df_raw)} 条记录，包含 {df_raw['tic'].nunique()} 支股票"
    )
else:
    print("未找到现有数据文件，开始下载股票数据...")
    df_raw = YahooDownloader(
        start_date=TRAIN_START_DATE,
        end_date=TRADE_END_DATE,
        ticker_list=tickers,  # 使用所有标普500成分股
    ).fetch_data()

    # 保存下载的数据，方便下次使用
    os.makedirs(os.path.dirname(raw_data_file), exist_ok=True)
    df_raw.to_csv(raw_data_file, index=False)
    print(f"数据已下载并保存至 {raw_data_file}")
    print(
        f"成功下载数据，共 {len(df_raw)} 条记录，包含 {df_raw['tic'].nunique()} 支股票"
    )

# 显示数据前几行
df_raw.head()

数据时间范围: 2015-01-01 至 2025-01-01
未找到现有数据文件，开始下载股票数据...


[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%********

Shape of DataFrame:  (1232011, 8)
数据已下载并保存至 data/raw_data/sp500_raw_data_20150101~20250101.csv
成功下载数据，共 1232011 条记录，包含 503 支股票


Price,date,close,high,low,open,volume,tic,day
0,2015-01-02,37.353016,41.310001,40.369999,41.18,1529200,A,4
1,2015-01-02,24.320433,27.860001,26.8375,27.8475,212818400,AAPL,4
2,2015-01-02,43.156208,66.400002,65.440002,65.440002,5086100,ABBV,4
3,2015-01-02,36.915035,45.450001,44.639999,45.25,3216600,ABT,4
4,2015-01-02,18.539352,19.860001,19.426666,19.733334,1101600,ACGL,4


# 添加技术指标

In [5]:
# 严格按照示例代码添加技术指标
print("添加技术指标...")
fe = FeatureEngineer(
    use_technical_indicator=True,
    tech_indicator_list = INDICATORS,
    use_vix=True,
    use_turbulence=True,
    user_defined_feature = False
)

processed = fe.preprocess_data(df_raw)
print("技术指标添加完成")

添加技术指标...
Successfully added technical indicators


[*********************100%***********************]  1 of 1 completed


Shape of DataFrame:  (2515, 8)
Successfully added vix
Successfully added turbulence index
技术指标添加完成


# 数据划分与保存

In [6]:
# 创建完整的日期-股票组合
print("创建完整的日期-股票组合...")
list_ticker = processed["tic"].unique().tolist()
list_date = list(
    pd.date_range(processed["date"].min(), processed["date"].max()).astype(str)
)
combination = list(itertools.product(list_date, list_ticker))

processed_full = pd.DataFrame(combination, columns=["date", "tic"]).merge(
    processed, on=["date", "tic"], how="left"
)
processed_full = processed_full[processed_full["date"].isin(processed["date"])]
processed_full = processed_full.sort_values(["date", "tic"])
processed_full = processed_full.fillna(0)

print(f"完整数据集大小: {len(processed_full)}条记录")
processed_full.head()

# 划分训练集和测试集
print("划分训练集和测试集...")
train = data_split(processed_full, TRAIN_START_DATE, TRAIN_END_DATE)
trade = data_split(processed_full, TRADE_START_DATE, TRADE_END_DATE)
print(f"训练集大小: {len(train)}条记录")
print(f"交易集大小: {len(trade)}条记录")

# 直接保存到指定文件
print("保存数据集...")
train.to_csv("data/processed_data/train_data_20150101~20250101.csv")
trade.to_csv("data/processed_data/test_data_20150101~20250101.csv")
print("数据处理完成！")

# 总结处理结果
print("\n========== 数据处理总结 ==========")
print(f"1. 原始数据: {len(df_raw)} 条记录，{df_raw['tic'].nunique()} 支股票")
print(f"2. 添加技术指标后数据: {len(processed)} 条记录")
print(f"3. 完整处理后数据: {len(processed_full)} 条记录")
print(f"4. 训练集: {len(train)} 条记录 (保存至 data/processed_data/)")
print(f"5. 测试集: {len(trade)} 条记录 (保存至 data/processed_data/)")

创建完整的日期-股票组合...
完整数据集大小: 1187080条记录
划分训练集和测试集...
训练集大小: 950608条记录
交易集大小: 236472条记录
保存数据集...
数据处理完成！

1. 原始数据: 1232011 条记录，503 支股票
2. 添加技术指标后数据: 1187080 条记录
3. 完整处理后数据: 1187080 条记录
4. 训练集: 950608 条记录 (保存至 data/processed_data/)
5. 测试集: 236472 条记录 (保存至 data/processed_data/)
