In [132]:
import pandas as pd

In [133]:
data = pd.read_csv("all_hs300_stocks.csv")
all_hs300_data = pd.read_csv("all_hs300_stocks_not_connect.csv")

In [134]:
data['date'] = pd.to_datetime(data['date'])
all_hs300_data['date'] = pd.to_datetime(all_hs300_data['date'])

data.set_index(['date'], inplace=True)
all_hs300_data.set_index(['date'], inplace=True)

In [135]:
# 合并数据
combined_data = pd.concat([data, all_hs300_data])

# 根据索引排序
combined_data = combined_data.sort_index()

combined_data

Unnamed: 0_level_0,code,open,close,high,low,volume
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,sh.600000,53.00,53.55,55.05,51.81,13158390.0
2008-01-02,sz.000012,21.42,21.75,21.90,21.10,4640052.0
2008-01-02,sz.000009,15.82,15.97,16.14,15.54,8910692.0
2008-01-02,sz.000002,29.02,29.19,29.51,28.39,93879671.0
2008-01-02,sz.000001,38.50,37.98,38.74,37.65,20052473.0
...,...,...,...,...,...,...
2022-12-30,sh.601901,6.38,6.38,6.39,6.35,12006609.0
2022-12-30,sh.601899,9.88,10.00,10.09,9.85,132916479.0
2022-12-30,sh.601898,8.60,8.62,8.70,8.60,11139116.0
2022-12-30,sh.601881,9.32,9.29,9.38,9.26,13362296.0


In [136]:
def transform_code(code):
    if len(code) >= 3:
        return code[:2].upper() + code[3:]  # 前两个字符变为大写，删除第三个字符
    return code.upper()  # 如果长度不足，返回全大写

# 应用函数到 code 列
combined_data['code'] = combined_data['code'].apply(transform_code)

In [137]:
combined_data

Unnamed: 0_level_0,code,open,close,high,low,volume
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,SH600000,53.00,53.55,55.05,51.81,13158390.0
2008-01-02,SZ000012,21.42,21.75,21.90,21.10,4640052.0
2008-01-02,SZ000009,15.82,15.97,16.14,15.54,8910692.0
2008-01-02,SZ000002,29.02,29.19,29.51,28.39,93879671.0
2008-01-02,SZ000001,38.50,37.98,38.74,37.65,20052473.0
...,...,...,...,...,...,...
2022-12-30,SH601901,6.38,6.38,6.39,6.35,12006609.0
2022-12-30,SH601899,9.88,10.00,10.09,9.85,132916479.0
2022-12-30,SH601898,8.60,8.62,8.70,8.60,11139116.0
2022-12-30,SH601881,9.32,9.29,9.38,9.26,13362296.0


In [138]:
# 定义一个排序函数
def custom_sort(code):
    # 检查是否以 'SH' 开头
    return (code[:2] != 'SH', int(code[2:]))  # SH 优先，后面数字大小排序

# 按照 date 分组并排序
combined_data = (
    combined_data.assign(sort_key=combined_data['code'].apply(custom_sort))  # 添加排序关键字
      .sort_values(by=['date', 'sort_key'])  # 先按 date 排序，然后按 sort_key 排序
      .drop(columns='sort_key')  # 删除临时的排序关键字列
)

In [139]:
combined_data

Unnamed: 0_level_0,code,open,close,high,low,volume
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,SH600000,53.00,53.55,55.05,51.81,13158390.0
2008-01-02,SH600001,8.12,8.38,8.60,7.95,62828328.0
2008-01-02,SH600003,6.21,6.21,6.21,6.21,0.0
2008-01-02,SH600004,20.95,22.15,22.20,20.95,10329598.0
2008-01-02,SH600005,19.79,19.30,19.89,19.00,26984436.0
...,...,...,...,...,...,...
2022-12-30,SZ300896,568.09,566.35,570.60,557.00,1286736.0
2022-12-30,SZ300919,68.50,65.61,69.12,65.61,7426278.0
2022-12-30,SZ300957,151.00,149.24,151.70,146.89,3534444.0
2022-12-30,SZ300979,57.53,57.11,58.00,56.63,1839075.0


In [140]:
combined_data =  combined_data.reset_index()

combined_data

Unnamed: 0,date,code,open,close,high,low,volume
0,2008-01-02,SH600000,53.00,53.55,55.05,51.81,13158390.0
1,2008-01-02,SH600001,8.12,8.38,8.60,7.95,62828328.0
2,2008-01-02,SH600003,6.21,6.21,6.21,6.21,0.0
3,2008-01-02,SH600004,20.95,22.15,22.20,20.95,10329598.0
4,2008-01-02,SH600005,19.79,19.30,19.89,19.00,26984436.0
...,...,...,...,...,...,...,...
1075557,2022-12-30,SZ300896,568.09,566.35,570.60,557.00,1286736.0
1075558,2022-12-30,SZ300919,68.50,65.61,69.12,65.61,7426278.0
1075559,2022-12-30,SZ300957,151.00,149.24,151.70,146.89,3534444.0
1075560,2022-12-30,SZ300979,57.53,57.11,58.00,56.63,1839075.0


In [141]:
combined_data.set_index(['date','code'],inplace=True)

In [142]:
combined_data

