In [None]:
'''
导入所需库
'''

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
from PIL import Image
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import pandas as pd
import pandas_market_calendars as mcal

# Data Pre-processing

In [None]:
'''
处理原始数据
1. 下载数据：https://wrds-www.wharton.upenn.edu/，包含列：PERMNO, HdrCUSIP, PrimaryExch, USIncFlg, Ticker, PERMCO, DlyCalDt, DlyCap, DlyRetx, DlyVol, DlyClose, DlyLow, DlyHigh, DlyOpen
2. 计算过去N日移动平均（5天，20天，60天），计算未来N日累积收益率（5天，20天，60天）
'''

df = pd.read_csv('./your/path.csv')
df = df.rename(columns = {'PERMNO':'id',
                        'DlyCalDt':'date',
                        'DlyCap':'cap',
                        'DlyRetx':'ret',
                        'DlyVol':'vol',
                        'DlyClose':'close',
                        'DlyLow':'low',
                        'DlyHigh':'high',
                        'DlyOpen':'open',})
df = df.drop(['HdrCUSIP', 'Ticker', 'PERMCO','PrimaryExch','USIncFlg'], axis = 1)
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values(by = ['id', 'date'])
df = df[['id', 'date', 'cap', 'open', 'close', 'high', 'low', 'vol', 'ret']]

for col in df.columns:
    if col not in ['id', 'date']:
        df[col] = pd.to_numeric(df[col], errors='coerce')

df['close1'] = df['close']
df['close1'] = df['close1'].fillna(0)

def calculate_ma(group):
    group = group.sort_values(by = 'date')
    group['ma5'] = group['close1'].rolling(window=5).mean().round(3)
    group['ma20'] = group['close1'].rolling(window=20).mean().round(3)
    group['ma60'] = group['close1'].rolling(window=60).mean().round(3)
    return group
df = df.groupby('id', group_keys=False).apply(calculate_ma)

def calculate_returns(group):
    group = group.sort_values(by = ['id', 'date'])
    for i in [5, 20, 60]:
        group[f'ret_{i}d'] = (
            group['ret']
            .rolling(i)
            .apply(lambda r: (1 + r).prod() - 1)
            .shift(-i)  # 未来i日收益（不含当日）
        )
    return group
df = df.groupby('id', group_keys=False).apply(calculate_returns)

df.drop(['close1'], axis = 1, inplace = True)
df = df.sort_values(by = ['id', 'date'])
df.to_feather('./data/train_data.feather')
# df.to_feather('./data/test_data.feather')
df

In [None]:
'''
下载日历
'''
nyse = mcal.get_calendar('NYSE')
df = nyse.valid_days('1993-01-01', '2020-12-31')
df = pd.DataFrame({'date': df})
df['date'] = pd.to_datetime(df['date'])
df['anchor_5'] = np.nan
df['anchor_20'] = np.nan
df['anchor_60'] = np.nan

# 计算 anchor_5 和 anchor_20
grouped_month = df.groupby([df['date'].dt.year, df['date'].dt.month])
for (y, m), group in grouped_month:
    idx = group.index.tolist()
    # anchor_20: 每月最后一个交易日标 1
    df.loc[idx[-1], 'anchor_20'] = 1
    # anchor_5: 从月末开始，每隔 5 个交易日标 1
    for pos in range(len(idx)-1, -1, -5):
        df.loc[idx[pos], 'anchor_5'] = 1

# 计算 anchor_60
grouped_year = df.groupby(df['date'].dt.year)
for y, group in grouped_year:
    for m in [3, 6, 9, 12]:
        month_group = group[group['date'].dt.month == m]
        if not month_group.empty:
            last_idx = month_group.index[-1]
            df.loc[last_idx, 'anchor_60'] = 1

