In [1]:
import json
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
from transformers import (BertModel, BertTokenizerFast)
from torch.nn.utils.rnn import pad_sequence
import time
import copy

from model import RawGlobalPointer, ERENet
from utils import sparse_multilabel_categorical_crossentropy, MetricsCalculator_CMeIE

In [2]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
def load_data(filename):
    # example=[{'text': text0, 'spo_list': [(s0, p0, o0, s0_t, o0_t), (s0_0, p0_0, o0_0, s0_0_t, o0_0_t), xxxxxx]}, {'text': text1, 'spo_list': [(s1, p1, o1, s1_t, o1_t), xxxxxx]}, xxxxxx]
    D = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            line = json.loads(line)
            D.append({
                "text": line["text"],
                "spo_list": [(spo["subject"], spo["predicate"], spo["object"]["@value"], spo["subject_type"],
                              spo["object_type"]["@value"])
                             for spo in line["spo_list"]]
            })
        return D

In [5]:
train_data = load_data('datasets_CMeIE/CMeIE_train.jsonl')
valid_data = load_data('datasets_CMeIE/CMeIE_dev.jsonl')

In [6]:
with open('datasets_CMeIE/53_schemas.jsonl', 'r', encoding='utf-8') as f:
    schema = {}
    for idx, item in enumerate(f):
        item = json.loads(item.rstrip())
        schema[item["subject_type"] + "_" + item["predicate"] + "_" + item["object_type"]] = idx
print(schema)  # 关系类型与id的字典映射

id2schema = {}
for k, v in schema.items():
    id2schema[v] = k

{'疾病_预防_其他': 0, '疾病_阶段_其他': 1, '疾病_就诊科室_其他': 2, '其他_同义词_其他': 3, '疾病_辅助治疗_其他治疗': 4, '疾病_化疗_其他治疗': 5, '疾病_放射治疗_其他治疗': 6, '其他治疗_同义词_其他治疗': 7, '疾病_手术治疗_手术治疗': 8, '手术治疗_同义词_手术治疗': 9, '疾病_实验室检查_检查': 10, '疾病_影像学检查_检查': 11, '疾病_辅助检查_检查': 12, '疾病_组织学检查_检查': 13, '检查_同义词_检查': 14, '疾病_内窥镜检查_检查': 15, '疾病_筛查_检查': 16, '疾病_多发群体_流行病学': 17, '疾病_发病率_流行病学': 18, '疾病_发病年龄_流行病学': 19, '疾病_多发地区_流行病学': 20, '疾病_发病性别倾向_流行病学': 21, '疾病_死亡率_流行病学': 22, '疾病_多发季节_流行病学': 23, '疾病_传播途径_流行病学': 24, '流行病学_同义词_流行病学': 25, '疾病_同义词_疾病': 26, '疾病_并发症_疾病': 27, '疾病_病理分型_疾病': 28, '疾病_相关（导致）_疾病': 29, '疾病_鉴别诊断_疾病': 30, '疾病_相关（转化）_疾病': 31, '疾病_相关（症状）_疾病': 32, '疾病_临床表现_症状': 33, '疾病_治疗后症状_症状': 34, '疾病_侵及周围组织转移的症状_症状': 35, '症状_同义词_症状': 36, '疾病_病因_社会学': 37, '疾病_高危因素_社会学': 38, '疾病_风险评估因素_社会学': 39, '疾病_病史_社会学': 40, '疾病_遗传因素_社会学': 41, '社会学_同义词_社会学': 42, '疾病_发病机制_社会学': 43, '疾病_病理生理_社会学': 44, '疾病_药物治疗_药物': 45, '药物_同义词_药物': 46, '疾病_发病部位_部位': 47, '疾病_转移部位_部位': 48, '疾病_外侵部位_部位': 49, '部位_同义词_部位': 50, '疾病_预后状况_预后': 51, '疾病_预后生存率_预后': 52}


In [7]:
class CustomDataset(Dataset):
    """定义数据集"""

    def __init__(self, items):
        self._items = items

    def __len__(self):
        return len(self._items)

    def __getitem__(self, index):
        item = self._items[index]

        return {'text': item['text'],
                'spo_list': item['spo_list']}


