In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

# import sys
# sys.path.append('/home/shencj/workspace/code/nlp/Frame/ark-nlp/')

from ark_nlp.nn import Ernie
from ark_nlp.dataset import TMDataset
from ark_nlp.factory.task import TMTask, SequenceClassificationTask
from ark_nlp.factory.optimizer import get_default_bert_optimizer
from ark_nlp.processor.tokenizer.transfomer import PairTokenizer

In [2]:
import random
import numpy as np

In [3]:
def set_seed(seed):
    """
    设置随机种子
    :param seed:
    
    :return:
    """
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)

In [4]:
set_seed(2021)

In [5]:
# 目录地址

train_data_path = '/home/shencj/workspace/data/medical/CHIP2021/Task1/train.jsonl'
test_data_path = '/home/shencj/workspace/data/medical/CHIP2021/Task1/testa.txt'
submit_data_path = 'submit.txt'

### 一、数据读入与处理

#### 1. 数据读入

In [6]:
import numpy as np
import pandas as pd
import copy
import json
import codecs

# from utils import get_entity_bios
from ark_nlp.dataset import BaseDataset


def get_task_data(data_path):
    with codecs.open(data_path, mode='r', encoding='utf8') as f:
        reader = f.readlines(f)    
        
    data_list = []

    for dialogue_ in reader:
        dialogue_ = json.loads(dialogue_)
        for content_idx_, contents_ in enumerate(dialogue_['dialog_info']):

            terms_ = contents_['ner']

            if len(terms_) != 0:
                idx_ = 0
                for _, term_ in enumerate(terms_):
                    
                    entity_ = dict()

                    entity_['dialog_id'] = dialogue_['dialog_id']
                    
                    entity_['text_a'] = dialogue_['dialog_info'][content_idx_]['text']

                    if content_idx_ + 1 < len(dialogue_['dialog_info']):
                        entity_['text_a'] = entity_['text_a'] + dialogue_['dialog_info'][content_idx_+1]['text']
                    if content_idx_ - 1 >= 0:
                        entity_['text_a'] = dialogue_['dialog_info'][content_idx_-1]['text'] + entity_['text_a']
                        
                    entity_['text_b'] = term_['mention']
                    entity_['start_idx'] = term_['range'][0]
                    entity_['end_idx'] = term_['range'][1] - 1
                    
                    try:
                        entity_['label_b'] = term_['mention']
                    except:
                        print(contents_)
                        print(term_)
                    entity_['label'] = term_['attr']
                    idx_ += 1
                    
                    if entity_['label'] == '':
                        continue
                    
                    if len(entity_) == 0:
                        continue
                        
                    data_list.append(entity_)
            
    data_df = pd.DataFrame(data_list)
    
    data_df = data_df.loc[:,['text_a', 'text_b', 'start_idx', 'end_idx', 'label_b', 'label', 'dialog_id']]
    
    return data_df

In [7]:
data_df = get_task_data(train_data_path)

In [8]:
data_df.shape

(118976, 7)

In [9]:
from sklearn.model_selection import train_test_split 

X_train, X_dev = train_test_split(list(set(data_df['dialog_id'].tolist())))

train_data_df = data_df[data_df['dialog_id'].apply(lambda x: x in X_train)]
dev_data_df = data_df[data_df['dialog_id'].apply(lambda x: x in X_dev)]

train_data_df.reset_index(drop=True, inplace=True)
dev_data_df.reset_index(drop=True, inplace=True)

In [10]:
tm_train_dataset = TMDataset(train_data_df)
tm_dev_dataset = TMDataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [11]:
# 可以先创建词典，再加载入分词器
# 也可以使用分词器自动加载
# bert_vocab = transformers.AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
# tokenizer = TransfomerTokenizer(bert_vocab, max_seq_len=30)

In [12]:
tokenizer = PairTokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=180)

#### 4. ID化

In [13]:
tm_train_dataset.convert_to_ids(tokenizer)
tm_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [14]:
from transformers import BertConfig