Unnamed: 0_level_0,Unnamed: 1_level_0,open,close,high,low,volume
date,code,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,SH600000,53.00,53.55,55.05,51.81,13158390.0
2008-01-02,SH600001,8.12,8.38,8.60,7.95,62828328.0
2008-01-02,SH600003,6.21,6.21,6.21,6.21,0.0
2008-01-02,SH600004,20.95,22.15,22.20,20.95,10329598.0
2008-01-02,SH600005,19.79,19.30,19.89,19.00,26984436.0
...,...,...,...,...,...,...
2022-12-30,SZ300896,568.09,566.35,570.60,557.00,1286736.0
2022-12-30,SZ300919,68.50,65.61,69.12,65.61,7426278.0
2022-12-30,SZ300957,151.00,149.24,151.70,146.89,3534444.0
2022-12-30,SZ300979,57.53,57.11,58.00,56.63,1839075.0


In [143]:
combined_data.to_csv("all_hs300_stocks_complete.csv")

In [151]:
data = pd.read_csv("all_hs300_stocks_complete.csv")

In [152]:
data.set_index('date',inplace=True)

data

Unnamed: 0_level_0,code,open,close,high,low,volume
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,SH600000,53.00,53.55,55.05,51.81,13158390.0
2008-01-02,SH600001,8.12,8.38,8.60,7.95,62828328.0
2008-01-02,SH600003,6.21,6.21,6.21,6.21,0.0
2008-01-02,SH600004,20.95,22.15,22.20,20.95,10329598.0
2008-01-02,SH600005,19.79,19.30,19.89,19.00,26984436.0
...,...,...,...,...,...,...
2022-12-30,SZ300896,568.09,566.35,570.60,557.00,1286736.0
2022-12-30,SZ300919,68.50,65.61,69.12,65.61,7426278.0
2022-12-30,SZ300957,151.00,149.24,151.70,146.89,3534444.0
2022-12-30,SZ300979,57.53,57.11,58.00,56.63,1839075.0


In [163]:
idx = pd.IndexSlice

# 获取从 '2020-01-02' 到 '2020-01-05' 的所有股票数据
ls = data.loc[idx['2008-01-02']]['code'].tolist()

ls

['SH600000',
 'SH600001',
 'SH600003',
 'SH600004',
 'SH600005',
 'SH600006',
 'SH600007',
 'SH600008',
 'SH600009',
 'SH600010',
 'SH600011',
 'SH600012',
 'SH600015',
 'SH600016',
 'SH600017',
 'SH600018',
 'SH600019',
 'SH600020',
 'SH600021',
 'SH600022',
 'SH600026',
 'SH600027',
 'SH600028',
 'SH600029',
 'SH600030',
 'SH600031',
 'SH600033',
 'SH600035',
 'SH600036',
 'SH600037',
 'SH600048',
 'SH600050',
 'SH600058',
 'SH600060',
 'SH600062',
 'SH600066',
 'SH600068',
 'SH600078',
 'SH600085',
 'SH600087',
 'SH600088',
 'SH600096',
 'SH600098',
 'SH600100',
 'SH600102',
 'SH600104',
 'SH600108',
 'SH600110',
 'SH600117',
 'SH600118',
 'SH600123',
 'SH600125',
 'SH600132',
 'SH600143',
 'SH600150',
 'SH600151',
 'SH600153',
 'SH600160',
 'SH600161',
 'SH600170',
 'SH600171',
 'SH600177',
 'SH600183',
 'SH600188',
 'SH600190',
 'SH600196',
 'SH600208',
 'SH600210',
 'SH600219',
 'SH600220',
 'SH600221',
 'SH600236',
 'SH600256',
 'SH600266',
 'SH600269',
 'SH600270',
 'SH600271',

In [166]:
import pickle

with open('/home/cseadmin/mz/StockMtsPlatform/data/csi300/csi300_dl_train.pkl', 'rb') as f:
        dl_train = pickle.load(f)
dl_train_data =dl_train.data.reset_index()
ls1 = dl_train_data[dl_train_data['datetime'] == '2008-01-02']['instrument'].tolist()

In [167]:
ls1

['SH600000',
 'SH600001',
 'SH600004',
 'SH600005',
 'SH600006',
 'SH600007',
 'SH600008',
 'SH600009',
 'SH600010',
 'SH600011',
 'SH600012',
 'SH600015',
 'SH600016',
 'SH600017',
 'SH600018',
 'SH600019',
 'SH600020',
 'SH600021',
 'SH600022',
 'SH600026',
 'SH600027',
 'SH600028',
 'SH600029',
 'SH600030',
 'SH600031',
 'SH600033',
 'SH600035',
 'SH600036',
 'SH600037',
 'SH600048',
 'SH600050',
 'SH600058',
 'SH600066',
 'SH600068',
 'SH600085',
 'SH600087',
 'SH600088',
 'SH600096',
 'SH600098',
 'SH600100',
 'SH600102',
 'SH600104',
 'SH600108',
 'SH600109',
 'SH600110',
 'SH600111',
 'SH600115',
 'SH600117',
 'SH600118',
 'SH600125',
 'SH600143',
 'SH600150',
 'SH600151',
 'SH600153',
 'SH600158',
 'SH600161',
 'SH600170',
 'SH600177',
 'SH600183',
 'SH600188',
 'SH600190',
 'SH600196',
 'SH600208',
 'SH600210',
 'SH600219',
 'SH600220',
 'SH600221',
 'SH600236',
 'SH600256',
 'SH600266',
 'SH600269',
 'SH600270',
 'SH600271',
 'SH600282',
 'SH600307',
 'SH600308',
 'SH600309',

In [186]:
set(ls1).intersection(set(ls)).__len__()

247

In [179]:
set(ls1).__len__()

276