In [1]:
import json
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel
import numpy as np
import random
import torch
import time

from utils import loss_fun, MetricsCalculator
from GlobalPointer import GlobalPointer

In [2]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
# dis:疾病
# sym:临床表现
# pro:医疗程序
# equ:医疗设备
# dru:药物
# ite:医学检验项目
# bod:身体
# dep:科室
# mic:微生物类
ent2id = {"bod": 0, "dis": 1, "sym": 2, "mic": 3, "pro": 4, "ite": 5, "dep": 6, "dru": 7, "equ": 8}  # 9个实体类型

id2ent = {}
for k, v in ent2id.items(): id2ent[v] = k
print(id2ent)

{0: 'bod', 1: 'dis', 2: 'sym', 3: 'mic', 4: 'pro', 5: 'ite', 6: 'dep', 7: 'dru', 8: 'equ'}


In [5]:
def load_data(path):
    D = []
    with open(path) as f:
        for d in json.load(f):
            D.append([d['text']])
            for e in d['entities']:
                start, end, label = e['start_idx'], e['end_idx'], e['type']
                if start <= end:
                    D[-1].append((start, end, ent2id[label]))
    return D


data_train = load_data('datasets/CMeEE_train.json')
data_dev = load_data('datasets/CMeEE_dev.json')
print(len(data_train))
print(len(data_dev))
print(data_dev[0])

15000
5000
['对儿童SARST细胞亚群的研究表明，与成人SARS相比，儿童细胞下降不明显，证明上述推测成立。', (3, 9, 0), (19, 24, 1)]


In [6]:
class CustomDataset(Dataset):
    """自定义Dataset"""

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


dataset_train = CustomDataset(data_train)
dataset_valid = CustomDataset(data_dev)
for i in dataset_valid:
    print(i)
    break

['对儿童SARST细胞亚群的研究表明，与成人SARS相比，儿童细胞下降不明显，证明上述推测成立。', (3, 9, 0), (19, 24, 1)]


In [7]:
model_name = 'hfl/chinese-roberta-wwm-ext'

tokenizer_fast = BertTokenizerFast.from_pretrained(model_name)
print(tokenizer_fast)
pretrained = BertModel.from_pretrained(model_name)
print(pretrained.num_parameters())

