In [1]:
import json
import numpy as np
import torch
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
import pandas as pd

from model import RawGlobalPointer, ERENet
from utils import sparse_multilabel_categorical_crossentropy, MetricsCalculator_bdci

In [2]:
def set_seed(seed):
    """PyTorch随机数种子设置大全"""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)  # CPU上设置随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)  # 当前GPU上设置随机种子
        # A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.
        torch.backends.cudnn.deterministic = True
        # torch.cuda.manual_seed_all(seed) # 所有GPU上设置随机种子


RANDOM_SEED = 42
set_seed(RANDOM_SEED)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
with open('datasets_bdci/rel2id.json', 'r', encoding='utf-8') as f:
    id_to_rel, rel_to_id = json.load(f)
    print(id_to_rel)
    print(rel_to_id)

{'0': '部件故障', '1': '性能故障', '2': '检测工具', '3': '组成'}
{'部件故障': 0, '性能故障': 1, '检测工具': 2, '组成': 3}


In [5]:
def split_sentence(text, spo_list, sign='。', max_len=510):
    """长句子根据标点符号(sign)截断"""
    asf = [0, 0]

    def sen_position(data):
        """句子长度之和不超过max_len的分组为同一组"""
        asf[1] += data
        if asf[1] > max_len:
            asf[1] = data
            asf[0] += 1
        return pd.Series([data, asf[0]])

    def sen_join(data):
        return pd.Series([''.join(data['text'].tolist()), data['txt_len'].sum()], index=['join_text', 'join_text_len'])

    text_s = text.split(sign)  # split函数结果不包含分割字符sign
    if text_s[-1] == '':
        text_ser = pd.Series(dict([(i + sign, len(i) + 1) for i in text_s[:-1]]))  # 最后一句以sign结尾时
    else:
        text_tup = [(i + sign, len(i) + 1) for i in text_s]
        text_tup[-1] = (text_tup[-1][0][:-1], text_tup[-1][1] - 1)  # 最后一句不以sign结尾时
        text_ser = pd.Series(dict(text_tup))
    text_df = text_ser.apply(sen_position)
    text_df = text_df.reset_index()
    text_df.columns = ['text', 'txt_len', 'group']
    result = text_df.groupby(text_df['group']).apply(sen_join)
    result['text_cumsum_len'] = result['join_text_len'].cumsum()

    spo_list_split = [[] for _ in range(result.shape[0])]
    for spo in spo_list:
        for i in result.iterrows():
            # 根据句子长度分组情况重新计算spo_list
            if spo['t']['pos'][-1] <= i[1]['text_cumsum_len']:
                if i[1]['text_cumsum_len'] == i[1]['join_text_len']:
                    spo_list_split[i[0]].append(
                        (spo['h']['name'], spo['h']['pos'], spo['t']['name'], spo['t']['pos'], spo['relation']))
                else:
                    cut_number = i[1]['text_cumsum_len'] - i[1]['join_text_len']
                    spo_h_pos = [spo['h']['pos'][0] - cut_number, spo['h']['pos'][1] - cut_number]  # 更新pos
                    spo_t_pos = [spo['t']['pos'][0] - cut_number, spo['t']['pos'][1] - cut_number]  # 更新pos
                    spo_list_split[i[0]].append(
                        (spo['h']['name'], spo_h_pos, spo['t']['name'], spo_t_pos, spo['relation']))
                break

    result['spo_list'] = spo_list_split
    return result


def split_one(text, spo_list, sen_lst, max_len=510):
    """长句子(不含标点符号)根据实体位置截断"""
    if len(text) <= max_len:
        sen_lst.append([text, spo_list])
        return
    span = []
    # i example:('蓄电池', [496, 499], '电压低', [499, 502], '部件故障')
    for i in spo_list:
        span.append([i, i[1][0], i[3][1]])
    span.sort(key=lambda x: x[-1])
    for i in range(len(span) - 1):
        if span[i][-1] <= max_len < span[i + 1][-2]:
            sen_lst.append([text[:span[i][-1]], [tuple(k[0]) for k in span[:i + 1]]])
            new_spo_list = [(k[0][0], [k[0][1][0] - span[i][-1], k[0][1][1] - span[i][-1]],
                             k[0][2], [k[0][3][0] - span[i][-1], k[0][3][1] - span[i][-1]], k[0][-1]) for k in
                            span[i + 1:]]  # 更新pos
            split_one(text[span[i][-1]:], new_spo_list, sen_lst)  # 递归