train_dataset = CustomDataset(items=train_data)
valid_dataset = CustomDataset(items=valid_data)

for i in valid_dataset:
    print(i)
    break

{'text': '急性胰腺炎@有研究显示，进行早期 ERCP （24 小时内）可以降低梗阻性胆总管结石患者的并发症发生率和死亡率； 但是，对于无胆总管梗阻的胆汁性急性胰腺炎患者，不需要进行早期 ERCP。', 'spo_list': [('急性胰腺炎', '影像学检查', 'ERCP', '疾病', '检查')]}


In [8]:
tokenizer_fast = BertTokenizerFast.from_pretrained('hfl/chinese-roberta-wwm-ext')
print(tokenizer_fast)

pretrained = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext')
print(pretrained.num_parameters())

Downloading:   0%|          | 0.00/19.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/269k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

PreTrainedTokenizerFast(name_or_path='hfl/chinese-roberta-wwm-ext', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Downloading:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


102267648


In [10]:
def search(pattern, sequence):
    """从序列sequence中寻找子序列pattern.如果找到,返回pattern第一个元素在sequence中的index,否则返回-1"""
    n = len(pattern)
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            return i, i + n - 1
    return -1


def get_collate_fn(tokenizer, max_len=512):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(data):
        batch_size = len(data)

        texts = [i['text'] for i in data]
        encoder_text = tokenizer(texts, max_length=max_len, truncation=True, padding=True, return_tensors='pt')
        input_ids, token_type_ids, attention_mask = encoder_text.values()

        spo_lists = [i['spo_list'] for i in data]
        entity_labels, head_labels, tail_labels = [], [], []

        for i in range(batch_size):
            entity_labels_temp = [set(), set()]
            head_labels_temp = [set() for _ in range(len(schema))]  # 每种关系的信息用一个列表表示
            tail_labels_temp = [set() for _ in range(len(schema))]
            spoes = set()

            # example:(s0, p0, o0, s0_t, o0_t)
            for s, p, o, s_t, o_t in spo_lists[i]:
                s = tokenizer(s, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False)[
                    'input_ids']
                o = tokenizer(o, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False)[
                    'input_ids']
                p = schema[s_t + "_" + p + "_" + o_t]  # SPO关系约束字典中该关系对应id
                # subject实体tokens列表(不含特殊token)在整个句子tokens列表(含特殊token)中的首尾位置
                s_range = search(s, input_ids[i].tolist())
                # object实体tokens列表(不含特殊token)在整个句子tokens列表(含特殊token)中的首尾位置
                o_range = search(o, input_ids[i].tolist())
                if s_range != -1 and o_range != -1:
                    spoes.add((*s_range, p, *o_range))  # subject、predicate、object三元组关系是唯一的
            for sh, se, p, oh, oe in spoes:
                # 该句子所有不同subject实体(故entity_labels_temp[1]类型为集合)tokens列表(不含特殊token)在整个句子tokens列表(含特殊token)中的首尾位置
                entity_labels_temp[0].add((sh, se))
                # 该句子所有不同object实体(故entity_labels_temp[1]类型为集合)tokens列表(不含特殊token)在整个句子tokens列表(含特殊token)中的首尾位置
                entity_labels_temp[1].add((oh, oe))
                # 该句子所有不同subject实体tokens列表在整个句子tokens列表的首位置, object实体tokens列表在整个句子tokens列表的首位置
                head_labels_temp[p].add((sh, oh))
                # 该句子所有不同subject实体tokens列表在整个句子tokens列表的尾位置, object实体tokens列表在整个句子tokens列表的尾位置
                tail_labels_temp[p].add((se, oe))

            for label in entity_labels_temp + head_labels_temp + tail_labels_temp:
                if not label:
                    label.add((0, 0))

            entity_labels_temp = [torch.tensor(list(i)) for i in entity_labels_temp]  # 内部set转换为list
            # entity_labels_temp.shape=[longest sequence, 2, 2]
            entity_labels_temp = torch.transpose(pad_sequence(entity_labels_temp, batch_first=True), 0,
                                                 1)  # 填充第0个维度,其他维度必须相等或可广播
            entity_labels.append(entity_labels_temp)

            head_labels_temp = [torch.tensor(list(i)) for i in head_labels_temp]
            head_labels_temp = torch.transpose(pad_sequence(head_labels_temp, batch_first=True), 0, 1)
            head_labels.append(head_labels_temp)
            tail_labels_temp = [torch.tensor(list(i)) for i in tail_labels_temp]
            tail_labels_temp = torch.transpose(pad_sequence(tail_labels_temp, batch_first=True), 0, 1)
            tail_labels.append(tail_labels_temp)

            # entity_labels.shape=[batch_size, 2, longest sequence, 2]
        entity_labels = torch.transpose(pad_sequence(entity_labels, batch_first=True), 1, 2)
        # head_labels.shape=[batch_size, len(schema), longest sequence, 2]
        head_labels = torch.transpose(pad_sequence(head_labels, batch_first=True), 1, 2)
        tail_labels = torch.transpose(pad_sequence(tail_labels, batch_first=True), 1, 2)
        return input_ids, attention_mask, token_type_ids, entity_labels, head_labels, tail_labels, texts, spo_lists

    return collate_fn


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=get_collate_fn(tokenizer_fast))
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, collate_fn=get_collate_fn(tokenizer_fast))

