## Metrics

汇总常见2分类的指标，例如: AUC，ROC曲线，ACC, 敏感性， 特异性，精确度，召回率，PPV, NPV, F1

具体的介绍，可以参考一下：https://blog.csdn.net/sunflower_sara/article/details/81214897

In [None]:
import os
import pandas as pd
from datetime import datetime
from onekey_algo import get_param_in_cwd

os.makedirs('img', exist_ok=True)
os.makedirs('results', exist_ok=True)
group_info = pd.read_csv(get_param_in_cwd('label_file'))[['ID', 'group']]
# group_info['ID'] = group_info['ID'].map(lambda x: x.replace('.nii.gz', ''))
group_info['group'] = group_info['group'].map(lambda x:x if x in ['train', 'val'] else 'test')
display(group_info['group'].value_counts())
group_info

In [None]:
import pandas as pd
import numpy  as np
import re
from onekey_algo.custom.components import metrics
from onekey_algo.custom.components.comp1 import draw_roc, normalize_df
from onekey_algo.custom.components.ugly import drop_error
from matplotlib import pyplot as plt

def get_group(x):
    x = os.path.basename(x)
    if x.startswith('train'):
        return 'train'
    elif x.startswith('val'):
        return 'val'
    else:
        return '_'.join(x.split('_')[:2])

def get_log(log_path, map2gz:bool = True):
    log_ = pd.read_csv(log_path, names=['fname', 'pred_score', 'pred_label', 'gt'], sep='\t')
    if map2gz:
        log_['ID'] = log_['fname'].map(lambda x: f"{os.path.basename(os.path.dirname(x))}.nii.gz")
    else:
        log_['ID'] = log_['fname'].map(lambda x: os.path.basename(os.path.dirname(x)))
    return log_

def map_mn(x):
    return x.replace('densen', 'DenseN').replace('resnet', 'ResNet').replace('vgg', 'VGG').replace('inception_v3', 'InceptionV3')

all_log_ = []
metrics_dfs = []
sel_idx = {'test': set(), 'val': set()}
epoch_mapping = {'resnet50': {'2.5D': 59, 'P': 19, 'V': 18, 'dwi': 19, 't1': 18, 't2': 17, },
                 'vgg19': {'2.5D': 61, 'P': 19, 'V': 18, 'dwi': 19, 't1': 18, 't2': 17, },
                 'resnet101': {'2.5D': 61, 'P': 10, 'V': 13, 'dwi': 5, 't1': 10, 't2': 12, },
                 'densenet121': {'2.5D': 55, 'P': 14, 'V': 12, 'dwi': 15, 't1': 17, 't2': 11, },
                 'densenet201': {'2.5D': 53, 'P': 14, 'V': 12, 'dwi': 15, 't1': 17, 't2': 11, },
                 'inception_v3': {'2.5D': 62, 'Pathomics': 0, 'V': 12, 'dwi': 15, 't1': 17, 't2': 11, },
                 'CrossFormer': {'2.5D': 84, 'Pathomics': 0, 'V': 18, 'dwi': 19, 't1': 18, 't2': 17, },
                 'TwinsSVT': {'2.5D': 61, 'P': 10, 'V': 13, 'dwi': 5, 't1': 10, 't2': 12, },
                 'SimpleViT': {'Pathomics': 1, 'P': 14, 'V': 12, 'dwi': 15, 't1': 17, 't2': 11, },
                 'ResNet50': {'3D': 14, 'P': 19, 'V': 18, 'dwi': 19, 't1': 18, 't2': 17, },
                 'ShuffleNet': {'3D': 18, 'P': 19, 'V': 18, 'dwi': 19, 't1': 18, 't2': 17, },
                 'ResNet101': {'3D': 35, 'P': 10, 'V': 13, 'dwi': 5, 't1': 10, 't2': 12, },
                 'DenseNet121': {'3D': 26, 'P': 14, 'V': 12, 'dwi': 15, 't1': 17, 't2': 11, }}

sel_models = ['SimpleViT', 'inception_v3']
for modal in get_param_in_cwd('mtype', ['Pathomics']):
    for epoch_ in range(63, 64):
        for cv in range(2, 3):
            model_root = os.path.join(get_param_in_cwd('model_root'))
            metric_results = []
            all_preds = []
            all_gts = []
            all_model_names = []
            for model in sel_models:
#             for model in ['DenseNet121', 'ResNet101', 'ResNet50', 'ShuffleNet']:
#             for model in os.listdir(model_root):
                if model != 'vgg19' and False:
                    continue
                all_pred = []
                all_gt = []
                all_groups = []
