In [1]:
import copy
from transformers import (BertTokenizerFast, BertModel)
from torch.utils.data import (DataLoader, Dataset)
import json
import time
import numpy as np
import torch
from sklearn.model_selection import train_test_split
import random
from tqdm import tqdm

from model import GlobalPointer
from utils import loss_fun, MetricsCalculator

In [2]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


seed = 2022
set_seed(seed)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
with open('datasets/train_ner.json', 'r', encoding='utf-8') as f:
    train = json.load(f)  # 列表
print(train[0], end='\n\n')

# 划分训练/验证数据集
train_data, valid_data = train_test_split(train, test_size=0.1, random_state=seed)

print(len(train_data))
print(len(valid_data))

{'sent': '胸 廓 对 称 ， 气 管 居 中 。 所 见 骨 骼 骨 质 结 构 完 整 。 双 肺 纹 理 清 晰 。 两 肺 门 影 不 大 。 心 影 横 径 增 大 ， 左 心 缘 饱 满 。 两 侧 膈 面 光 整 ， 两 侧 肋 膈 角 锐 利 。 1 . 两 肺 未 见 明 显 活 动 性 病 变 ， 随 诊 。 2 . 心 影 改 变 请 结 合 临 床 。', 'ners': [[0, 2, '器官组织', '胸廓'], [2, 4, '阴性表现', '对称'], [5, 7, '器官组织', '气管'], [7, 9, '阴性表现', '居中'], [12, 16, '器官组织', '骨骼骨质'], [16, 18, '属性', '结构'], [18, 20, '阴性表现', '完整'], [21, 23, '器官组织', '双肺'], [23, 25, '属性', '纹理'], [25, 27, '阴性表现', '清晰'], [28, 32, '器官组织', '两肺门影'], [32, 34, '阴性表现', '不大'], [35, 37, '器官组织', '心影'], [37, 39, '属性', '横径'], [39, 41, '阳性表现', '增大'], [42, 45, '器官组织', '左心缘'], [45, 47, '阳性表现', '饱满'], [48, 52, '器官组织', '两侧膈面'], [52, 54, '阴性表现', '光整'], [55, 60, '器官组织', '两侧肋膈角'], [60, 62, '阴性表现', '锐利'], [65, 67, '器官组织', '两肺'], [67, 69, '否定描述', '未见'], [69, 71, '修饰描述', '明显'], [71, 74, '修饰描述', '活动性'], [74, 76, '异常现象', '病变'], [82, 84, '器官组织', '心影'], [84, 86, '异常现象', '改变']], 'spans': [[1, 2, '器官组织', '胸廓'], [3, 4, '阴性表现', '对称'], [6, 7, '器官组织', '气管'], [8, 9, '阴性表现', '居中'], [13, 16, '器官组织', '骨骼骨质'], [

In [5]:
with open('datasets/ent_map_id.json', 'r', encoding='utf-8') as f:
    ent_to_ot_id = json.load(f)
ent_to_ot_id

{'ent2id': {'修饰描述': 0,
  '否定描述': 1,
  '器官组织': 2,
  '属性': 3,
  '异常现象': 4,
  '手术': 5,
  '指代': 6,
  '数量': 7,
  '期象': 8,
  '检查手段': 9,
  '测量值': 10,
  '疾病': 11,
  '病理分型': 12,
  '病理分期': 13,
  '病理分级': 14,
  '累及部位': 15,
  '阳性表现': 16,
  '阴性表现': 17},
 'id2ent': {'0': '修饰描述',
  '1': '否定描述',
  '2': '器官组织',
  '3': '属性',
  '4': '异常现象',
  '5': '手术',
  '6': '指代',
  '7': '数量',
  '8': '期象',
  '9': '检查手段',
  '10': '测量值',
  '11': '疾病',
  '12': '病理分型',
  '13': '病理分期',
  '14': '病理分级',
  '15': '累及部位',
  '16': '阳性表现',
  '17': '阴性表现'}}

In [6]:
class CustomDataset(Dataset):
    """定义数据集"""

    def __init__(self, sentences):
        self._sentences = sentences

    def __len__(self):
        return len(self._sentences)

    def __getitem__(self, index):
        sentence = self._sentences[index]
        return {'text': sentence['sent'],
                'tags': sentence['ners'],
                'spans': sentence['spans']}


train_dataset = CustomDataset(sentences=train_data)
valid_dataset = CustomDataset(sentences=valid_data)

for i in train_dataset:
    # 调用__getitem__方法
    print(i['text'])
    print(i['tags'])
    print(i['spans'])
    break

胸 腰 椎 C T 平 扫 + 三 维 重 建 所 示 层 面 的 胸 腰 段 椎 体 生 理 曲 度 存 在 ， 椎 体 序 列 尚 规 则 ， 胸 1 2 椎 体 骨 质 连 续 性 中 断 、 椎 体 高 度 变 扁 呈 楔 形 改 变 、 密 度 不 均 ； 余 所 示 诸 椎 体 边 缘 及 椎 小 关 节 可 见 骨 质 增 生 、 硬 化 影 ， 部 分 变 尖 ， 椎 间 隙 无 明 显 狭 窄 ， 腰 3 / 4 、 腰 4 / 5 椎 间 盘 向 周 围 隆 起 ， 硬 膜 囊 受 压 ， 椎 管 未 见 明 显 狭 窄 。 1 . 胸 1 2 椎 体 压 缩 性 骨 折 ； 2 . 胸 腰 椎 退 行 性 变 ； 腰 3 / 4 、 腰 4 / 5 椎 间 盘 膨 出 。
[[0, 3, '器官组织', '胸腰椎'], [3, 7, '检查手段', 'CT平扫'], [8, 12, '检查手段', '三维重建'], [17, 22, '器官组织', '胸腰段椎体'], [22, 26, '属性', '生理曲度'], [26, 28, '阴性表现', '存在'], [29, 31, '器官组织', '椎体'], [31, 33, '属性', '序列'], [34, 36, '阴性表现', '规则'], [37, 44, '器官组织', '胸12椎体骨质'], [44, 47, '属性', '连续性'], [47, 49, '阳性表现', '中断'], [50, 52, '器官组织', '椎体'], [54, 56, '阳性表现', '变扁'], [57, 59, '修饰描述', '楔形'], [59, 61, '异常现象', '改变'], [62, 64, '属性', '密度'], [64, 66, '阳性表现', '不均'], [70, 73, '器官组织', '诸椎体'], [73, 75, '属性', '边缘'], [76, 80, '器官组织', '椎小关节'], [82, 86, '异常现象', '骨质增生'], [87, 90, '异常现象', '硬化影'], [91, 93, '指代', '部分'], [93, 95, '阳性表现', '变尖'], [96, 99, '器官组织', '椎间隙'], [99, 100, '否定描述',

In [7]:
tokenizer_fast = BertTokenizerFast.from_pretrained('./save_tokenizer/')
print(tokenizer_fast)

pretrained = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext-large')
print(pretrained.num_parameters())

PreTrainedTokenizerFast(name_or_path='./save_tokenizer/', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext-large were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


325522432


In [8]:
def get_collate_fn(tokenizer, ent_map_id, max_len=512):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    ent2id = ent_map_id['ent2id']
    ent_type_size = len(ent2id)  # 实体类型个数

    def collate_fn(batch):
        sentences_list = [sentence['text'] for sentence in batch]
        spans_list = [sentence['spans'] for sentence in batch]
        outputs = tokenizer(sentences_list, truncation=True, max_length=max_len, padding=True, return_tensors='pt')
        input_ids, attention_mask, token_type_ids = outputs.input_ids, outputs.attention_mask, outputs.token_type_ids

        labels = np.zeros((input_ids.shape[0], ent_type_size, input_ids.shape[1], input_ids.shape[1]))  # 构造labels
        for i, spans in enumerate(spans_list):
            for start, end, ent_type, ent_text in spans:
                labels[i, ent2id[ent_type], start, end] = 1
        labels = torch.tensor(labels, dtype=torch.long)
        return input_ids, attention_mask, token_type_ids, labels

    return collate_fn


dataloader_train = DataLoader(dataset=train_dataset,
                              batch_size=4,
                              shuffle=True,
                              collate_fn=get_collate_fn(tokenizer_fast, ent_to_ot_id))
dataloader_valid = DataLoader(dataset=valid_dataset,
                              batch_size=4,
                              shuffle=False,
                              collate_fn=get_collate_fn(tokenizer_fast, ent_to_ot_id))

for input_ids, attention_mask, token_type_ids, labels in dataloader_train:
    print(input_ids)
    print(input_ids.shape)
    print(attention_mask)
    print(token_type_ids)
    print(labels.shape)
    break

tensor([[ 101, 4508, 4307, 5593, 1381, 1383,  123,  124,  155,  155,  190,  122,
          130,  155,  155,  190,  125,  127,  155,  155, 8024, 2340, 1383,  122,
          128,  155,  155,  190,  122,  127,  155,  155,  190,  125,  125,  155,
          155, 8024, 2284, 6956, 1331,  124,  155,  155,  511, 4508, 4307, 5593,
         1920, 2207,  510, 2501, 2578, 3633, 2382, 8024, 1726, 1898, 1772, 1258,
         8024, 1259, 5606, 1045, 3146,  511,  145,  146,  148,  151, 8038, 5593,
          860, 1079, 6224, 4157, 3340, 4307, 6117, 3837,  928, 1384,  511, 4508,
         4307, 5593, 1079, 6224, 3144,  702,  856, 1726, 1898, 5310, 5688, 8024,
         6804, 3926, 6226, 1156, 8024, 6804, 5357, 1045, 3146, 8024, 2340, 1383,
         6772, 1920, 4638, 5276,  122,  124,  155,  155,  190,  130,  155,  155,
         8024, 1381, 1383, 6772, 1920, 4638, 5276,  123,  129,  155,  155,  190,
          122,  127,  155,  155,  117,  145,  146,  148,  151, 8038, 5310, 5688,
         1079, 3313, 6224, 3

In [9]:
model = GlobalPointer(copy.deepcopy(pretrained), len(ent_to_ot_id['ent2id']), 64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [10]:
# 模型训练
def train(model, dataloader, optimizer, device):
    model.train()

    for idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloader):
        # 数据设备切换
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        # labels.shape=[batch_size, ent_type_size, seq_len, seq_len]
        labels = labels.to(device)

        # logits.shape=[batch_size, ent_type_size, seq_len, seq_len]
        logits = model(input_ids, attention_mask, token_type_ids)
        loss = loss_fun(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0 and idx > 0:
            mc = MetricsCalculator()  # 计算实体的查准率、查全率、F1 score
            mc.calc_confusion_matrix_ner(logits, labels)
            print('| step {:5d} | loss {:8.5f} | precision {:8.5f} | recall {:8.5f} | f1 {:8.5f} |'.format(idx,
                                                                                                           loss.item(),
                                                                                                           mc.precision,
                                                                                                           mc.recall,
                                                                                                           mc.f1))

In [11]:
# 模型验证
def evaluate(model, dataloader, device):
    model.eval()

    mc = MetricsCalculator()
    with torch.no_grad():
        for input_ids, attention_mask, token_type_ids, labels in dataloader:
            # 数据设备切换
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            # logits.shape=[batch_size, ent_type_size, seq_len, seq_len]
            logits = model(input_ids, attention_mask, token_type_ids)

            mc.calc_confusion_matrix_ner(logits, labels)
    return mc.precision, mc.recall, mc.f1

In [12]:
for epoch in range(3):
    epoch_start_time = time.time()
    train(model, dataloader_train, optimizer, device)
    _, _, train_f1 = evaluate(model, dataloader_train, device)
    valid_precision, valid_recall, valid_f1 = evaluate(model, dataloader_valid, device)
    print('-' * 123)
    print('| epoch: {:5d} | time: {:5.2f}s '
          '| valid precision {:8.5f} '
          '| valid recall {:8.5f} '
          '| valid f1 {:8.5f} | train f1 {:8.5f} |'.format(epoch,
                                                           time.time() - epoch_start_time,
                                                           valid_precision,
                                                           valid_recall,
                                                           valid_f1,
                                                           train_f1))
    print('-' * 123)

| step   100 | loss  1.45807 | precision  0.87500 | recall  0.79333 | f1  0.83217 |
| step   200 | loss  0.61724 | precision  0.96241 | recall  0.93431 | f1  0.94815 |
| step   300 | loss  0.33895 | precision  0.96522 | recall  0.96522 | f1  0.96522 |
| step   400 | loss  0.67916 | precision  0.95139 | recall  0.91946 | f1  0.93515 |
| step   500 | loss  0.20465 | precision  0.97059 | recall  0.96117 | f1  0.96585 |
| step   600 | loss  0.38906 | precision  0.95420 | recall  0.93284 | f1  0.94340 |
| step   700 | loss  0.41365 | precision  0.97248 | recall  0.92174 | f1  0.94643 |
| step   800 | loss  1.07702 | precision  0.90972 | recall  0.86184 | f1  0.88514 |
---------------------------------------------------------------------------------------------------------------------------
| epoch:     0 | time: 458.36s | valid precision  0.93032 | valid recall  0.94425 | valid f1  0.93723 | train f1  0.95313 |
--------------------------------------------------------------------------------

In [13]:
sentence_pred_all = []

with open('datasets/testB.conll_sent.conll', 'r', encoding='utf-8') as f:
    testB_sentences = f.readlines()
    for sentence in tqdm(testB_sentences):  # 每次预测一条数据
        sentence = sentence.strip()
        sentence_pred = {"sent": sentence}
        output = tokenizer_fast([sentence], return_offsets_mapping=True, max_length=512, truncation=True, padding=True)
        input_ids = torch.tensor(output['input_ids'], dtype=torch.int64).to(device)
        token_type_ids = torch.tensor(output['token_type_ids'], dtype=torch.int64).to(device)
        attention_mask = torch.tensor(output['attention_mask'], dtype=torch.int64).to(device)

        # 处理原句空格
        offset_mapping = []
        for i, (start, end) in enumerate(output["offset_mapping"][0]):
            if (end > 0) and (i >= 2):
                start -= (i - 1)
                end -= (i - 1)
            offset_mapping.append((start, end))

        sentence = sentence.replace(' ', '')

        ent_list = []
        with torch.no_grad():
            logits = model(input_ids, attention_mask, token_type_ids).cpu()
            for _, l, start, end in zip(*torch.where(logits > 0.0)):  # 阈值(threshold)设置为0.0
                ent_type = ent_to_ot_id['id2ent'][str(l.item())]
                ent_char_span = [offset_mapping[start.item()][0], offset_mapping[end.item()][1]]
                ent_text = sentence[ent_char_span[0]: ent_char_span[1]]
                ent_list.append([ent_char_span[0], ent_char_span[1], ent_type, ent_text])
        ent_list = sorted(ent_list, key=lambda x: x[0])
        sentence_pred['ners'] = ent_list
        sentence_pred_all.append(sentence_pred)

100%|██████████| 1000/1000 [01:14<00:00, 13.35it/s]


In [14]:
sentence_pred_all[0]

{'sent': '幽 门 : 呈 圆 形 , 开 闭 尚 可 , 粘 膜 皱 襞 光 滑 , 色 泽 淡 红 , 未 见 出 血 及 溃 疡 。',
 'ners': [[0, 2, '器官组织', '幽门'],
  [4, 6, '阴性表现', '圆形'],
  [7, 9, '属性', '开闭'],
  [10, 11, '阴性表现', '可'],
  [12, 16, '器官组织', '粘膜皱襞'],
  [16, 18, '阴性表现', '光滑'],
  [19, 21, '属性', '色泽'],
  [24, 26, '否定描述', '未见'],
  [26, 28, '阳性表现', '出血'],
  [29, 31, '阳性表现', '溃疡']]}

In [15]:
with open('result_data/ner.json', 'w') as fp:
    json.dump(sentence_pred_all, fp, ensure_ascii=False, indent=2)