for i in train_loader:
    print('input_ids.shape', i[0].shape)
    print('entity_labels.shape', i[3].shape)
    print('head_labels.shape', i[4].shape)
    print('tail_labels.shape', i[5].shape)
    break

input_ids.shape torch.Size([16, 207])
entity_labels.shape torch.Size([16, 2, 11, 2])
head_labels.shape torch.Size([16, 53, 9, 2])
tail_labels.shape torch.Size([16, 53, 8, 2])


In [11]:
hidden_size = pretrained.config.hidden_size
mention_detect = RawGlobalPointer(hidden_size, 2, 64).to(device)  # 不提取实体类型(只识别subject、object对应的实体)
s_o_head = RawGlobalPointer(hidden_size, len(schema), 64, RoPE=False, tril_mask=False).to(
    device)  # 不需要设置tril_mask=False
s_o_tail = RawGlobalPointer(hidden_size, len(schema), 64, RoPE=False, tril_mask=False).to(
    device)  # 不需要设置tril_mask=False
net = ERENet(copy.deepcopy(pretrained), mention_detect, s_o_head, s_o_tail).to(device)

optimizer = torch.optim.AdamW(net.parameters(), lr=2e-5)

In [12]:
def extract_spoes(logits1, logits2, logits3, texts, tokenizer, id2predicate):
    logits1 = logits1.data.cpu().numpy()
    logits2 = logits2.data.cpu().numpy()
    logits3 = logits3.data.cpu().numpy()
    batch_size = logits1.shape[0]
    offset_mapping = tokenizer(texts, return_offsets_mapping=True)['offset_mapping']

    # 序列开头与结尾特殊token('[CLS]', '[SEP]')处元素设置为无穷小
    logits1[:, :, [0, -1]] -= np.inf
    logits1[:, :, :, [0, -1]] -= np.inf
    subjects, objects = [set() for _ in range(batch_size)], [set() for _ in range(batch_size)]
    for b, l, h, t in zip(*np.where(logits1 > 0.0)):  # 阈值(threshold)设置为0.0
        if l == 0:  # 不提取实体类型(只识别subjects、objects对应的实体)
            subjects[b].add((h, t))
        else:
            objects[b].add((h, t))

    spoes = [set() for _ in range(batch_size)]
    for b in range(batch_size):
        offset_mapping_b = offset_mapping[b]
        text_b = texts[b]
        # 计算subjects[b]与objects[b]所有可能关系的笛卡尔组合
        for sh, st in subjects[b]:
            for oh, ot in objects[b]:
                p1s = np.where(logits2[b, :, sh, oh] > 0.0)[0]  # 阈值(threshold)设置为0.0
                p2s = np.where(logits3[b, :, st, ot] > 0.0)[0]  # 阈值(threshold)设置为0.0
                # 含义:首S(s_h,o_h|p) > 0 且 尾S(s_t,o_t|p) > 0
                ps = set(p1s) & set(p2s)
                for p in ps:
                    sht_str = text_b[offset_mapping_b[sh][0]: offset_mapping_b[st][1]]
                    oht_str = text_b[offset_mapping_b[oh][0]: offset_mapping_b[ot][1]]
                    spoes[b].add((sht_str, id2predicate[p], oht_str))  # 添加预测结果:(subject, predicate, object)  
    return spoes