In [6]:
def load_data(filename):
    # example=[{'text': text0, 'spo_list': [(h_name00, hpos00, t_name00, tpos00, r00), (h_name01, hpos01, t_name01, tpos01, r01), xxxxxx]}, xxxxxx]
    D = []
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            line = json.loads(line)
            if len(line["text"]) <= 510:
                D.append({
                    "text": line["text"],
                    "spo_list": [(spo['h']['name'], spo['h']['pos'], spo['t']['name'], spo['t']['pos'], spo['relation'])
                                 for spo in line["spo_list"]]})
            else:
                line_text = line['text'].replace('；', '。')  # 分号替换为句号(切分句子优先级相等)
                if line_text.find('。') == -1 or line_text.find('。') == len(
                        line_text) - 1:  # 长句子(不含标点符号'。',或只含一个标点符号'。'且位于句子结尾)
                    spo_list = [(spo['h']['name'], spo['h']['pos'], spo['t']['name'], spo['t']['pos'], spo['relation'])
                                for spo in line["spo_list"]]
                    afa = []
                    split_one(line_text, spo_list, afa)
                    for i in afa:
                        D.append({"text": i[0], "spo_list": i[1]})
                else:
                    split_sentence_reuslt = split_sentence(line_text, line['spo_list'])
                    for i in split_sentence_reuslt.iterrows():
                        if len(i[1]["join_text"]) <= 510:
                            D.append({"text": i[1]["join_text"], "spo_list": i[1]["spo_list"]})
                        else:
                            # 长句子根据标点符号(sign)截断仍存在长句子
                            afa_ssr = []
                            split_one(i[1]["join_text"], i[1]["spo_list"], afa_ssr)
                            for i in afa_ssr:
                                D.append({"text": i[0], "spo_list": i[1]})
        return D


data = load_data('datasets_bdci/train_bdci.json')
print(data[0])
print(len(data))

{'text': '62号汽车故障报告综合情况:故障现象:加速后，丢开油门，发动机熄火。', 'spo_list': [('发动机', [28, 31], '熄火', [31, 33], '部件故障')]}
1559


In [7]:
class CustomDataset(Dataset):
    """定义数据集"""

    def __init__(self, items):
        self._items = items

    def __len__(self):
        return len(self._items)

    def __getitem__(self, index):
        item = self._items[index]

        return {'text': item['text'],
                'spo_list': item['spo_list']}

In [8]:
tokenizer_fast = AutoTokenizer.from_pretrained('junnyu/uer_large', use_fast=False)
tokenizer_fast.add_tokens(new_tokens=['[SP]'])
print(tokenizer_fast)

pretrained = AutoModel.from_pretrained('junnyu/uer_large')
pretrained.resize_token_embeddings(len(tokenizer_fast))
print(pretrained.num_parameters())