PreTrainedTokenizerFast(name_or_path='hfl/chinese-roberta-wwm-ext', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


102267648


In [8]:
def get_collate_fn(tokenizer, max_len=512):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(data):
        batch_size = len(data)
        sentences_list = [i[0] for i in data]  # len(sentences_list) = batch_size
        entities_list = [i[1:] for i in data]  # len(entities_list) = batch_size

        # 长度>max_len ===> 截断
        # 长度<=max_len ===> 当前批次最大长度
        outputs = tokenizer(sentences_list, max_length=max_len, truncation=True, padding=True,
                            return_offsets_mapping=True)
        input_ids = torch.tensor(outputs["input_ids"], dtype=torch.int64)
        attention_mask = torch.tensor(outputs["attention_mask"], dtype=torch.int64)
        token_type_ids = torch.tensor(outputs["token_type_ids"], dtype=torch.int64)
        offset_mapping = outputs["offset_mapping"]  # return (char_start, char_end) for each token.

        # (0, 0)表示特殊token(如:'[CLS]','[SEP'], '[PAD]'等)
        # offset_mapping为左闭右开(故j[1] - 1表示该token结尾字符的位置)
        # i表示第几个token(从0开始计数,包含特殊token)
        start_mapping = [{j[0]: i for i, j in enumerate(i) if j != (0, 0)} for i in offset_mapping]
        end_mapping = [{j[1] - 1: i for i, j in enumerate(i) if j != (0, 0)} for i in offset_mapping]

        # 实体类别数量:len(ent2id)
        labels = np.zeros((batch_size, len(ent2id), input_ids.shape[1], input_ids.shape[1]))
        for i in range(batch_size):
            for start, end, label in entities_list[i]:
                if start in start_mapping[i] and end in end_mapping[i]:
                    start = start_mapping[i][start]
                    end = end_mapping[i][end]
                    labels[i, label, start, end] = 1  # label实体类别中实体的位置
        labels = torch.tensor(labels, dtype=torch.int64)
        return input_ids, attention_mask, token_type_ids, labels

    return collate_fn


train_loader = DataLoader(dataset_train, batch_size=16, shuffle=True, collate_fn=get_collate_fn(tokenizer_fast))
valid_loader = DataLoader(dataset_valid, batch_size=16, shuffle=False, collate_fn=get_collate_fn(tokenizer_fast))
for j in valid_loader:
    print(j[0])
    print(j[0].shape, j[1].shape, j[2].shape)
    print(j[-1].shape)
    break

tensor([[ 101, 2190, 1036,  ...,    0,    0,    0],
        [ 101, 4777, 4955,  ...,    0,    0,    0],
        [ 101, 1728, 5445,  ...,    0,    0,    0],
        ...,
        [ 101,  123,  119,  ...,    0,    0,    0],
        [ 101,  124,  119,  ...,    0,    0,    0],
        [ 101,  125,  119,  ...,    0,    0,    0]])
torch.Size([16, 122]) torch.Size([16, 122]) torch.Size([16, 122])
torch.Size([16, 9, 122, 122])


In [9]:
model = GlobalPointer(pretrained, len(ent2id), 64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [10]:
# 模型训练
def train(model, dataloader, optimizer, device):
    model.train()

    for idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloader):
        # 数据设备切换
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        # labels.shape=[batch_size, ent_type_size, seq_len, seq_len]
        labels = labels.to(device)

        # logits.shape=[batch_size, ent_type_size, seq_len, seq_len]
        logits = model(input_ids, attention_mask, token_type_ids)
        loss = loss_fun(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0 and idx > 0:
            mc = MetricsCalculator()  # 计算查准率、查全率、F1 score 
            mc.calc_confusion_matrix(logits, labels)
            print('| step {:5d} | loss {:8.5f} | precision {:8.5f} | recall {:8.5f} | f1 {:8.5f} |'.format(idx,
                                                                                                           loss.item(),
                                                                                                           mc.precision,
                                                                                                           mc.recall,
                                                                                                           mc.f1))

In [11]:
# 模型验证
def evaluate(model, dataloader, device):
    model.eval()

    mc = MetricsCalculator()
    with torch.no_grad():
        for input_ids, attention_mask, token_type_ids, labels in dataloader:
            # 数据设备切换
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            # logits.shape=[batch_size, ent_type_size, seq_len, seq_len]
            logits = model(input_ids, attention_mask, token_type_ids)

            mc.calc_confusion_matrix(logits, labels)
    return mc.precision, mc.recall, mc.f1

In [12]:
for epoch in range(3):
    epoch_start_time = time.time()
    train(model, train_loader, optimizer, device)
    _, _, train_f1 = evaluate(model, train_loader, device)
    valid_precision, valid_recall, valid_f1 = evaluate(model, valid_loader, device)
    print('-' * 123)
    print('| epoch: {:5d} | time: {:5.2f}s '
          '| valid precision {:8.5f} '
          '| valid recall {:8.5f} '
          '| valid f1 {:8.5f} | train f1 {:8.5f} |'.format(epoch,
                                                           time.time() - epoch_start_time,
                                                           valid_precision,
                                                           valid_recall,
                                                           valid_f1,
                                                           train_f1))
    print('-' * 123)

| step   100 | loss  1.19852 | precision  0.50000 | recall  0.02326 | f1  0.04444 |
| step   200 | loss  0.68973 | precision  0.64444 | recall  0.45312 | f1  0.53211 |
| step   300 | loss  0.80105 | precision  0.78571 | recall  0.17742 | f1  0.28947 |
| step   400 | loss  0.79979 | precision  0.60000 | recall  0.25424 | f1  0.35714 |
| step   500 | loss  0.66815 | precision  0.70000 | recall  0.50909 | f1  0.58947 |
| step   600 | loss  0.53522 | precision  0.86364 | recall  0.54286 | f1  0.66667 |
| step   700 | loss  0.64121 | precision  0.72340 | recall  0.46575 | f1  0.56667 |
| step   800 | loss  0.63160 | precision  0.65385 | recall  0.54839 | f1  0.59649 |
| step   900 | loss  1.01873 | precision  0.78846 | recall  0.44086 | f1  0.56552 |
---------------------------------------------------------------------------------------------------------------------------
| epoch:     0 | time: 279.43s | valid precision  0.65952 | valid recall  0.59908 | valid f1  0.62785 | train f1  0.6834

In [13]:
all_ent_list = []

for d in json.load(open('./datasets/CMeEE_test.json')):
    text = d["text"]

    output = tokenizer_fast([text], return_offsets_mapping=True, max_length=512, truncation=True, padding=True)
    input_ids = torch.tensor(output['input_ids'], dtype=torch.int64).to(device)
    token_type_ids = torch.tensor(output['token_type_ids'], dtype=torch.int64).to(device)
    attention_mask = torch.tensor(output['attention_mask'], dtype=torch.int64).to(device)
    offset_mapping = output['offset_mapping']

    one_ent_list = {'text': text, 'entities': []}
    with torch.no_grad():
        # logits.shape=[1, ent_type_size, seq_len, seq_len]
        logits = model(input_ids, attention_mask, token_type_ids).cpu()
        for _, l, start, end in zip(*torch.where(logits > 0.0)):  # 阈值(threshold)设置为0.0
            ent_type = id2ent[l.item()]
            # [offset_mapping[0][start]表示该实体开始token的位置信息
            # [offset_mapping[0][end]表示该实体结尾token的位置信息
            ent_char_span = [offset_mapping[0][start][0], offset_mapping[0][end][1]]
            ent_text = text[ent_char_span[0]: ent_char_span[1]]
            one_ent_list['entities'].append({"start_idx": ent_char_span[0],
                                             "end_idx": ent_char_span[1] - 1,  # j[1] - 1表示该token结尾字符的位置
                                             "type": ent_type,
                                             "entity": ent_text})
    all_ent_list.append(one_ent_list)  # 每次预测一条文本

In [14]:
print(all_ent_list[0])

{'text': '六、新生儿疾病筛查的发展趋势自1961年开展苯丙酮尿症筛查以来，随着医学技术的发展，符合进行新生儿疾病筛查标准的疾病也在不断增加，无论在新生儿疾病筛查的病种，还是在新生儿疾病筛查的技术方法上，都有了非常显著的进步。', 'entities': [{'start_idx': 22, 'end_idx': 26, 'type': 'dis', 'entity': '苯丙酮尿症'}, {'start_idx': 2, 'end_idx': 8, 'type': 'pro', 'entity': '新生儿疾病筛查'}, {'start_idx': 22, 'end_idx': 28, 'type': 'pro', 'entity': '苯丙酮尿症筛查'}, {'start_idx': 68, 'end_idx': 74, 'type': 'pro', 'entity': '新生儿疾病筛查'}, {'start_idx': 82, 'end_idx': 88, 'type': 'pro', 'entity': '新生儿疾病筛查'}]}