bert_config = BertConfig.from_pretrained('nghuyong/ernie-1.0', 
                                         num_labels=len(tm_train_dataset.cat2id))

#### 2. 模型创建

In [15]:
torch.cuda.empty_cache()

In [16]:
dl_module = Ernie.from_pretrained('nghuyong/ernie-1.0',
                                  config=bert_config)

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing Ernie: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing Ernie from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Ernie from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Ernie were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this mode

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [17]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [18]:
optimizer = get_default_bert_optimizer(dl_module) 

#### 2. 任务创建

In [19]:
model = SequenceClassificationTask(dl_module, optimizer, 'ce', cuda_device=1)

#### 3. 训练

In [20]:
model.fit(tm_train_dataset, 
          tm_dev_dataset,
          lr=2e-5,
          epochs=1, 
          batch_size=batch_size
         )

  4%|▎         | 100/2797 [00:46<21:02,  2.14it/s]

[99/2797],train loss is:0.922122


  7%|▋         | 200/2797 [01:33<20:25,  2.12it/s]

[199/2797],train loss is:0.826494


 11%|█         | 300/2797 [02:21<19:43,  2.11it/s]

[299/2797],train loss is:0.768934


 14%|█▍        | 400/2797 [03:08<18:58,  2.11it/s]

[399/2797],train loss is:0.728550


 18%|█▊        | 500/2797 [03:56<18:14,  2.10it/s]

[499/2797],train loss is:0.698540


 21%|██▏       | 600/2797 [04:44<17:26,  2.10it/s]

[599/2797],train loss is:0.679917


 25%|██▌       | 700/2797 [05:31<16:41,  2.09it/s]

[699/2797],train loss is:0.662679


 29%|██▊       | 800/2797 [06:19<16:02,  2.08it/s]

[799/2797],train loss is:0.651655


 32%|███▏      | 900/2797 [07:07<15:23,  2.05it/s]

[899/2797],train loss is:0.638450


 36%|███▌      | 1000/2797 [07:56<14:37,  2.05it/s]

[999/2797],train loss is:0.630506


 39%|███▉      | 1100/2797 [08:44<13:31,  2.09it/s]

[1099/2797],train loss is:0.622484


 43%|████▎     | 1200/2797 [09:32<12:48,  2.08it/s]

[1199/2797],train loss is:0.613460


 46%|████▋     | 1300/2797 [10:20<12:18,  2.03it/s]

[1299/2797],train loss is:0.605594


 50%|█████     | 1400/2797 [11:09<11:14,  2.07it/s]

[1399/2797],train loss is:0.600252


 54%|█████▎    | 1500/2797 [11:57<10:22,  2.08it/s]

[1499/2797],train loss is:0.593298


 57%|█████▋    | 1600/2797 [12:45<09:35,  2.08it/s]

[1599/2797],train loss is:0.588580


 61%|██████    | 1700/2797 [13:34<08:46,  2.08it/s]

[1699/2797],train loss is:0.584810


 64%|██████▍   | 1800/2797 [14:23<08:10,  2.03it/s]

[1799/2797],train loss is:0.581599


 68%|██████▊   | 1900/2797 [15:12<07:09,  2.09it/s]

[1899/2797],train loss is:0.578146


 72%|███████▏  | 2000/2797 [16:01<06:32,  2.03it/s]

[1999/2797],train loss is:0.575369


 75%|███████▌  | 2100/2797 [16:50<05:36,  2.07it/s]

[2099/2797],train loss is:0.571160


 79%|███████▊  | 2200/2797 [17:38<04:46,  2.08it/s]

[2199/2797],train loss is:0.567543


 82%|████████▏ | 2300/2797 [18:26<03:58,  2.09it/s]

[2299/2797],train loss is:0.564077


 86%|████████▌ | 2400/2797 [19:15<03:23,  1.95it/s]

[2399/2797],train loss is:0.561753


 89%|████████▉ | 2500/2797 [20:05<02:27,  2.01it/s]

