In [24]:
from simpletransformers.classification import (
    MultiLabelClassificationModel,
    MultiLabelClassificationArgs,
)
import pandas as pd
import numpy
import logging
from sklearn.model_selection import train_test_split
import torch

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
cuda_available = torch.cuda.is_available()
print(cuda_available)


True


In [26]:
# 读取数据至pandas
df = pd.read_excel("./data/信立泰 高血压-诊断定义.xlsx")
# df.drop(columns=["备注"], inplace=True)
df.replace(r"[^\u4e00-\u9fa5]", "", regex=True, inplace=True)  # 去除所有非中文
df["原始诊断"].fillna("", inplace=True)
df


Unnamed: 0,序号,原始诊断,处方张数,高血压,冠心病,糖尿病,血脂异常,卒中,慢性肾病,心力衰竭,高尿酸
0,1,高血压,81364,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2,无诊断,57580,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,3,高血压病,31515,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,4,冠状动脉粥样硬化性心脏病,26092,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,5,冠心病,8964,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
350761,350759,左足趾感染,1,,,,,,,,
350762,350760,左足趾骨骨折,1,,,,,,,,
350763,350762,坐骨神经痛高血压,1,,,,,,,,
350764,350763,坐骨神经痛腰椎病,1,,,,,,,,


In [27]:
# 准备数据格式

df_label = df[df["高血压"].notna()]  # Filter出打过标签的行
train_df = pd.DataFrame()
train_df["text"] = df_label["原始诊断"]
df_label = df_label.iloc[:, 3:].astype("int")
train_df["labels"] = df_label.values.tolist()  # 将多列标签数据转换为单列list
df_label.sum()


高血压     308
冠心病     224
糖尿病     103
血脂异常     89
卒中       78
慢性肾病     89
心力衰竭     64
高尿酸      28
dtype: int64

In [28]:
# 训练集和验证集分离
train_df, eval_df = train_test_split(train_df, test_size=0.2, random_state=1)
train_df, eval_df


