In [124]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torchmetrics import F1Score, R2Score

In [125]:
# 设定
save = False
classify = True
y_n = 3
multi_model_folder = r'D:\code\forecast_model\notebook\20240904_test_data_2'
# 概率阈值
mode = 'threshold' # 'threshold' or 'max'

In [126]:
def train_folder_filter(folder):
    if folder.endswith('.csv'): return False
    if folder == 'muilt_model': return False

    # 过滤
    if '5y' in folder:return False

    return True

In [127]:
def cal_balance_acc(y_pred, y_true):
    if isinstance(y_pred, pd.Series):
        y_pred = torch.tensor(y_pred.values)
        y_true = torch.tensor(y_true.values)

    unique_labels = [torch.tensor(i, device=y_pred.device) for i in range(y_n)]
    recall_values = []

    for label in unique_labels:
        true_positives = torch.sum((y_true == label) & (y_pred == label))
        false_negatives = torch.sum((y_true == label) & (y_pred != label))
        recall = true_positives / (true_positives + false_negatives)
        recall_values.append(recall)

    # 计算均衡 ACC
    balanced_acc = torch.mean(torch.stack(recall_values))
    return balanced_acc.item()

def class_accuracy(y_pred, y_true):
    if isinstance(y_pred, pd.Series):
        y_pred = torch.tensor(y_pred.values)
        y_true = torch.tensor(y_true.values)

    class_correct = [0] * y_n
    class_total = [0] * y_n
    
    for i in range(y_n):
        class_correct[i] = torch.logical_and(y_pred == i, y_pred == y_true).sum().item()
        class_total[i] = (y_true == i).sum().item()
    
    class_acc = [class_correct[i] / class_total[i] if class_total[i] > 0 else 0 for i in range(y_n)]
    
    return class_acc

def class_f1_score(y_pred, y_true):
    if isinstance(y_pred, pd.Series):
        y_pred = torch.tensor(y_pred.values)
        y_true = torch.tensor(y_true.values)

    # 计算每个类别的 F1 分数
    f1_score = F1Score(num_classes=y_n, average='none', task='multiclass').to(y_pred.device)  # 设置 average='none' 以计算每个类别的 F1 分数
    # 计算 F1 Score
    f1_score.update(y_pred, y_true)
    class_f1 = f1_score.compute()

    return class_f1.tolist()

def cal_weighted_f1(y_pred, y_true):
    if isinstance(y_pred, pd.Series):
        y_pred = torch.tensor(y_pred.values)
        y_true = torch.tensor(y_true.values)

    # 计算加权 F1 分数
    f1_score = F1Score(num_classes=y_n, average='weighted', task='multiclass').to(y_pred.device)
    return f1_score(y_pred, y_true).unsqueeze(0).item()
    

def cal_variance_weighted_r2(y_pred, y_true):
    if isinstance(y_pred, pd.Series):
        y_pred = torch.tensor(y_pred.values)
        y_true = torch.tensor(y_true.values)

    # 计算方差加权 R2
    r2_score = R2Score(multioutput='variance_weighted').to(y_pred.device)
    return r2_score(y_pred, y_true).item()

In [128]:
# 预测的标的列表
sample_file = os.path.join(multi_model_folder, [i for i in os.listdir(multi_model_folder) if train_folder_filter(i)][0])
print(sample_file)

codes = os.listdir(sample_file)
codes = set([i.split('_')[0] for i in codes if i.endswith('.csv') and i.count('_') == 2])
print(len(codes))
codes

D:\code\forecast_model\notebook\20240904_test_data_2\once_test_t0_datas_159509_100_0_T4x2_fp16
1


{'159509'}

In [129]:
# 读取所有模型的预测/及标签
muilt_predict_folder = os.path.join(multi_model_folder, 'muilt_model')
os.makedirs(muilt_predict_folder, exist_ok=True)