PreTrainedTokenizer(name_or_path='junnyu/uer_large', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Some weights of the model checkpoint at junnyu/uer_large were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


325523456


In [9]:
def padding_space(d):
    """将句子转换为字符列表,并将列表中的空格(' ')替换为'[SP]'"""
    if d.find(' ') == -1:
        return list(d)
    else:
        d_arr = np.array(list(d))
        d_arr = np.where(d_arr == ' ', '[SP]', d_arr).tolist()
        return d_arr


def get_collate_fn(tokenizer, max_len=512):
    """返回collate_fun函数(通过闭包函数引入形参)"""

    def collate_fn(data):
        batch_size = len(data)
        texts = [padding_space(i['text']) for i in data]
        spo_lists = [i['spo_list'] for i in data]

        encoder_text = tokenizer(texts, padding=True, max_length=max_len, truncation=True, is_split_into_words=True,
                                 return_tensors='pt')
        input_ids, token_type_ids, attention_mask = encoder_text.values()

        entity_labels, head_labels, tail_labels = [], [], []
        for i in range(batch_size):
            entity_labels_temp, head_labels_temp, tail_labels_temp = [[], []], [[], [], [], []], [[], [], [], []]
            if spo_lists[i]:
                for _, p_index, _, o_index, r in spo_lists[i]:
                    entity_labels_temp[0].append((p_index[0], p_index[1] - 1))
                    entity_labels_temp[1].append((o_index[0], o_index[1] - 1))
                    head_labels_temp[rel_to_id[r]].append((p_index[0], o_index[0]))
                    tail_labels_temp[rel_to_id[r]].append((p_index[1] - 1, o_index[1] - 1))
            else:
                # spo_lists为空列表时
                entity_labels_temp[0].append((0, 0))
                entity_labels_temp[1].append((0, 0))

            _, _ = [i.append((0, 0)) for i in head_labels_temp if not i], [i.append((0, 0)) for i in tail_labels_temp if
                                                                           not i]

            entity_labels_temp = torch.transpose(torch.tensor(entity_labels_temp), 0, 1)
            entity_labels.append(entity_labels_temp)

            head_labels_temp = [torch.tensor(i) for i in head_labels_temp]
            head_labels_temp = torch.transpose(pad_sequence(head_labels_temp, batch_first=True), 0, 1)
            head_labels.append(head_labels_temp)
            tail_labels_temp = [torch.tensor(i) for i in tail_labels_temp]
            tail_labels_temp = torch.transpose(pad_sequence(tail_labels_temp, batch_first=True), 0, 1)
            tail_labels.append(tail_labels_temp)

        entity_labels = torch.transpose(pad_sequence(entity_labels, batch_first=True), 1, 2)
        head_labels = torch.transpose(pad_sequence(head_labels, batch_first=True), 1, 2)
        tail_labels = torch.transpose(pad_sequence(tail_labels, batch_first=True), 1, 2)
        return input_ids, attention_mask, token_type_ids, entity_labels, head_labels, tail_labels, texts, spo_lists

    return collate_fn


train_loader = DataLoader(CustomDataset(items=data), batch_size=4, shuffle=True,
                          collate_fn=get_collate_fn(tokenizer_fast))
for i in train_loader:
    print(i[0].shape)
    print(i[3].shape)
    print(i[4].shape)
    break

torch.Size([4, 361])
torch.Size([4, 2, 9, 2])
torch.Size([4, 4, 9, 2])


In [10]:
hidden_size = pretrained.config.hidden_size
mention_detect = RawGlobalPointer(hidden_size, 2, 64).to(device)  # 不提取实体类型(只识别subject、object对应的实体)
s_o_head = RawGlobalPointer(hidden_size, len(id_to_rel), 64, RoPE=False, tril_mask=False).to(
    device)  # 不需要设置tril_mask=False
s_o_tail = RawGlobalPointer(hidden_size, len(id_to_rel), 64, RoPE=False, tril_mask=False).to(
    device)  # 不需要设置tril_mask=False
net = ERENet(pretrained, mention_detect, s_o_head, s_o_tail).to(device)

optimizer = torch.optim.AdamW(net.parameters(), lr=2e-5)

In [11]:
def extract_spoes(logits1, logits2, logits3, texts, id2predicate):
    logits1 = logits1.data.cpu().numpy()
    logits2 = logits2.data.cpu().numpy()
    logits3 = logits3.data.cpu().numpy()
    batch_size = logits1.shape[0]

    # 序列开头与结尾特殊token('[CLS]', '[SEP]')处元素设置为无穷小
    logits1[:, :, [0, -1]] -= np.inf
    logits1[:, :, :, [0, -1]] -= np.inf
    subjects, objects = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)]
    for b, l, h, t in zip(*np.where(logits1 > 0.0)):  # 阈值(threshold)设置为0.0
        if l == 0:  # 不提取实体类型(只识别subject、objects对应的实体)
            subjects[b].append((int(h), int(t)))
        else:
            objects[b].append((int(h), int(t)))

    spoes = [[] for _ in range(batch_size)]
    for b in range(batch_size):
        text_b = texts[b]
        for sh, st in subjects[b]:
            for oh, ot in objects[b]:
                p1s = np.where(logits2[b, :, sh, oh] > 0.0)[0]  # 阈值(threshold)设置为0.0
                p2s = np.where(logits3[b, :, st, ot] > 0.0)[0]  # 阈值(threshold)设置为0.0
                ps = set(p1s) & set(p2s)
                for p in ps:
                    sht_str = ''.join(text_b[sh: st + 1])
                    oht_str = ''.join(text_b[oh: ot + 1])
                    spoes[b].append((sht_str, (sh, st + 1), oht_str, (oh, ot + 1),
                                     id2predicate[str(p)]))  # 添加预测结果:(h_name, hpos, t_name, tpos, r)
    return spoes

