In [1]:
import torch
import numpy as np
import random
import json
from sklearn.model_selection import train_test_split
from torch.utils.data import (DataLoader, Dataset)
from transformers import (BertTokenizerFast, BertModel)
import time
import copy

from model import CustomRelation
from utils import multilabel_categorical_crossentropy, MetricsCalculator

In [2]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


seed = 2022
set_seed(seed)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
with open('datasets/train_rel.json', 'r', encoding='utf-8') as f:
    train = json.load(f)  # 列表

# 划分训练/验证数据集
train_data, valid_data = train_test_split(train, test_size=0.1, random_state=seed)

print(len(train_data))
print(len(valid_data))

3557
396


In [5]:
class CustomDataset(Dataset):
    """定义数据集"""

    def __init__(self, sentences):
        self._sentences = sentences

    def __len__(self):
        return len(self._sentences)

    def __getitem__(self, index):
        sentence = self._sentences[index]

        return {'text': sentence['sentence_text'],
                'position_ids': sentence['position_ids'],
                'relation_label': sentence['relation_label'],
                'relation_idx': sentence['relation_idx']}


train_dataset = CustomDataset(sentences=train_data)
valid_dataset = CustomDataset(sentences=valid_data)

In [6]:
tokenizer_fast = BertTokenizerFast.from_pretrained('save_tokenizer/')
print(tokenizer_fast)

pretrained = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext-large')
pretrained.resize_token_embeddings(len(tokenizer_fast))
print(pretrained.num_parameters())