(                                                    text  \
 27950                        膜性肾病肾病综合征高血压病级极高危银屑病低蛋白血症高甘   
 2263   肿瘤指标异常混合性痴呆高血压级很高危组型糖尿病胆囊切除术后状态慢性肾功能不全周围动脉硬化闭塞...   
 121                                                 肺部阴影   
 49                                               不稳定型心绞痛   
 305                               冠状动脉粥样硬化性心脏病高脂血症原发性高血压   
 ...                                                  ...   
 144                                                肾移植术后   
 69481            动脉粥样硬化并高脂血症高血压关节痛冠心病脑血管病前列腺增生头晕重度骨质疏松失眠   
 72                                                   脑出血   
 235                                                 重症肺炎   
 37                                                  高血压级   
 
                          labels  
 27950  [1, 0, 0, 1, 0, 0, 0, 0]  
 2263   [1, 0, 1, 0, 0, 1, 0, 1]  
 121    [0, 0, 0, 0, 0, 0, 0, 0]  
 49     [0, 1, 0, 0, 0, 0, 0, 0]  
 305    [1, 1, 0, 1, 0, 0, 0, 0]  
 ...                         ...  
 144    [0, 0, 0, 0, 

In [34]:
# 准备模型可选参数
model_args = MultiLabelClassificationArgs(num_train_epochs=50, overwrite_output_dir=True)


In [35]:
# 根据预训练模型生成模型
model = MultiLabelClassificationModel(
    "bert",
    "./data/chinese_wwm_ext_pytorch",
    num_labels=8,
    use_cuda=cuda_available,
    args=model_args,
)


Some weights of the model checkpoint at ./data/chinese_wwm_ext_pytorch were not used when initializing BertForMultiLabelSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultiLabelSequenceClassificat

In [36]:
# 训练模型
model.train_model(train_df)


INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
100%|██████████| 2/2 [00:02<00:00,  1.46s/it]
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_128_0_546
Epoch:   0%|          | 0/50 [00:00<?, ?it/s]INFO:simpletransformers.classification.classification_model:   Starting fine-tuning.
Epochs 0/50. Running Loss:    0.6102: 100%|██████████| 69/69 [00:06<00:00,  9.89it/s]
Epochs 1/50. Running Loss:    0.2255: 100%|██████████| 69/69 [00:06<00:00,  9.98it/s]
Epochs 2/50. Running Loss:    0.1966: 100%|██████████| 69/69 [00:06<00:00, 10.07it/s]
Epochs 3/50. Running Loss:    0.0696: 100%|██████████| 69/69 [00:06<00:00, 10.13it/s]
Epochs 4/50. Running Loss:    0.0323: 100%|██████████| 69/69 [00:06<00:00, 10.06it/s]
Epochs 5/50. Running Loss:    0.3323: 100%|██████████| 69/69 [00:06<00:00, 10.24it/s]
Epochs 6/50. Running Loss:    0.0175: 100%|██████████| 69/69 [00

(3450, 0.04051075134449976)

In [37]:
# 验证模型
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
result


INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_137
Running Evaluation: 100%|██████████| 18/18 [00:01<00:00, 14.46it/s]


{'LRAP': 0.9777256401343993, 'eval_loss': 0.14981819308983782}

In [38]:
# 使用模型进行预测
df_nolabel = df[~df["高血压"].notna()]  #  没有人为标签过的数据
pred_text = df_nolabel["原始诊断"].sample(n=10).values.tolist()  #  随机抽取10条数据的原始诊断字段
predictions, raw_outputs = model.predict(pred_text)

labels = df_label.columns.values.tolist()
for pred in predictions:
    for i in range(len(pred)):
        if pred[i] == 1:
            pred[i] = labels[i]

pred_text, predictions


INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_10
100%|██████████| 2/2 [00:00<00:00, 32.34it/s]


(['慢性肾衰竭腹膜透析评估',
  '脑梗死大脑前动脉狭窄高血压级',
  '肾移植术后高血压肝功能异常血脂异常',
  '高血压高脂血症糖耐量受损高尿酸血症冠状动脉粥样硬化性心脏病',
  '高血压原因待查低钾血症应检尽检核酸检测',
  '高血压级肾功能不全蛋白尿冠状动脉粥样硬化性心脏病高尿酸血症',
  '肾病综合征肾性骨营养不良肾性骨病高血压高脂血症',
  '高血压碘治疗后',
  '不稳定性心绞痛二度房室传导阻滞高血压级冠状动脉支架植入术后状态冠状动脉粥样硬化性心脏病心功能级分级',
  '动脉粥样硬化并高脂血症脑梗塞痤疮'],
 [[0, 0, 0, 0, 0, '慢性肾病', 0, 0],
  ['高血压', 0, 0, 0, '卒中', 0, 0, 0],
  ['高血压', 0, 0, '血脂异常', 0, '慢性肾病', 0, 0],
  ['高血压', '冠心病', '糖尿病', '血脂异常', 0, 0, 0, 0],
  ['高血压', 0, 0, 0, 0, 0, 0, 0],
  ['高血压', 0, '糖尿病', 0, 0, '慢性肾病', 0, '高尿酸'],
  ['高血压', 0, 0, '血脂异常', 0, '慢性肾病', 0, 0],
  ['高血压', 0, 0, 0, 0, 0, 0, 0],
  ['高血压', '冠心病', 0, 0, 0, 0, '心力衰竭', 0],
  [0, '冠心病', 0, '血脂异常', '卒中', 0, 0, 0]])

In [16]:
# 使用模型进行原始数据所有的预测
list_info = []
list_pred = []
labels = df_label.columns.values.tolist()
for index, row in df.head(1000).iterrows():
    pred_text = row["原始诊断"]
    if pd.isna((pred_text)):  # 如果原始诊断为na则标签均为0
        list_pred.append([0] * 8)
    elif pd.isna(row["高血压"]):  # 如果没有人为标记的条目使用模型预测
        predictions, raw_outputs = model.predict([pred_text])
        list_pred.append(predictions[0])
    else:  # 反之，则保留原人为标记结果
        list_pred.append(row[3:].tolist())
    list_info.append(row[:3])

df_pred = pd.DataFrame(list_pred, columns=labels)
df_info = pd.DataFrame(list_info, columns=["序号", "原始诊断", "处方张数"])
df_combined = pd.concat([df_info, df_pred], axis=1)
df_combined


INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_1
100%|██████████| 1/1 [00:00<00:00, 32.34it/s]
INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_1
100%|██████████| 1/1 [00:00<00:00, 71.63it/s]
INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_1
100%|██████████| 1/1 [00:00<00:00, 71.46it/s]
INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classi

Unnamed: 0,序号,原始诊断,处方张数,高血压,冠心病,糖尿病,血脂异常,卒中,慢性肾病,心力衰竭,高尿酸
0,1,高血压,81364,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2,无诊断,57580,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,3,高血压病,31515,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,4,冠状动脉粥样硬化性心脏病,26092,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,5,冠心病,8964,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
995,994,上感,55,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
996,998,高尿酸血症高血压高脂血症,55,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0
997,991,膝关节痛,55,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
998,996,多囊肾,55,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


In [39]:
# 将模型预测应用到所有数据

pred_text = df["原始诊断"].values.tolist()
predictions, raw_outputs = model.predict(pred_text)
pred_text, predictions


INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.
INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_128_0_350766
100%|██████████| 43846/43846 [16:37<00:00, 43.95it/s]


(['高血压',
  '无诊断',
  '高血压病',
  '冠状动脉粥样硬化性心脏病',
  '冠心病',
  '脑梗死',
  '糖尿病',
  '原发性高血压',
  '型糖尿病',
  '高血压级',
  '尿毒症',
  '高血压级',
  '高血压病',
  '高血压高脂血症',
  '高脂血症',
  '冠状动脉性心脏病',
  '心律失常',
  '肾病综合征',
  '心脏病',
  '冠心病心梗',
  '原发性高血压',
  '肾病',
  '冠状动脉支架植入后状态',
  '不稳定性心绞痛',
  '良性高血压',
  '脑血管病',
  '特发性原发性高血压',
  '心力衰竭',
  '高血压冠心病',
  '肺部感染',
  '心房颤动',
  '高血压级',
  '膜性肾病',
  '慢性肾炎综合征',
  '脑梗死后遗症',
  '肺炎',
  '冠状动脉粥样硬化',
  '高血压级',
  '高血压冠状动脉粥样硬化性心脏病',
  '心功能不全',
  '蛋白尿',
  '偏瘫脑梗死高血压高血脂冠心病',
  '糖尿病高血压',
  '慢性肾脏病期',
  '系统性红斑狼疮',
  '高血压糖尿病',
  '缺血性脑血管病',
  '慢性肾小球肾炎',
  '急性心肌梗死',
  '不稳定型心绞痛',
  '高尿酸血症',
  '冠心病高血压',
  '高血压病高脂血症',
  '头晕',
  '高血压冠状动脉粥样硬化性心脏病高脂血症',
  '冠状动脉粥样硬化性心脏病高血压',
  '陈旧性心肌梗死',
  '冠状动脉粥样硬化性心脏病',
  '高血压级',
  '急性冠脉综合征',
  '型糖尿病不伴有并发症',
  '脑梗塞',
  '糖尿病伴多个并发症',
  '慢性肾衰竭',
  '头晕和眩晕',
  '慢性肾功能不全',
  '高血压病冠心病',
  '睡眠障碍',
  '精神分裂症',
  '肾移植状态',
  '腔隙性脑梗死',
  '高血压病级极高危',
  '脑出血',
  '高血压高脂血症动脉粥样硬化',
  '后循环缺血',
  '糖尿病伴并发症其他特指的',
  '高血压慢性病',
  '脑血管病后遗症',
  '慢性肾炎',
  '腹膜透析',
  '冠状动脉粥样硬化性心脏病冠状动脉支架植入后状态',


In [40]:
labels = df_label.columns.values.tolist()
df_info = df.iloc[:, :3]
df_pred = pd.DataFrame(predictions, columns=labels)
df_combined = pd.concat([df_info, df_pred], axis=1)
df_combined.to_csv("./labeled_data.csv", index=False, encoding="utf_8_sig")
df_combined


Unnamed: 0,序号,原始诊断,处方张数,高血压,冠心病,糖尿病,血脂异常,卒中,慢性肾病,心力衰竭,高尿酸
0,1,高血压,81364,1,0,0,0,0,0,0,0
1,2,无诊断,57580,0,0,0,0,0,0,0,0
2,3,高血压病,31515,1,0,0,0,0,0,0,0
3,4,冠状动脉粥样硬化性心脏病,26092,0,1,0,0,0,0,0,0
4,5,冠心病,8964,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...
350761,350759,左足趾感染,1,0,0,0,0,0,0,0,0
350762,350760,左足趾骨骨折,1,0,0,0,0,0,0,0,0
350763,350762,坐骨神经痛高血压,1,1,0,0,0,0,0,0,0
350764,350763,坐骨神经痛腰椎病,1,0,0,0,0,0,0,0,0