mask1 = (df['date'].dt.year >= 1993) & (df['date'].dt.year <= 2000)
df1 = df.loc[mask1].sort_values('date').reset_index(drop=True)
df1['date'] = pd.to_datetime(df1['date']).dt.date
mask2 = (df['date'].dt.year >= 2001) & (df['date'].dt.year <= 2019)
df2 = df.loc[mask2].sort_values('date').reset_index(drop=True)
df2['date'] = pd.to_datetime(df2['date']).dt.date
df1.to_feather('./data/trading_days_anchor_1993_2000.feather')
print(df1.head())
df2.to_feather('./data/trading_days_anchor_2001_2019.feather')
print(df2.head())

# Chart Drawing

In [None]:
'''
绘图函数：
    1) 收盘价归一化重建
    2) 按比例缩放 open/high/low/ma
    3) 价格区和成交量区缩放
    4) 逐日绘制高低线、开收、MA 点和成交量
    5) 线性插值连接 MA
    6) 保存图片
    """
'''

def draw_chart(df, save_path,
               NUM_DAYS,
               IMG_HEIGHT,
               PRICE_HEIGHT,
               VOL_HEIGHT,
               ma_col,
               interp_points):

    COLUMNS_PER_DAY = 3

    # 1. 数据提取
    data = {
        'open'  : df['open'].values,
        'high'  : df['high'].values,
        'low'   : df['low'].values,
        'close' : df['close'].values,
        'volume': df['vol'].values,
        'ret'   : df['ret'].values,
        ma_col   : df[ma_col].values
    }

    # 2. 收盘价归一化
    norm_close = [1.0]
    for t in range(1, NUM_DAYS):
        norm_close.append(norm_close[-1] * (1 + data['ret'][t]))
    data['close'] = np.array(norm_close)

    # 3. 按比例缩放 open/high/low/ma
    valid_close = df['close'].replace(0, np.nan).values
    scale = data['close'] / valid_close
    scale[np.isinf(scale)] = np.nan
    for k in ['open', 'high', 'low', ma_col]:
        data[k] = data[k] * scale

    # 4. 计算价格区上下限
    valid_prices = np.concatenate([
        data['low'][~np.isnan(data['low'])],
        data['open'][~np.isnan(data['open'])],
        data['close'][~np.isnan(data['close'])],
        data['high'][~np.isnan(data['high'])],
        data[ma_col][~np.isnan(data[ma_col])]
    ])
    price_min = valid_prices.min() if valid_prices.size else 0
    price_max = valid_prices.max() if valid_prices.size else 0
    price_scale = ((PRICE_HEIGHT - 1) / (price_max - price_min)
                   if price_max != price_min else 0)

    # 5. 成交量区缩放
    valid_vol = data['volume'][~np.isnan(data['volume'])]
    vol_max   = valid_vol.max() if valid_vol.size else 0
    vol_scale = (VOL_HEIGHT / vol_max) if vol_max else 0

    # 6. 初始化画布
    img = np.zeros((IMG_HEIGHT, NUM_DAYS * COLUMNS_PER_DAY), dtype=np.uint8)

    # 7. 按天绘制
    ma_pts = []
    for day in range(NUM_DAYS):
        l, m, r = day*COLUMNS_PER_DAY, day*COLUMNS_PER_DAY+1, day*COLUMNS_PER_DAY+2

        # 高 - 低线
        h, lo = data['high'][day], data['low'][day]
        if not (np.isnan(h) or np.isnan(lo)):
            y_h = int(round((price_max - h) * price_scale))
            y_l = int(round((price_max - lo)* price_scale))
            img[min(y_h, y_l):max(y_h, y_l)+1, m] = 255

        # 开盘、收盘
        for col, key in ((l, 'open'), (r, 'close')):
            val = data[key][day]
            if not np.isnan(val):
                y = int(round((price_max - val) * price_scale))
                img[y, col] = 255

        # MA 点
        ma = data[ma_col][day]
        if not np.isnan(ma):
            y_ma = int(round((price_max - ma) * price_scale))
            ma_pts.append((m, y_ma))

        # 成交量
        vol = data['volume'][day]
        if not np.isnan(vol):
            h_vol = int(round(vol * vol_scale))
            img[IMG_HEIGHT - h_vol:, m] = 255

    # 8. 插值连接 MA（线性）
    if len(ma_pts) >= 2:
        xs, ys = zip(*sorted(ma_pts))
        f = interp1d(xs, ys, kind='linear', bounds_error=False)
        x_new = np.linspace(xs[0], xs[-1], interp_points)
        y_new = f(x_new)
        for xi_f, yi_f in zip(x_new, y_new):
            if np.isnan(yi_f): 
                continue
            xi, yi = int(round(xi_f)), int(round(yi_f))
            if 0 <= xi < img.shape[1] and 0 <= yi < PRICE_HEIGHT:
                img[yi, xi] = 255

    # 9. 保存
    Image.fromarray(img).save(save_path)


