# 拆分数据

In [None]:
import pandas as pd
from onekey_algo import get_param_in_cwd
import numpy as np
from onekey_algo.custom.components.comp1 import merge_results

task = 'FULL'
label_data = pd.read_csv(get_param_in_cwd('label_file'))
label_data['group'] = label_data['group'].map(lambda x:x if x in ['train', 'val'] else 'test')
if task == 'FULL':
    clinic_features = pd.read_csv('data/clinic_sel.csv')
    path_features = pd.read_csv('features/Path_after_lasso.csv')
    rad_features = pd.read_csv('features/Habitat_after_lasso.csv')
    features = merge_results(path_features[[c for c in path_features if c not in ['group', 'label', 'prob10']]], 
                             rad_features[[c for c in rad_features if c not in ['group', 'label']]], 
                             clinic_features[[c for c in clinic_features if c not in ['group', 'label']]], )
    print(clinic_features.shape, path_features.shape, rad_features.shape)
    features.to_csv(f'features/fusion_{task}_features.csv', index=False)
    display(features)
for ug in get_param_in_cwd('subsets'):
    sug_group = label_data[label_data['group'] == ug]
    sub_features = pd.merge(features, sug_group[['ID']], on='ID', how='inner')
    print(ug, len(np.unique(sub_features['ID'])))
    sub_features.to_csv(f'features/{ug}.csv', index=False)

In [None]:
import os
from onekey_algo.fusion.MultiTransformer.run_model import train_categorical_model as clf_main
from collections import namedtuple
from onekey_algo import get_param_in_cwd

# 设置参数
train = r'features/train.csv'
val = r'features/val.csv'
tests = [r'features/test.csv']
target_file = get_param_in_cwd('label_file')
input_dim = features.shape[1] - 1
bags_size = 1
normalize = True
header = 0

for i in range(100):
    params = dict(train=train,
                  val=val,
                  tests=tests,
                  target_file=target_file,
                  j=0,
                  input_dim=input_dim,
                  bags_size=bags_size,
                  normalize=normalize,
                  header=header,
                  gpus=[0],
                  batch_size=16,
                  model_name='Transformer',
                  epochs=50,
                  init_lr=0.1,
                  optimizer='sgd',
                  model_root=os.path.join(get_param_in_cwd('radio_dir'), '../models'),
                  add_date=True,
                  retrain='',
                  iters_start=0,
                  iters_verbose=1,
                  save_per_epoch=False,
                  pretrained=True)
    # 训练模型
    Args = namedtuple("Args", params)
    clf_main(Args(**params))

In [None]:
label_data[label_data['group'] == subset]

In [None]:
import pandas as pd
import os
from onekey_algo.custom.components import metrics
from onekey_algo.custom.components.comp1 import draw_roc
from onekey_algo import get_param_in_cwd
import numpy as np
import matplotlib.pyplot as plt

def get_log(epoch):
    log_ = pd.concat([pd.read_csv(os.path.join(root, f'../train/Epoch-{epoch}_spec.csv')), 
                      pd.read_csv(os.path.join(root, f'../valid/Epoch-{epoch}_spec.csv'))], axis=0)
    log_.columns = ['ID', 'label-0', 'label-1']
    return log_

model_root = os.path.join(get_param_in_cwd('radio_dir'), '../models')
label_data = pd.read_csv(r'split_info/label-RND-7.csv')
label_data['group'] = label_data['group'].map(lambda x:x if x in ['train', 'val'] else 'test')
groups = []
all_ids = set()
for s in ['201748']:
    root = os.path.join(model_root, s, 'Transformer', 'viz')
    for epoch in range(42, 43):
        metrics_df = []
        all_gt = []
        all_pred = []
        for subset in get_param_in_cwd('subsets'):
            sub_group = get_log(epoch)            
            all_ids |= set(sub_group['ID'])
            sub_group.columns = ['ID', 'label-0', 'label-1']
            sub_group = pd.merge(sub_group, label_data[label_data['group'] == subset])
#             display(sub_group)
            acc, auc, ci, tpr, tnr, ppv, npv, _, _, _, thres = metrics.analysis_pred_binary(np.array(sub_group['label']), 
                                                                                            np.array(sub_group['label-1']))
            sub_group['pred_label'] = (sub_group['label-1'].astype(float) > 0.02).astype(int)
            sub_group['group'] = subset
            groups.append(sub_group)
            sub_group[['ID', 'label-0', 'label-1']].to_csv(f'results/Fusion_Transformer_{subset}.csv', index=False)
            all_gt.append(np.array(sub_group['label']))
            all_pred.append(np.array(sub_group['label-1']))
            acc, auc, ci, tpr, tnr, ppv, npv, _, _, _, thres = metrics.analysis_pred_binary(np.array(sub_group['label']), 
                                                                                            np.array(sub_group['label-1']))
            ci = f"{ci[0]:.4f}-{ci[1]:.4f}"
            metrics_df.append([acc, auc, ci, tpr, tnr, ppv, npv, thres, subset])
        draw_roc(all_gt, all_pred, labels=get_param_in_cwd('subsets'), title=f"Model: Transformer")
        plt.savefig(f'img/Fusion_Transformer_roc.svg', bbox_inches='tight')
        plt.show()
        metrics_df = pd.DataFrame(metrics_df,
                                  columns=['Acc', 'AUC', '95% CI', 'Sensitivity', 'Specificity', 'PPV', 'NPV', 'Youden', 'Cohort'])
        if True or float(metrics_df[metrics_df['Cohort'] == 'val']['AUC']) > 0.75 and float(metrics_df[metrics_df['Cohort'] == 'test']['AUC']) > 0.75:
            print(s, epoch)
            display(metrics_df)

In [None]:
group_info = pd.concat(groups, axis=0)
group_info[['ID', 'label', 'group']].to_csv('joinit_group.csv', index=False)
group_info

In [None]:
label_data[~label_data['ID'].isin(all_ids - set(group_info['ID']))].to_csv('group.csv', index=False)

In [None]:
label_data[~label_data['ID'].isin(all_ids - set(group_info['ID']))]