all_data = pd.DataFrame()
for code in tqdm(codes):
    # code = '159329'
    data = None
    for model_folder in os.listdir(multi_model_folder):
        if not train_folder_filter(model_folder):
            continue
        
        name = model_folder
        model_path = os.path.join(multi_model_folder, model_folder)
        for symbol_predict_file in os.listdir(model_path):
            if symbol_predict_file.endswith('.csv') and symbol_predict_file.startswith(code):
                file_path = os.path.join(model_path, symbol_predict_file)
                _data = pd.read_csv(file_path)
                if 'predict' in list(_data):
                    if None is data:
                        data = _data
                        data.rename(columns={'predict': name}, inplace=True)
                    else:
                        data[name] = _data['predict']
                else:
                    # 读取 threshold
                    if mode == 'threshold':
                        threshold_file = os.path.join(model_path, 'threshold.txt')
                        with open(threshold_file, 'r')as f:
                            thresholds = f.readline().strip().split(',')
                            thresholds = [float(i) for i in thresholds]

                        thresholds = torch.tensor(thresholds)
                        comb = torch.tensor([0, 1, 2])
                        
                        softmax_predictions = torch.from_numpy(_data.iloc[:, 2:].values)
                        thresholded_predictions = softmax_predictions > thresholds
                        thresholded_predictions_int = thresholded_predictions.int()

                        # 预测类别
                        y_pred = torch.argmax(thresholded_predictions_int[:, comb], dim=1)
                    else:
                        # 使用最大值
                        y_pred = _data.iloc[:, 2:].idxmax(axis=1).apply(lambda x: list(_data).index(x) - 2)

                    if None is data:
                        data = _data.iloc[:, :2]
                    data[name] = y_pred

                break

    # 投票预测
    data['predict'] = data.iloc[:, 2:].apply(lambda row: np.bincount(row).argmax(), axis=1)

    # 写入文件
    if save:
        _begin, _end = data.iloc[0, 0], data.iloc[-1, 0]
        data.loc[:, ['timestamp', 'target', 'predict']].to_csv(os.path.join(muilt_predict_folder, f'{code}_{_begin}_{_end}.csv'), index=False)

    # 合并
    all_data = pd.concat([all_data, data], ignore_index=True)

all_data

100%|██████████| 1/1 [00:00<00:00,  1.24it/s]


Unnamed: 0,timestamp,target,once_test_t0_datas_159509_100_0_T4x2_fp16,once_test_t0_datas_159509_100_1_T4x2_fp16,once_test_t0_datas_159509_100_2_T4x2_fp16,once_test_t0_datas_159509_20_0_T4x2_fp16,once_test_t0_datas_159509_20_1_T4x2_fp16,once_test_t0_datas_159509_20_2_T4x2_fp16,once_test_t0_datas_159509_50_0_T4x2_fp16,once_test_t0_datas_159509_50_1_T4x2_fp16,once_test_t0_datas_159509_50_2_T4x2_fp16,predict
0,1724639712,1,0,0,0,0,0,1,0,0,0,0
1,1724639715,1,0,0,0,0,0,1,0,0,0,0
2,1724639718,1,0,0,0,0,0,1,0,0,0,0
3,1724639721,1,0,0,0,0,0,1,0,0,0,0
4,1724639724,1,0,0,0,0,0,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
14331,1724999829,2,0,0,2,2,0,2,1,0,2,0
14332,1724999832,2,0,0,2,2,0,2,2,0,2,2
14333,1724999835,0,0,0,2,2,0,2,2,0,2,2
14334,1724999838,0,0,0,2,2,0,2,1,0,2,0


In [130]:
# 计算各个预测的 acc/f1
all_data.rename(columns={'predict':'multi'}, inplace=True)
model_names = list(all_data.iloc[:, 2:])
model_names

['once_test_t0_datas_159509_100_0_T4x2_fp16',
 'once_test_t0_datas_159509_100_1_T4x2_fp16',
 'once_test_t0_datas_159509_100_2_T4x2_fp16',
 'once_test_t0_datas_159509_20_0_T4x2_fp16',
 'once_test_t0_datas_159509_20_1_T4x2_fp16',
 'once_test_t0_datas_159509_20_2_T4x2_fp16',
 'once_test_t0_datas_159509_50_0_T4x2_fp16',
 'once_test_t0_datas_159509_50_1_T4x2_fp16',
 'once_test_t0_datas_159509_50_2_T4x2_fp16',
 'multi']