def Chart5d(df, save_path):
    draw_chart(df, save_path,
               NUM_DAYS=5,
               IMG_HEIGHT=32,
               PRICE_HEIGHT=26,
               VOL_HEIGHT=6,
               ma_col='ma5',
               interp_points=100)

def Chart20d(df, save_path):
    draw_chart(df, save_path,
               NUM_DAYS=20,
               IMG_HEIGHT=64,
               PRICE_HEIGHT=52,
               VOL_HEIGHT=12,
               ma_col='ma20',
               interp_points=200)

def Chart60d(df, save_path):
    draw_chart(df, save_path,
               NUM_DAYS=60,
               IMG_HEIGHT=96,
               PRICE_HEIGHT=78,
               VOL_HEIGHT=18,
               ma_col='ma60',
               interp_points=500)

In [None]:
'''
作图：读取数据，合并锚点信息，按 id+date 排序
配置每种窗口的参数，分组处理数据，生成图像
'''
# 1. 读取数据
df = pd.read_feather('./data/train_data.feather')
anchor_df = pd.read_feather('./data/trading_days_anchor_1993_2000.feather')

# 2. 合并锚点信息到主表
df = df.merge(
    anchor_df[['date', 'anchor_5', 'anchor_20', 'anchor_60']],
    on=['date'],
    how='left'
)

# 3. 按 id+date 排序
df = df.sort_values(['id', 'date']).reset_index(drop=True)

# 4. 配置每种窗口的参数：窗口大小 / 锚点列 / 绘图函数 / 输出目录
configs = [
    (5,  'anchor_5',  Chart5d,  './charts_train/5d_charts'),
    (20, 'anchor_5', Chart20d, './charts_train/20d_charts'),
    (60, 'anchor_5', Chart60d, './charts_train/60d_charts'),
]

for window, anchor_col, chart_func, output_dir in configs:
    os.makedirs(output_dir, exist_ok=True)
    
    # 分组处理
    for id_val, group in df.groupby('id'):
        group = group.reset_index(drop=True)
        
        # 找到所有锚点行的索引
        anchor_idxs = group.index[group[anchor_col] == 1.0].tolist()
        for idx in anchor_idxs:
            # 确保前面有足够的数据
            if idx >= (window - 1):
                window_df = group.iloc[idx - (window - 1): idx + 1]
                
                # 格式化锚点日期为 YYYYMMDD
                anchor_date = pd.to_datetime(window_df['date'].iloc[-1]).strftime('%Y%m%d')
                
                # 构造保存路径，后缀 .bat
                save_path = os.path.join(
                    output_dir,
                    f'id_{id_val}_{anchor_date}.png'
                )
                
                # 调用绘图函数
                chart_func(window_df, save_path)
    
    print(f"{window}d 图像生成完成！")

In [None]:
'''
随机查看图像和检查
'''
for i in [5, 20, 60]:
    image_folder = f'./charts_train/{i}d_charts'
    # image_folder = f'./charts_test/{i}d_charts'
    image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]

    random_image = random.choice(image_files)
    image_path = os.path.join(image_folder, random_image)
    print({image_path})
    display(Image(filename=image_path))

# Labeling and Spliting

In [None]:
import os
import pandas as pd

# 1. 读取主数据和锚点文件
df = pd.read_feather('train_data.feather')
anchors = pd.read_feather('trading_days_anchor_1993_2000.feather')