In [12]:
# 模型训练
def train(model, dataloader, optimizer, id2predicate, device):
    model.train()

    for idx, (
    input_ids, attention_mask, token_type_ids, entity_labels, head_labels, tail_labels, texts, spo_lists) in enumerate(
            dataloader, start=1):
        # 数据设备切换
        # input_ids.shape=[batch_size, seq_len]
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        entity_labels = entity_labels.to(device)
        head_labels = head_labels.to(device)
        tail_labels = tail_labels.to(device)

        # logits1.shape=[batch_size, 2, seq_len, seq_len]
        # logits2.shape=[batch_size, len(schema) seq_len, seq_len]
        # logits3.shape=[batch_size, len(schema), seq_len, seq_len]
        logits1, logits2, logits3 = model(input_ids, attention_mask, token_type_ids)

        loss1 = sparse_multilabel_categorical_crossentropy(y_true=entity_labels, y_pred=logits1)
        loss2 = sparse_multilabel_categorical_crossentropy(y_true=head_labels, y_pred=logits2)
        loss3 = sparse_multilabel_categorical_crossentropy(y_true=tail_labels, y_pred=logits3)
        loss = sum([loss1, loss2, loss3]) / 3  # entities和relations之间的信息共享和交互
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0:
            extract_spoes(logits1, logits2, logits3, texts, id_to_rel)
            y_pred = extract_spoes(logits1, logits2, logits3, texts, id2predicate)
            mc = MetricsCalculator_bdci()  # 计算查准率、查全率、F1 score 
            mc.calc_confusion_matrix(y_pred, spo_lists)
            print('| step {:5d} | loss {:9.5f} | precision {:8.5f} | recall {:8.5f} | f1 {:8.5f} |'.format(idx,
                                                                                                           loss.item(),
                                                                                                           mc.precision,
                                                                                                           mc.recall,
                                                                                                           mc.f1))

In [13]:
# 模型验证
def evaluate(model, dataloader, id2predicate, device):
    model.eval()

    mc = MetricsCalculator_bdci()  # 计算查准率、查全率、F1 score 
    with torch.no_grad():
        for idx, (input_ids, attention_mask, token_type_ids, entity_labels, head_labels, tail_labels, texts,
                  spo_lists) in enumerate(dataloader):
            # 数据设备切换
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)

            logits1, logits2, logits3 = model(input_ids, attention_mask, token_type_ids)
            y_pred = extract_spoes(logits1, logits2, logits3, texts, id2predicate)

            mc.calc_confusion_matrix(y_pred, spo_lists)
    return mc.precision, mc.recall, mc.f1

In [14]:
for epoch in range(30):
    print('-' * 50 + str(epoch) + '-' * 50)
    train(net, train_loader, optimizer, id_to_rel, device)


--------------------------------------------------0--------------------------------------------------
| step   100 | loss  22.28444 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   200 | loss  22.22444 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   300 | loss  23.07591 | precision  0.33333 | recall  0.06667 | f1  0.11111 |
--------------------------------------------------1--------------------------------------------------
| step   100 | loss  27.81215 | precision  0.50000 | recall  0.06250 | f1  0.11111 |
| step   200 | loss  27.93113 | precision  0.00000 | recall  0.00000 | f1  0.00000 |
| step   300 | loss  10.15987 | precision  0.21429 | recall  0.13043 | f1  0.16216 |
--------------------------------------------------2--------------------------------------------------
| step   100 | loss  19.36986 | precision  1.00000 | recall  0.13333 | f1  0.23529 |
| step   200 | loss   1.13598 | precision  0.50000 | recall  0.21429 | f1  0.30000 |
| step   300 |