In [131]:
result_data = {
    'acc_0': [],
    'acc_1': [],
    'acc_2': [],

    'f1_0': [],
    'f1_1': [],
    'f1_2': [],

    'acc': [],
    'f1': [],
    'r2': []
}

for i in model_names:
    if classify:
        result_data['acc'].append(cal_balance_acc(all_data[i], all_data['target']))
        result_data['f1'].append(cal_weighted_f1(all_data[i], all_data['target']))

        accs = class_accuracy(all_data[i], all_data['target'])
        result_data['acc_0'].append(accs[0])
        result_data['acc_1'].append(accs[1])
        result_data['acc_2'].append(accs[2])

        f1s = class_f1_score(all_data[i], all_data['target'])
        result_data['f1_0'].append(f1s[0])
        result_data['f1_1'].append(f1s[1])
        result_data['f1_2'].append(f1s[2])
    else:
        result_data['r2'].append(cal_variance_weighted_r2(all_data[i], all_data['target']))

result_data = {i:result_data[i] for i in  result_data if result_data[i]}
result_data = pd.DataFrame(result_data, index=model_names)

# 计算列均值
mean_row = pd.DataFrame(result_data.iloc[:-1, :].mean()).T

# 在倒数第二行之前插入均值行
result_data = pd.concat([result_data.iloc[:-1], mean_row, result_data.iloc[-1:]])

result_data

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
once_test_t0_datas_159509_100_0_T4x2_fp16,0.814362,0.180783,0.179395,0.394372,0.232828,0.288876,0.391513,0.298802
once_test_t0_datas_159509_100_1_T4x2_fp16,0.471692,0.430916,0.297417,0.348487,0.353836,0.412518,0.400009,0.382064
once_test_t0_datas_159509_100_2_T4x2_fp16,0.38975,0.316834,0.346293,0.294661,0.297022,0.425089,0.350959,0.360807
once_test_t0_datas_159509_20_0_T4x2_fp16,0.471692,0.332186,0.551652,0.394517,0.356079,0.584351,0.451844,0.479754
once_test_t0_datas_159509_20_1_T4x2_fp16,0.471692,0.430916,0.297417,0.348487,0.353836,0.412518,0.400009,0.382064
once_test_t0_datas_159509_20_2_T4x2_fp16,0.257747,0.177607,0.70536,0.27649,0.219892,0.622854,0.380238,0.435578
once_test_t0_datas_159509_50_0_T4x2_fp16,0.490763,0.417152,0.426548,0.397778,0.366086,0.521518,0.444821,0.451589
once_test_t0_datas_159509_50_1_T4x2_fp16,0.484207,0.428534,0.398084,0.384934,0.375464,0.494097,0.436942,0.437279
once_test_t0_datas_159509_50_2_T4x2_fp16,0.38975,0.316834,0.346293,0.294661,0.297022,0.425089,0.350959,0.360807
0,0.471295,0.336863,0.394273,0.348265,0.316896,0.465212,0.40081,0.398749


In [116]:
_10y_simple = result_data
_10y_simple

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
test_t0_datas_10y_simple_109_T4x2_fp16,0.608113,0.549417,0.904335,0.551046,0.558212,0.912864,0.687288,0.856589
test_t0_datas_10y_simple_123_T4x2_fp16,0.606291,0.584203,0.894915,0.546486,0.556693,0.908888,0.695137,0.852734
test_t0_datas_10y_simple_150_T4x2_fp16,0.575993,0.541436,0.914113,0.553849,0.553267,0.916251,0.677181,0.85934
test_t0_datas_10y_simple_42_T4x2_fp16,0.599172,0.569675,0.899412,0.549124,0.551014,0.910518,0.68942,0.853938
test_t0_datas_10y_simple_55_T4x2_fp16,0.604801,0.557193,0.905596,0.55754,0.556566,0.91365,0.689197,0.857699
0,0.598874,0.560385,0.903674,0.551609,0.555151,0.912434,0.687644,0.85606
multi,0.593543,0.555351,0.910774,0.559282,0.560454,0.915852,0.686556,0.85998