# 2. 合并锚点信息
df = df.merge(
    anchors[['date', 'anchor_5', 'anchor_20', 'anchor_60']],
    on=['date'],
    how='left'
).sort_values(['id', 'date']).reset_index(drop=True)

# 3. 定义每种窗口的配置：窗口长度 / 锚点列 / 输出路径 / 输出标签文件
configs = [
    (5,  'anchor_5',  './charts_train/5d_charts',  './labels_train/image_labels_i5.feather'),
    (20, 'anchor_20', './charts_train/20d_charts', './labels_train/image_labels_i20.feather'),
    (60, 'anchor_60', './charts_train/60d_charts', './labels_train/image_labels_i60.feather'),
]

for window, anchor_col, img_dir, out_feather in configs:
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(os.path.dirname(out_feather), exist_ok=True)

    image_labels = []

    for id_val, grp in df.groupby('id'):
        grp = grp.reset_index(drop=True)
        # 找到所有锚点行索引
        idxs = grp.index[grp[anchor_col] == 1.0].tolist()

        for idx in idxs:
            # 前面必须有 window-1 条数据
            if idx >= window - 1:
                win = grp.iloc[idx - (window - 1): idx + 1]

                # 计算三个 horizon 的 label：看窗口末期的 ret_Xd 是否 > 0
                label_5  = int(win['ret_5d'].iloc[-1]  > 0)
                label_20 = int(win['ret_20d'].iloc[-1] > 0)
                label_60 = int(win['ret_60d'].iloc[-1] > 0)

                # 锚点日期（窗口最后一天）
                anchor_date = pd.to_datetime(win['date'].iloc[-1]).strftime('%Y%m%d')

                # 对应的图像路径
                image_path = os.path.join(img_dir, f'id_{id_val}_{anchor_date}.png')

                if os.path.exists(image_path):
                    image_labels.append({
                        'image_path': image_path,
                        'id':         id_val,
                        'date':       win['date'].iloc[-1],
                        'label_5':    label_5,
                        'label_20':   label_20,
                        'label_60':   label_60,
                    })

    # 保存标签
    labels_df = pd.DataFrame(image_labels)
    labels_df.to_feather(out_feather)
    print(f"{window}d 图像标签已生成并保存到 {out_feather}")


In [None]:
def balance_and_split(labels_file, label_column, train_file, test_file):
    # 读取标签文件
    labels_df = pd.read_feather(labels_file)

    # 获取标签为 0 和 1 的数据
    label_0 = labels_df[labels_df[label_column] == 0]
    label_1 = labels_df[labels_df[label_column] == 1]

    # 确保选择相同数量的标签为 0 和 1 的样本
    num_samples = min(len(label_0), len(label_1))
    label_0 = label_0.sample(n=num_samples, random_state=42)
    label_1 = label_1.sample(n=num_samples, random_state=42)

    # 合并平衡后的数据并打乱
    balanced_df = pd.concat([label_0, label_1]).sample(frac=1, random_state=42).reset_index(drop=True)

    # 按照 70% 训练集和 30% 测试集划分
    train_df, test_df = train_test_split(
        balanced_df, train_size=0.7, stratify=balanced_df[label_column], random_state=42
    )

    # 确保训练集和测试集标签分布均衡
    print(f"训练集标签分布 ({label_column}):")
    print(train_df[label_column].value_counts())
    print(f"\n测试集标签分布 ({label_column}):")
    print(test_df[label_column].value_counts())

    # 保存训练集和测试集到 Feather 文件
    train_df.to_feather(train_file)
    test_df.to_feather(test_file)

    print(f"数据集 {label_column} 已划分并保存！")

In [None]:
horizons = [5, 20, 60]
labels = [5, 20, 60]

for h in horizons:
    for l in labels:
        balance_and_split(
            labels_file = f'./labels_train/image_labels_i{h}.feather',
            label_column = f'label_{l}',
            train_file = f'./labels_train/train_labels_i{h}r{l}.feather',
            test_file  = f'./labels_train/test_labels_i{h}r{l}.feather',
        )