In [1]:
# 导入包
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import zipfile
import warnings
import json
import os
from collections import defaultdict
from tqdm import tqdm
# import mplfinance as mpf
import pickle
import argparse

In [3]:
# 指定文件路径
zip_file_path = '../data/data285396/初赛数据集.zip'
train_file_name = '数据集/初赛-训练集.csv'
test_file_name = '数据集/初赛-测试集.csv'

# 打开zip文件
with zipfile.ZipFile(zip_file_path) as z:
    with z.open(test_file_name) as f:
        test_df = pd.read_csv(f, encoding="gbk")


# 计算移动平均线
test_df['MA_5'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.rolling(window=5, min_periods=1).mean())
test_df['MA_10'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.rolling(window=10, min_periods=1).mean())
test_df['MA_20'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.rolling(window=20, min_periods=1).mean())
test_df['MA_30'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.rolling(window=30, min_periods=1).mean())

# 计算成交量的移动平均
test_df['Volume_MA_5'] = test_df.groupby('股票')['成交量'].transform(lambda x: x.rolling(window=5, min_periods=1).mean())
test_df['Volume_MA_10'] = test_df.groupby('股票')['成交量'].transform(lambda x: x.rolling(window=10, min_periods=1).mean())

# 计算换手率
test_df['换手率'] = (test_df['成交量'] / 10000) * 100  # 假设流通股总数为 10,000，具体可根据实际情况调整

# 计算 MACD
test_df['EMA_12'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.ewm(span=12, adjust=False).mean())
test_df['EMA_26'] = test_df.groupby('股票')['收盘价'].transform(lambda x: x.ewm(span=26, adjust=False).mean())
test_df['DIFF'] = test_df['EMA_12'] - test_df['EMA_26']
test_df['DEA'] = test_df.groupby('股票')['DIFF'].transform(lambda x: x.ewm(span=9, adjust=False).mean())
test_df['MACD Histogram'] = test_df['DIFF'] - test_df['DEA']

In [4]:
# 这里的处理逻辑同训练集，仍然是分组
grouper = pd.DataFrame([test_df["日期代码"].unique(), pd.Series((np.diff(test_df["日期代码"].unique()) != 1).cumsum()).shift(1)]).T.bfill().ffill()
grouper.columns = ['日期代码', '组别']
merged_test = pd.merge(test_df, grouper, on='日期代码', how='left')

# 更改列名, 相较于训练集，测试集少了Lable
merged_test.columns = ['Stock_name', 'Data_time', 'Open', 'High', 'Low', 'Close', 'Volume', 'Money', 'MA_5', 'MA_10', 'MA_20', 'MA_30', 'Volume_MA_5', 'Volume_MA_10', 'Trunover', 'EMA_12', 'EMA_26', 'DIFF', 'DEA', 'MACD Histogram', 'Group']
grouped_test = merged_test.groupby(['Stock_name', 'Group'])

In [7]:
# 设置分组长度
group_length = 5

In [8]:

# 对于 misshaped_features 的特别处理
def prepend_to_five_rows(df):
    # 检查DataFrame的行数
    current_rows = df.shape[0]
    
    # 如果行数不足5行，则进行填充
    if current_rows < group_length:
        # 计算需要填充的行数
        rows_to_add = group_length - current_rows
        # 使用第一行的数据进行填充
        prepend_df = pd.DataFrame([df.iloc[0].to_dict()] * rows_to_add)
        df = pd.concat([prepend_df, df], ignore_index=True)
    
    return df

# 遍历分组结果并存储在 defaultdict 中
# 对于一周不足 group_length 天的，取上一周数据，补足 group_length 天
# 如果上周数据也无法补足，则复制当周第一天数据补足 group_length 天，以保证输出
# 这里的逻辑选手可进一步深度优化
grouped_dict = defaultdict(dict)
# 预先计算每个股票的所有数据
stock_dfs = {stock: df.reset_index(drop=True) for stock, df in merged_test.groupby('Stock_name')}
cnt = 0
for (stock, group), sub_df in tqdm(grouped_test):
    if len(sub_df) != group_length:

        stock_df = stock_dfs[stock]
        sub_first_date = sub_df['Data_time'].values[0]
        index = stock_df[stock_df['Data_time'] == sub_first_date].index[0]
    
        delta_len = group_length - len(sub_df)
        
        if index >= delta_len:
            rows_to_concat = stock_df.iloc[index - delta_len:index]
            sub_df = pd.concat([rows_to_concat, sub_df], ignore_index=True)
        else:
            sub_df = prepend_to_five_rows(sub_df)
            
    
    # mpf绘图时需要标准化的日期
    # 日期标准化
    sub_df['Data_time'] = pd.date_range('1/10/2021', periods=len(sub_df), freq='D')
    # 设置日期索引
    sub_df.set_index('Data_time', inplace=True)
    
    grouped_dict[stock][group] = sub_df
    cnt = cnt + 1

100%|██████████| 144685/144685 [02:52<00:00, 838.59it/s] 


In [None]:
# 将 grouped_dict 保存为 pickle 文件
with open('test_grouped_dict.pkl', 'wb') as f:
    pickle.dump(grouped_dict, f)