#                 val_log = pd.concat([get_log(os.path.join(model_root, model, f"viz/BST_TRAIN_RESULTS.txt")),
#                                     get_log(os.path.join(model_root, model, f"viz/BST_VAL_RESULTS.txt"))], axis=0)
#                 val_log = pd.concat([get_log(os.path.join(model_root, model, f"train/Epoch-{epoch_}.txt")),
#                                     get_log(os.path.join(model_root, model, f"valid/Epoch-{epoch_}.txt"))], axis=0)
                val_log = pd.concat([get_log(os.path.join(model_root, model, f"train/Epoch-{epoch_mapping[model][modal]}.txt")),
                                     get_log(os.path.join(model_root, model, f"valid/Epoch-{epoch_mapping[model][modal]}.txt"))], axis=0)
#                 display(val_log)
                val_log = pd.merge(val_log, group_info, on='ID', how='inner')
                val_log['model'] = f"{model}_{modal}"
                # display(val_log)
                ug_groups = get_param_in_cwd('subsets')
                ul_labels = np.unique(val_log['pred_label'])
#                 display(val_log)
                for g in ug_groups:
                    sub_group = val_log[val_log['group'] == g]
                    if g in ['val', 'test']:
                        if model == 'ShuffleNet' and modal == '' and False:
                            s_idx, = drop_error([sub_group['pred_label']], [sub_group['gt']], [sub_group['pred_score']], 
                                                  ratio=3 if g == 'val' else 4, 
                                                  verbose=True, )
                            sel_idx[g] = set(sub_group['fname'])-set(sub_group[s_idx]['fname'])
                            print(g, sub_group.shape, sub_group[s_idx].shape, sel_idx[g])
                        sub_group = sub_group[~sub_group['fname'].isin(sel_idx[g])]
                    print(modal, epoch_, cv, g, sub_group.shape)
                    sub_group['label-1'] = list(map(lambda x: x[0] if x[1] == 1 else 1-x[0], 
                                                    np.array(sub_group[['pred_score', 'pred_label']])))
                    sub_group['label-0'] = 1 - sub_group['label-1']
                    sub_group = normalize_df(sub_group, not_norm=[c for c in sub_group.columns if c != 'label-1'], method='minmax')
#                     sub_group[['ID', 'label-0', 'label-1']].to_csv(os.path.join('results', f'Slice_{model}_{g}.csv'), index=False)
                    all_groups.append(g)                    
                    all_log_.append(sub_group)
                    for ul in [1]:
                        pred_score = np.array(sub_group['label-1']) #if g == 'val' and model in ['resnet101', 'resnet50', 'vgg19'] else np.array(sub_group['label-1'])
                        gt = [1 if gt_ == ul else 0 for gt_ in np.array(sub_group['gt'])]
                        acc, auc, ci, tpr, tnr, ppv, npv, _, _, _, thres = metrics.analysis_pred_binary(gt, pred_score, use_youden=True)
                        ci = f"{ci[0]:.4f}-{ci[1]:.4f}"
                        metric_results.append([model, acc, auc, ci, tpr, tnr, ppv, npv, thres, modal, g])
                        all_pred.append(pred_score)
                        all_gt.append(gt)
                # 绘制每个模型的ROC
                draw_roc(all_gt, all_pred, labels=all_groups, title=f"Modal: {modal}, Model: {map_mn(model)}")
                plt.savefig(f'img/{modal}_{model}_roc.svg', bbox_inches='tight')
                plt.show()
                # 整合到所有模型汇总。
                all_preds.extend(all_pred)
                all_gts.extend(all_gt)
                all_model_names.append(model)
            for gi, g in enumerate(all_groups):
                draw_roc(all_gts[gi::len(all_groups)], all_preds[gi::len(all_groups)], 
                         labels=[map_mn(m) for m in all_model_names], 
                         title=f"Modal {modal}, Cohort {g}")
                plt.savefig(f'img/DTL_{g}_roc.svg', bbox_inches='tight')
                plt.show()
            metrics_df = pd.DataFrame(metric_results, 
                                      columns=['ModelName', 'Acc', 'AUC', '95% CI', 'Sensitivity', 'Specificity', 'PPV', 'NPV', 
                                               'Youden', 'Modal', 'Cohort'])
            display(metrics_df)
            metrics_dfs.append(metrics_df)
pd.concat(metrics_dfs, axis=0)

In [None]:
all_logs = pd.concat(all_log_, axis=0)
sel_log = all_logs[all_logs['model'].str.contains('SimpleViT')]
sel_log[['ID', 'label-1', 'pred_label', 'gt']].to_csv('E:/20240601-BiQiu/Pathomics/models/SimpleViT/ALL_DL_PREDICTIONS.csv', index=False)
sel_log
# all_log['ID'] = all_log['fname'].map(lambda x: os.path.basename(x))