In [13]:
# 模型训练
def train(model, dataloader, optimizer, tokenizer, id2predicate, device):
    model.train()

    for idx, (
    input_ids, attention_mask, token_type_ids, entity_labels, head_labels, tail_labels, texts, spo_lists) in enumerate(
            dataloader):
        # 数据设备切换
        # input_ids.shape=[batch_size, seq_len]
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        entity_labels = entity_labels.to(device)
        head_labels = head_labels.to(device)
        tail_labels = tail_labels.to(device)

        # logits1.shape=[batch_size, 2, seq_len, seq_len]
        # logits2.shape=[batch_size, len(schema) seq_len, seq_len]
        # logits3.shape=[batch_size, len(schema), seq_len, seq_len]
        logits1, logits2, logits3 = model(input_ids, attention_mask, token_type_ids)

        loss1 = sparse_multilabel_categorical_crossentropy(y_true=entity_labels, y_pred=logits1)
        loss2 = sparse_multilabel_categorical_crossentropy(y_true=head_labels, y_pred=logits2)
        loss3 = sparse_multilabel_categorical_crossentropy(y_true=tail_labels, y_pred=logits3)
        loss = sum([loss1, loss2, loss3]) / 3  # entities和relations之间的信息共享和交互
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0 and idx > 0:
            y_pred = extract_spoes(logits1, logits2, logits3, texts, tokenizer, id2predicate)
            mc = MetricsCalculator_CMeIE()  # 计算查准率、查全率、F1 score 
            mc.calc_confusion_matrix(y_pred, spo_lists)
            print('| step {:5d} | loss {:9.5f} | precision {:8.5f} | recall {:8.5f} | f1 {:8.5f} |'.format(idx,
                                                                                                           loss.item(),
                                                                                                           mc.precision,
                                                                                                           mc.recall,
                                                                                                           mc.f1))

In [14]:
# 模型验证
def evaluate(model, dataloader, tokenizer, id2predicate, device):
    model.eval()

    mc = MetricsCalculator_CMeIE()  # 计算查准率、查全率、F1 score 
    with torch.no_grad():
        for input_ids, attention_mask, token_type_ids, _, _, _, texts, spo_lists in dataloader:
            # 数据设备切换
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)

            logits1, logits2, logits3 = model(input_ids, attention_mask, token_type_ids)
            y_pred = extract_spoes(logits1, logits2, logits3, texts, tokenizer, id2predicate)

            mc.calc_confusion_matrix(y_pred, spo_lists)
    return mc.precision, mc.recall, mc.f1

In [15]:
for epoch in range(10):
    epoch_start_time = time.time()
    train(net, train_loader, optimizer, tokenizer_fast, id2schema, device)
    valid_precision, valid_recall, valid_f1 = evaluate(net, valid_loader, tokenizer_fast, id2schema, device)
    print('-' * 100)
    print('| epoch: {:5d} | time: {:5.2f}s '
          '| valid precision {:8.5f} '
          '| valid recall {:8.5f} '
          '| valid f1 {:8.5f} |'.format(epoch,
                                        time.time() - epoch_start_time,
                                        valid_precision,
                                        valid_recall,
                                        valid_f1))
    print('-' * 100)

| step   100 | loss 339.16815 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   200 | loss 151.01361 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   300 | loss 157.45125 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   400 | loss 139.54654 | precision  0.50000 | recall  0.02000 | f1  0.03846 |
| step   500 | loss  85.37960 | precision  1.00000 | recall  0.02500 | f1  0.04878 |
| step   600 | loss 103.00769 | precision  1.00000 | recall  0.05263 | f1  0.10000 |
| step   700 | loss  79.83633 | precision  0.50000 | recall  0.05405 | f1  0.09756 |
| step   800 | loss  87.48330 | precision  0.63636 | recall  0.15217 | f1  0.24561 |
----------------------------------------------------------------------------------------------------
| epoch:     0 | time: 295.65s | valid precision  0.62902 | valid recall  0.20465 | valid f1  0.30883 |
----------------------------------------------------------------------------------------------------
| step   100 |