In [2]:
!pip install ark_nlp

Collecting ark_nlp
  Downloading ark-nlp-0.0.9.tar.gz (78 kB)
[K     |████████████████████████████████| 78 kB 2.7 kB/s eta 0:00:03
Collecting zhon>=1.1.5
  Downloading zhon-1.1.5.tar.gz (99 kB)
[K     |████████████████████████████████| 99 kB 3.8 kB/s eta 0:00:02
Building wheels for collected packages: ark-nlp, zhon
  Building wheel for ark-nlp (setup.py) ... [?25ldone
[?25h  Created wheel for ark-nlp: filename=ark_nlp-0.0.9-py3-none-any.whl size=174427 sha256=35880c32e3563dde7380ace8391926dafaf057a0cacb5b0651801f1ce8cf1509
  Stored in directory: /root/.cache/pip/wheels/63/4c/22/b3283dcd140244b3e8527bbd2769ea9b06ac4e5d0ddeaf4381
  Building wheel for zhon (setup.py) ... [?25ldone
[?25h  Created wheel for zhon: filename=zhon-1.1.5-py3-none-any.whl size=84292 sha256=b539b80f7d7b0e53332cd40450c7332a27810410a4ecf0523420e1b1c4a0bc5e
  Stored in directory: /root/.cache/pip/wheels/d0/56/17/2675c4c7413a72bf173062e8d0a16503e3b2607745aa84988d
Successfully built ark-nlp zhon
Installing collec

In [3]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.re.casrel_bert import CasRelBert
from ark_nlp.model.re.casrel_bert import CasRelBertConfig
from ark_nlp.model.re.casrel_bert import Dataset
from ark_nlp.model.re.casrel_bert import Task
from ark_nlp.model.re.casrel_bert import get_default_model_optimizer
from ark_nlp.model.re.casrel_bert import Tokenizer
from ark_nlp.factory.loss_function import CasRelLoss

In [4]:
# 目录地址

train_data_path = '../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CMeIE/CMeIE_train.jsonl'
dev_data_path = '../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CMeIE/CMeIE_dev.jsonl'

#### 1. 数据读入

In [5]:
train_data_list = []

with codecs.open(train_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
        train_data_list.append(record_)

train_df = pd.DataFrame(train_data_list)

In [6]:
dev_data_list = []
counter = 0
with codecs.open(dev_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
            counter += 1
        dev_data_list.append(record_)
        
dev_df = pd.DataFrame(dev_data_list)

In [7]:
re_train_dataset = Dataset(train_df)
re_dev_dataset = Dataset(dev_df,
                         categories = re_train_dataset.categories,
                         is_train=False)

#### 2. 词典创建和生成分词器

In [8]:
tokenizer = Tokenizer(
    vocab='nghuyong/ernie-1.0',
    max_seq_len=128
)

Downloading:   0%|          | 0.00/62.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/89.0k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

#### 3. ID化

In [9]:
re_train_dataset.convert_to_ids(tokenizer)
re_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [10]:
bert_config = CasRelBertConfig.from_pretrained('nghuyong/ernie-1.0',
                                               num_labels=len(re_train_dataset.cat2id))

#### 2. 模型创建

In [11]:
dl_module = CasRelBert.from_pretrained('nghuyong/ernie-1.0',
                                       config=bert_config)

Downloading:   0%|          | 0.00/383M [00:00<?, ?B/s]

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing CasRelBert: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing CasRelBert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CasRelBert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of CasRelBert were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['obj_heads_linear.weight', 'sub_heads_linear.weight', '

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [12]:
optimizer = get_default_model_optimizer(dl_module) 

#### 2. 任务创建

In [13]:
model = Task(dl_module, optimizer, CasRelLoss(), cuda_device=1)

#### 3. 训练

In [14]:
model.fit(
    re_train_dataset,
    re_dev_dataset,
    lr=4e-5,
    epochs=20,
    batch_size=16
)

[99/897],train loss is:10.906454
[199/897],train loss is:6.213885
[299/897],train loss is:4.473707
[399/897],train loss is:3.559849
[499/897],train loss is:2.998150
[599/897],train loss is:2.614500
[699/897],train loss is:2.337759
[799/897],train loss is:2.126076
epoch:[0],train loss is:1.965480 

pred_triples:  set()
gold_triples:  {('急性胰腺炎', '影像学检查@检查', 'ERCP')}
pred_triples:  set()
gold_triples:  {('广泛性焦虑症', '鉴别诊断@疾病', '社交性焦虑'), ('广泛性焦虑症', '鉴别诊断@疾病', '分离性焦虑')}
pred_triples:  set()
gold_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
pred_triples:  set()
gold_triples:  {('胆囊穿孔', '死亡率@流行病学', '30%')}
pred_triples:  set()
gold_triples:  {('乙型肝炎', '预防@其他', '不献血'), ('乙型肝炎', '预防@其他', '不捐献器官或精液')}
pred_triples:  set()
gold_triples:  {('心绞痛', '药物治疗@药物', 'β受体阻滞剂')}
pred_triples:  set()
gold_triples:  {('感染性心内膜炎', '病因@社会学', '链球菌')}
pred_triples:  set()
gold_triples:  {('稳定型缺血性心脏疾病', '药物治疗

[799/897],train loss is:0.151825
epoch:[5],train loss is:0.151835 

pred_triples:  {('梗阻性胆总管结石', '影像学检查@检查', 'ERCP'), ('胆汁性急性胰腺炎', '影像学检查@检查', 'ERCP')}
gold_triples:  {('急性胰腺炎', '影像学检查@检查', 'ERCP')}
pred_triples:  {('焦虑情绪', '临床表现@症状', '躯体不适症状')}
gold_triples:  {('广泛性焦虑症', '鉴别诊断@疾病', '社交性焦虑'), ('广泛性焦虑症', '鉴别诊断@疾病', '分离性焦虑')}
pred_triples:  {('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '发病部位@部位', '关节炎@在其他关节（如踝关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
gold_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
pred_triples:  set()
gold_triples:  {('胆囊穿孔', '死亡率@流行病学', '30%')}
pred_triples:  set()
gold_triples:  {('乙型肝炎', '预防@其他', '不献血'), ('乙型肝炎', '预防@其他', '不捐献器官或精液')}
pred_triples:  {('抗心绞痛', '药物治疗@药物', 'β受体阻滞剂'), ('抗心绞痛', '药物治疗@药物', '抗心绞痛药物 * 抗心绞痛药物的主要目标是减少心绞痛症状，改善生活质量。稳定型缺血性心脏疾病@ * β受体阻滞剂')}
gold_triples:  {('心绞痛', '药物治疗@药物', 'β受体阻滞剂')}
pred_tri

correct_num: 4536, predict_num: 7997, gold_num: 10613
precision: 0.5672127047642795, recall: 0.42740035805144233, f1_score: 0.48747984949423906
[99/897],train loss is:0.077638
[199/897],train loss is:0.076752
[299/897],train loss is:0.076927
[399/897],train loss is:0.076696
[499/897],train loss is:0.077731
[599/897],train loss is:0.077617
[699/897],train loss is:0.078343
[799/897],train loss is:0.078826
epoch:[10],train loss is:0.079621 

pred_triples:  {('梗阻性胆总管结石', '影像学检查@检查', 'ERCP'), ('急性胰腺炎', '影像学检查@检查', 'ERCP')}
gold_triples:  {('急性胰腺炎', '影像学检查@检查', 'ERCP')}
pred_triples:  set()
gold_triples:  {('广泛性焦虑症', '鉴别诊断@疾病', '社交性焦虑'), ('广泛性焦虑症', '鉴别诊断@疾病', '分离性焦虑')}
pred_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
gold_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
pred_tri

correct_num: 5242, predict_num: 9590, gold_num: 10613
precision: 0.5466110531803905, recall: 0.49392254781870826, f1_score: 0.5189328317077827
[99/897],train loss is:0.048820
[199/897],train loss is:0.048794
[299/897],train loss is:0.050013
[399/897],train loss is:0.049265
[499/897],train loss is:0.049134
[599/897],train loss is:0.050565
[699/897],train loss is:0.051257
[799/897],train loss is:0.051187
epoch:[14],train loss is:0.051470 

pred_triples:  {('梗阻性胆总管结石', '影像学检查@检查', 'ERCP')}
gold_triples:  {('急性胰腺炎', '影像学检查@检查', 'ERCP')}
pred_triples:  set()
gold_triples:  {('广泛性焦虑症', '鉴别诊断@疾病', '社交性焦虑'), ('广泛性焦虑症', '鉴别诊断@疾病', '分离性焦虑')}
pred_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '发病部位@部位', '踝关节')}
gold_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
pred_triples:  {('胆囊穿孔', '死亡率@流行病学', '30%'), ('胆囊炎', '死亡率@流行病学', '30%')

correct_num: 5284, predict_num: 9509, gold_num: 10613
precision: 0.5556840887580129, recall: 0.49787995854140676, f1_score: 0.5251963025045634
[99/897],train loss is:0.031814
[199/897],train loss is:0.034828
[299/897],train loss is:0.036303
[399/897],train loss is:0.035919
[499/897],train loss is:0.037244
[599/897],train loss is:0.037039
[699/897],train loss is:0.036890
[799/897],train loss is:0.036648
epoch:[18],train loss is:0.036869 

pred_triples:  {('梗阻性胆总管结石', '影像学检查@检查', 'ERCP')}
gold_triples:  {('急性胰腺炎', '影像学检查@检查', 'ERCP')}
pred_triples:  set()
gold_triples:  {('广泛性焦虑症', '鉴别诊断@疾病', '社交性焦虑'), ('广泛性焦虑症', '鉴别诊断@疾病', '分离性焦虑')}
pred_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '发病部位@部位', '踝关节')}
gold_triples:  {('骨性关节炎', '发病部位@部位', '关节'), ('骨性关节炎', '发病部位@部位', '踝关节'), ('骨性关节炎', '病因@社会学', '创伤'), ('骨性关节炎', '发病部位@部位', '腕关节'), ('骨性关节炎', '病因@社会学', '结晶性关节病')}
pred_triples:  {('胆囊穿孔', '死亡率@流行病学', '30%')}
gold_triples:  {('胆囊穿孔', '

<br>

### 四、模型预测

In [15]:
from tqdm import tqdm
from ark_nlp.model.re.casrel_bert import Predictor

casrel_re_predictor_instance = Predictor(model.module, tokenizer, re_train_dataset.cat2id)

In [16]:
test_data_path = '../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CMeIE/CMeIE_test.jsonl'
schemas_data_path = '../mydata/data_origin/220602_0902-cblue-nlp-医疗nlp打榜/CMeIE/53_schemas.jsonl'
output_data_path = './CMeIE_test.jsonl'

In [17]:
result = []

with open(test_data_path, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line_ in lines:
        result.append(casrel_re_predictor_instance.predict_one_sample(eval(line_)['text']))

In [18]:
all_subject_type = []
all_predicate = []
all_shcemas = []
predicate2subject = {}
with open(schemas_data_path, 'r', encoding='utf-8') as fs:
    for jsonstr in fs.readlines():
        jsonstr = json.loads(jsonstr)
        # all_shcemas.append(jsonstr)
        
        predicate2subject[jsonstr['predicate']+'@'+jsonstr['object_type']] = jsonstr['subject_type']
        
    fs.close()
    
all_predicate = set(all_predicate)
with open(output_data_path, 'w', encoding='utf-8') as fw:
    with open(test_data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for index_, jsonstr in tqdm(enumerate(lines)):
            line = json.loads(jsonstr)
            results_len = []
            sentence = line['text']
            dict_list = result[index_]
            new = []
            for list_ in dict_list:
                for predicate_ in predicate2subject:
                    if list_[1] == predicate_:
                        if list_[-1] != '' and list_[-1] != '[UNK]':
                            result_dict = {
                                'predicate': predicate_.split('@')[0],
                                "subject": list_[0],
                                'subject_type': predicate2subject[predicate_],
                                "object": {"@value": list_[-1]},
                                'object_type': {"@value":predicate_.split('@')[-1]}
                                }
                        else:
                            continue
                        if result_dict not in new:
                            new.append(result_dict)
            if sum([item.count('。') for item in sentence]) >= 2:
                for item in new:
                    item['Combined'] = True
            else:
                for item in new:
                    item['Combined'] = False

            if len(new) == 0:
                new = [{
                    "Combined": '',
                    "predicate": '',
                    "subject": '',
                    "subject_type": '',
                    "object": {"@value": ""},
                    "object_type": {"@value": ""},
                }]
                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            else:

                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            fw.write(json.dumps(pred_dict, ensure_ascii=False) + '\n')
f.close()
fw.close()

4482it [00:00, 18378.52it/s]