In [107]:
_10y_each_sample = result_data
_10y_each_sample

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
test_t0_datas_10y_each_sample_109_T4x2_fp16,0.597517,0.585021,0.886245,0.53431,0.53776,0.903859,0.689594,0.846109
test_t0_datas_10y_each_sample_123_T4x2_fp16,0.590066,0.506446,0.912801,0.546835,0.544435,0.914929,0.669771,0.856997
test_t0_datas_10y_each_sample_150_T4x2_fp16,0.594205,0.566605,0.896551,0.538242,0.549896,0.908568,0.685787,0.851273
test_t0_datas_10y_each_sample_42_T4x2_fp16,0.570199,0.516268,0.919768,0.559409,0.540778,0.918126,0.668745,0.860527
test_t0_datas_10y_each_sample_55_T4x2_fp16,0.597682,0.505013,0.91631,0.55891,0.546804,0.917021,0.673002,0.859975
0,0.589934,0.535871,0.906335,0.547541,0.543934,0.912501,0.67738,0.854976
multi,0.581457,0.520565,0.921012,0.564494,0.552683,0.919407,0.674345,0.862883


In [97]:
_5y_2label_100 = result_data
_5y_2label_100

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
once_test_t0_datas_5y_2label_100_0_T4x2_fp16,0.643341,0.533038,0.375462,0.376565,0.324232,0.513325,0.51728,0.468147
once_test_t0_datas_5y_2label_100_1_T4x2_fp16,0.646497,0.528011,0.366484,0.372245,0.321431,0.505242,0.513664,0.461318
once_test_t0_datas_5y_2label_100_2_T4x2_fp16,0.27007,0.226718,0.548942,0.204404,0.165922,0.634329,0.348577,0.507886
0,0.519969,0.429255,0.430296,0.317738,0.270528,0.550965,0.45984,0.479117
multi,0.651415,0.448647,0.397792,0.358854,0.314535,0.533808,0.499285,0.478811


In [88]:
_5y_2label_50 = result_data
_5y_2label_50

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
once_test_t0_datas_5y_2label_50_0_T4x2_fp16,0.679144,0.493641,0.540427,0.333113,0.321158,0.669779,0.571071,0.603819
once_test_t0_datas_5y_2label_50_1_T4x2_fp16,0.615625,0.565094,0.529264,0.337112,0.30798,0.660143,0.569994,0.595293
once_test_t0_datas_5y_2label_50_2_T4x2_fp16,0.275344,0.225736,0.539712,0.156923,0.131919,0.654566,0.346931,0.556418
0,0.523371,0.428157,0.536468,0.275716,0.253686,0.661496,0.495998,0.585176
multi,0.657025,0.470214,0.536853,0.31481,0.3155,0.66628,0.554697,0.598584


In [79]:
_5y_2label_20 = result_data
_5y_2label_20

Unnamed: 0,acc_0,acc_1,acc_2,f1_0,f1_1,f1_2,acc,f1
once_test_t0_datas_5y_2label_20_0_T4x2_fp16,0.616185,0.555979,0.749079,0.323457,0.317198,0.834048,0.640414,0.778173
once_test_t0_datas_5y_2label_20_1_T4x2_fp16,0.822933,0.670281,0.380475,0.21024,0.227993,0.544026,0.624563,0.508602
once_test_t0_datas_5y_2label_20_2_T4x2_fp16,0.221461,0.147421,0.84423,0.161278,0.118744,0.877673,0.404371,0.797573
0,0.553526,0.457894,0.657928,0.231658,0.221312,0.751916,0.556449,0.694783
multi,0.648148,0.493552,0.718612,0.277989,0.316839,0.815441,0.620104,0.758921