PreTrainedTokenizerFast(name_or_path='save_tokenizer/', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext-large were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


325559296


In [7]:
def get_collate_fn(tokenizer, max_len=512):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(batch):
        sentences_list = [sentence['text'] for sentence in batch]
        # 设置add_special_tokens=False(数据处理中已经加入了'[CLS]','[SEP]'等特殊token)
        outputs = tokenizer(sentences_list, truncation=True, max_length=max_len, padding=True,
                            add_special_tokens=False, return_tensors='pt')
        input_ids, attention_mask, token_type_ids = outputs.input_ids, outputs.attention_mask, outputs.token_type_ids

        relation_max_len = 0  # 该batch内最长的relation_label
        position_ids = []
        for sentence in batch:
            if len(sentence['relation_label']) >= relation_max_len:
                relation_max_len = len(sentence['relation_label'])
            pos_ids = sentence['position_ids'].copy()
            pad = [0] * (input_ids.shape[1] - len(pos_ids))
            position_ids.append(pos_ids + pad)
        position_ids = torch.tensor(position_ids, dtype=torch.long)  # 自定义位置嵌入

        labels, labels_mask, relations_idx = [], [], []
        for sentence in batch:
            r_label = []
            for label in sentence['relation_label']:
                r_label.append(int(label == '属性'))  # 1表示存在关系,0表示不存在关系
            pad = [0] * (relation_max_len - len(r_label))
            labels.append(r_label + pad)
            labels_mask.append([1] * len(r_label) + pad)

            relation_idx = sentence['relation_idx'].copy()
            for _ in pad:
                relation_idx.append([0, 0, 0, 0])
            relations_idx.append(relation_idx)
        # labels.shape=[batch_size, relation_max_len] 
        labels = torch.tensor(labels, dtype=torch.long)
        # labels_mask.shape=[batch_size, relation_max_len]
        labels_mask = torch.tensor(labels_mask, dtype=torch.long)
        # relations_idx.shape=[batch_size, relation_max_len, 4]
        relations_idx = np.array(relations_idx)

        return input_ids, attention_mask, token_type_ids, position_ids, labels, labels_mask, relations_idx

    return collate_fn


dataloader_train = DataLoader(dataset=train_dataset,
                              batch_size=4,
                              shuffle=True,
                              collate_fn=get_collate_fn(tokenizer_fast))
dataloader_valid = DataLoader(dataset=valid_dataset,
                              batch_size=4,
                              shuffle=False,
                              collate_fn=get_collate_fn(tokenizer_fast))

for i in dataloader_train:
    print('input_ids.shape:', i[0].shape)
    print('attention_mask.shape:', i[1].shape)
    print('token_type_ids.shape:', i[2].shape)
    print('position_ids.shape.shape:', i[3].shape)
    print('label.shape:', i[4].shape)  # 本赛题只有一种关系(若有n_type种关系,则label.shape=[batch_size, n_type, relation_max_len])
    print('labels_mask.shape:', i[5].shape)
    print('relations_idx.shape:',
          i[6].shape)  # 若有n_type种关系,则relations_idx.shape=[batch_size, n_type, relation_max_len, 4]
    break

input_ids.shape: torch.Size([4, 247])
attention_mask.shape: torch.Size([4, 247])
token_type_ids.shape: torch.Size([4, 247])
position_ids.shape.shape: torch.Size([4, 247])
label.shape: torch.Size([4, 1482])
labels_mask.shape: torch.Size([4, 1482])
relations_idx.shape: (4, 1482, 4)


In [8]:
model = CustomRelation(copy.deepcopy(pretrained), 64, True).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [9]:
# 模型训练
def train(model, dataloader, optimizer, device):
    model.train()

    for idx, (input_ids, attention_mask, token_type_ids, position_ids, labels, labels_mask, relations_idx) in enumerate(
            dataloader):
        # 数据设备切换
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels_mask = labels_mask.to(device)
        # labels.shape=[batch_size, relation_max_len]
        labels = labels.to(device)

        # logits.shape=[batch_size, relation_max_len]
        logits = model(input_ids, attention_mask, token_type_ids, position_ids, relations_idx, labels_mask)
        loss = multilabel_categorical_crossentropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0 and idx > 0:
            mc = MetricsCalculator()  # 计算实体关系的查准率、查全率、F1 score
            mc.calc_confusion_matrix_rel(logits, labels)
            print('| step {:5d} | loss {:8.5f} | precision {:8.5f} | recall {:8.5f} | f1 {:8.5f} |'.format(idx,
                                                                                                           loss.item(),
                                                                                                           mc.precision,
                                                                                                           mc.recall,
                                                                                                           mc.f1))

In [10]:
# 模型验证
def evaluate(model, dataloader, device):
    model.eval()

    mc = MetricsCalculator()
    with torch.no_grad():
        for input_ids, attention_mask, token_type_ids, position_ids, labels, labels_mask, relations_idx in dataloader:
            # 数据设备切换
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            labels_mask = labels_mask.to(device)
            # logits.shape=[batch_size, relation_max_len]
            logits = model(input_ids, attention_mask, token_type_ids, position_ids, relations_idx, labels_mask)

            mc.calc_confusion_matrix_rel(logits, labels)
    return mc.precision, mc.recall, mc.f1

In [11]:
for epoch in range(3):
    epoch_start_time = time.time()
    train(model, dataloader_train, optimizer, device)
    _, _, train_f1 = evaluate(model, dataloader_train, device)
    valid_precision, valid_recall, valid_f1 = evaluate(model, dataloader_valid, device)
    print('-' * 123)
    print('| epoch: {:5d} | time: {:5.2f}s '
          '| valid precision {:8.5f} '
          '| valid recall {:8.5f} '
          '| valid f1 {:8.5f} | train f1 {:8.5f} |'.format(epoch,
                                                           time.time() - epoch_start_time,
                                                           valid_precision,
                                                           valid_recall,
                                                           valid_f1,
                                                           train_f1))
    print('-' * 123)

| step   100 | loss  3.81342 | precision  0.93407 | recall  0.94444 | f1  0.93923 |
| step   200 | loss  2.71993 | precision  0.98750 | recall  0.96341 | f1  0.97531 |
| step   300 | loss  3.84497 | precision  0.96040 | recall  0.98980 | f1  0.97487 |
| step   400 | loss  3.19924 | precision  0.95312 | recall  0.84722 | f1  0.89706 |
| step   500 | loss  2.29865 | precision  0.94231 | recall  0.98990 | f1  0.96552 |
| step   600 | loss  3.48252 | precision  0.96296 | recall  0.94891 | f1  0.95588 |
| step   700 | loss  2.94238 | precision  0.93269 | recall  0.98980 | f1  0.96040 |
| step   800 | loss  2.43509 | precision  0.96341 | recall  0.96341 | f1  0.96341 |
---------------------------------------------------------------------------------------------------------------------------
| epoch:     0 | time: 245.81s | valid precision  0.96787 | valid recall  0.96549 | valid f1  0.96668 | train f1  0.97087 |
--------------------------------------------------------------------------------

In [13]:
all_rel_predict = []

with open('my_data/ner_rel_predict.json', 'r', encoding='utf-8') as f:
    for ners_sentence in json.load(f):  # 每次预测一条数据
        if not ners_sentence['relation']:  # relation可能为空列表
            all_rel_predict.append([])  # 此时关系预测为:[]
        else:
            # 设置add_special_tokens=False(数据处理中已经加入了'[CLS]','[SEP]'等特殊token)
            outputs = tokenizer_fast(ners_sentence['sentence_text'], add_special_tokens=False, return_tensors='pt')
            input_ids, attention_mask, token_type_ids = outputs.input_ids.to(device), outputs.attention_mask.to(
                device), outputs.token_type_ids.to(device)
            # position_ids.shape=[1, seq_len]
            position_ids = torch.tensor(ners_sentence['position_ids'])[None, :]
            # relation_idx.shape=[1, seq_len, 4]
            relations_idx = np.array(ners_sentence['relation_idx'])[None, :]
            relation = np.array(ners_sentence['relation'])
            logits = torch.squeeze(model(input_ids, attention_mask, token_type_ids, position_ids, relations_idx)).cpu()
            predict = torch.where(logits > 0)[0].tolist()  # 阈值(threshold)设置为0.0
            last_result = relation[predict]
            last_result = [i + '\t' + '属性' for i in last_result]
            all_rel_predict.append(last_result)

In [14]:
all_rel_predict[0]



["[6, 14, '器官组织', '右肺下叶内基底段']\t[14, 17, '异常现象', '小结节']\t属性",
 "[37, 45, '器官组织', '左肺下叶前基底段']\t[47, 50, '异常现象', '钙化灶']\t属性",
 "[45, 47, '修饰描述', '点状']\t[47, 50, '异常现象', '钙化灶']\t属性",
 "[64, 66, '器官组织', '胆囊']\t[70, 72, '异常现象', '结石']\t属性"]