[2499/2797],train loss is:0.559353


 93%|█████████▎| 2600/2797 [20:53<01:34,  2.08it/s]

[2599/2797],train loss is:0.556748


 97%|█████████▋| 2700/2797 [21:41<00:46,  2.08it/s]

[2699/2797],train loss is:0.554015


100%|██████████| 2797/2797 [22:28<00:00,  2.07it/s]


epoch:[0],train loss is:0.551978 

classification_report: 
               precision    recall  f1-score   support

         不标注       0.65      0.73      0.69      5731
          其他       0.47      0.29      0.36      1737
          阳性       0.88      0.85      0.87     18605
          阴性       0.62      0.74      0.68      3430

    accuracy                           0.78     29503
   macro avg       0.66      0.65      0.65     29503
weighted avg       0.78      0.78      0.78     29503

confusion_matrix_: 
 [[ 4169    96  1237   229]
 [  237   504   420   576]
 [ 1766   279 15824   736]
 [  213   199   468  2550]]
test loss is:0.531644,test acc is:0.781175,f1_score is:0.647538


In [21]:
# model.fit(tm_train_dataset, 
#           tm_dev_dataset,
#           lr=2e-5,
#           epochs=1, 
#           batch_size=batch_size
#          )

<br>

### 四、模型验证与保存

#### 1. 模型验证

In [22]:
from ark_nlp.factory.predictor import TMPredictor

In [23]:
tm_predictor_instance = TMPredictor(model.module, tokenizer, tm_train_dataset.cat2id)

In [24]:
tm_predictor_instance.predict_one_sample(['医生:38.5摄氏度以上需要药物降温;医生:打得什么药物？;患者:没看;医生:不是所有生病发烧就非要输液打针;患者:打屁股的;医生:要弄清楚原因才能用药;', '38.5'], 
                                         return_proba=True)

[('不标注', 0.5857797265052795),
 ('阳性', 0.40820521116256714),
 ('其他', 0.0032178503461182117),
 ('阴性', 0.0027971782255917788)]

#### 2. 测试结果输出

In [25]:
from tqdm import tqdm

In [26]:
submit_result = []

with codecs.open(test_data_path, mode='r', encoding='utf8') as f:
    reader = f.readlines(f)    

data_list = []

for dialogue_ in tqdm(reader):
    dialogue_ = json.loads(dialogue_)
    for content_idx_, contents_ in enumerate(dialogue_['dialog_info']):

        terms_ = contents_['ner']

        if len(terms_) != 0:
            idx_ = 0
            for _ner_idx, term_ in enumerate(terms_):

                entity_ = dict()

                entity_['dialog_id'] = dialogue_['dialog_id']

                entity_['text_a'] = dialogue_['dialog_info'][content_idx_]['text']

                if content_idx_ + 1 < len(dialogue_['dialog_info']):
                    entity_['text_a'] = entity_['text_a'] + dialogue_['dialog_info'][content_idx_+1]['text']
                if content_idx_ - 1 >= 0:
                    entity_['text_a'] = dialogue_['dialog_info'][content_idx_-1]['text'] + entity_['text_a']

                entity_['text_b'] = term_['mention']
                entity_['start_idx'] = term_['range'][0]
                entity_['end_idx'] = term_['range'][1] - 1

                try:
                    entity_['label_b'] = term_['mention']
                except:
                    print(contents_)
                    print(term_)
                entity_['label'] = term_['attr']
                idx_ += 1

                dialogue_['dialog_info'][content_idx_]['ner'][_ner_idx]['attr'] = tm_predictor_instance.predict_one_sample([entity_['text_a'], entity_['text_b']])[0]
    submit_result.append(dialogue_)

100%|██████████| 2000/2000 [06:38<00:00,  5.02it/s]


In [27]:
with open(submit_data_path, 'w') as output_data:
    for json_content in submit_result:
        output_data.write(json.dumps(json_content, ensure_ascii=False) + '\n')