In [1]:
import warnings
warnings.filterwarnings("ignore")

import re
import os
import jieba
import torch
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split

from ark_nlp.model.ner.span_bert import SpanBert
from ark_nlp.model.ner.span_bert import SpanBertConfig
from ark_nlp.model.ner.span_bert import Dataset
from ark_nlp.model.ner.span_bert import Task
from ark_nlp.factory.optimizer import get_w2ner_model_optimizer as get_default_model_optimizer
from ark_nlp.factory.lr_scheduler import get_default_cosine_schedule_with_warmup
from ark_nlp.model.ner.span_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed
from ark_nlp.nn.layer.pooler_block import PoolerStartLogits, PoolerEndLogits
from transformers import AutoModel, AutoModelForPreTraining, AutoTokenizer, BertPreTrainedModel

In [2]:
set_seed(42)
tqdm.pandas(desc="inference")

In [3]:
def E_trans_to_C(string):
    E_pun = u',.!?[]()<>"\''
    C_pun = u'，。！？【】（）《》“‘'
    table= {ord(f):ord(t) for f,t in zip(E_pun,C_pun)}
    return string.translate(table)

In [4]:
test = pd.read_csv("data/test.csv", sep="\t")
train = pd.read_csv("data/train.csv", sep="\t")

In [5]:
test["text"] = test["text"].apply(lambda line: E_trans_to_C(re.sub("[\(《：；→，。、\-”]+$", "", line.strip())))
train["text"] = train["text"].apply(lambda line: E_trans_to_C(re.sub("[\(《：→；，。、\-”]+$", "", line.strip())))
train["tag"] = train["tag"].apply(lambda x: [E_trans_to_C(i) for i in eval(str(x))])

In [6]:
train["entities"] = train.progress_apply(lambda row: [["LOC", *i.span()] for tag in row["tag"] for i in re.finditer(tag, row["text"])], axis=1)

inference: 100%|██████████| 6000/6000 [00:00<00:00, 19546.85it/s]


In [7]:
datalist = []

for _, row in train.iterrows():
    entity_labels = []
    for _type, _start_idx, _end_idx in row["entities"]:
        entity_labels.append({
            'start_idx': _start_idx,
            'end_idx': _end_idx,
            'type': _type,
            'entity': row["text"][_start_idx: _end_idx]
    })

    datalist.append({
        'text': row["text"],
        'label': entity_labels
    })

In [8]:
data = pd.DataFrame(datalist)
train_data_df, dev_data_df = train_test_split(data, test_size=0.3)

In [9]:
tta = pd.read_csv("data/tta.csv", sep="\t")
tta["text"] = tta["text"].apply(lambda line: E_trans_to_C(re.sub("[\(《：→；，。、\-”]+$", "", line.strip())))
tta["tag"] = tta["tag"].apply(lambda x: [E_trans_to_C(i) for i in eval(str(x))])
tta["entities"] = tta.progress_apply(lambda row: [["LOC", *i.span()] for tag in row["tag"] for i in re.finditer(tag, row["text"])], axis=1)

tta_datalist = []

for _, row in train.iterrows():
    entity_labels = []
    for _type, _start_idx, _end_idx in row["entities"]:
        entity_labels.append({
            'start_idx': _start_idx,
            'end_idx': _end_idx,
            'type': _type,
            'entity': row["text"][_start_idx: _end_idx]
    })

    tta_datalist.append({
        'text': row["text"],
        'label': entity_labels
    })

tta_data = pd.DataFrame(tta_datalist)
train_data_df = pd.concat([train_data_df, tta_data]).reset_index(drop=True)

inference: 100%|██████████| 2628/2628 [00:00<00:00, 20239.83it/s]


In [10]:
train_data_df = train_data_df.loc[:,['text', 'label']]
train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))
dev_data_df = dev_data_df.loc[:,['text', 'label']]
dev_data_df['label'] = dev_data_df['label'].apply(lambda x: str(x))

In [11]:
ner_train_dataset = Dataset(train_data_df)
ner_dev_dataset = Dataset(dev_data_df)

In [12]:
tokenizer = Tokenizer(vocab='roberta-base-finetuned-cluener2020-chinese', max_seq_len=52)

In [13]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

In [14]:
class SpanDependenceBert(BertPreTrainedModel):
    """
    基于BERT指针的命名实体模型(end指针依赖start指针的结果)

    Args:
        config: 模型的配置对象
        bert_trained (:obj:`bool`, optional): 预训练模型的参数是否可训练
    """  # noqa: ignore flake8"

    def __init__(
        self,
        config,
        encoder_trained=True
    ):
        super(SpanDependenceBert, self).__init__(config)

        self.num_labels = config.num_labels

        self.bert = AutoModel.from_pretrained("./outputs/roberta-finetuned-cosine")

        for param in self.bert.parameters():
            param.requires_grad = encoder_trained

        self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
        self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        **kwargs
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
            output_hidden_states=True
        ).hidden_states

        sequence_output = outputs[-1]

        sequence_output = self.dropout(sequence_output)

        start_logits = self.start_fc(sequence_output)

        label_logits = F.softmax(start_logits, -1)
        label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()

        end_logits = self.end_fc(sequence_output, label_logits)

        return (start_logits, end_logits)

In [15]:
config = SpanBertConfig.from_pretrained('./outputs/roberta-finetuned-cosine', num_labels=len(ner_train_dataset.cat2id))

In [16]:
torch.cuda.empty_cache()

In [17]:
dl_module = SpanDependenceBert.from_pretrained('./outputs/roberta-finetuned-cosine', config=config)

Some weights of the model checkpoint at ./outputs/roberta-finetuned-cosine were not used when initializing BertModel: ['encoder.weight_hh_l0', 'convLayer.convs.2.bias', 'encoder.weight_ih_l0_reverse', 'encoder.weight_hh_l0_reverse', 'predictor.linear.bias', 'encoder.bias_ih_l0', 'encoder.bias_hh_l0_reverse', 'predictor.mlp_rel.linear.weight', 'cln.weight', 'predictor.mlp2.linear.weight', 'encoder.weight_ih_l0', 'convLayer.base.1.weight', 'encoder.bias_hh_l0', 'reg_embs.weight', 'convLayer.convs.1.weight', 'predictor.mlp1.linear.bias', 'dis_embs.weight', 'cln.bias', 'convLayer.convs.2.weight', 'predictor.mlp_rel.linear.bias', 'convLayer.convs.0.weight', 'cln.bias_dense.weight', 'encoder.bias_ih_l0_reverse', 'convLayer.convs.1.bias', 'predictor.mlp1.linear.weight', 'predictor.linear.weight', 'predictor.mlp2.linear.bias', 'cln.weight_dense.weight', 'convLayer.base.1.bias', 'convLayer.convs.0.bias', 'predictor.biaffine.weight']
- This IS expected if you are initializing BertModel from the 

In [18]:
# 设置运行次数
num_epoches = 30
batch_size = 256
# 注意lr衰减轮次的设定
show_step = len(ner_train_dataset) // batch_size + 2
t_total = len(ner_train_dataset) // batch_size * num_epoches

In [19]:
optimizer = get_default_model_optimizer(dl_module, lr=1e-2, bert_lr=5e-5, weight_decay=0.01)
scheduler = get_default_cosine_schedule_with_warmup(optimizer, t_total, warmup_ratio=0.1)

In [20]:
model = Task(dl_module, optimizer, 'ce', cude_device=2, scheduler=None, grad_clip=10.0, ema_decay=0.995, fgm_attack=True, save_path="outputs/roberta-finetuned-spanbert")

In [21]:
model.fit(
    ner_train_dataset,
    ner_dev_dataset,
    lr=2e-4,
    epochs=num_epoches,
    batch_size=batch_size,
    show_step=show_step
)

100%|██████████| 40/40 [00:53<00:00,  1.33s/it]


epoch:[0],train loss is:0.150960 

eval_info:  {'acc': 0.786046511627907, 'recall': 0.8848167539267016, 'f1': 0.832512315270936}
entity_info:  {'LOC': {'acc': 0.786, 'recall': 0.8848, 'f1': 0.8325}}


100%|██████████| 40/40 [00:54<00:00,  1.35s/it]


epoch:[1],train loss is:0.027976 

eval_info:  {'acc': 0.7467811158798283, 'recall': 0.9109947643979057, 'f1': 0.820754716981132}
entity_info:  {'LOC': {'acc': 0.7468, 'recall': 0.911, 'f1': 0.8208}}


100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


epoch:[2],train loss is:0.017292 

eval_info:  {'acc': 0.7927927927927928, 'recall': 0.9214659685863874, 'f1': 0.8523002421307506}
entity_info:  {'LOC': {'acc': 0.7928, 'recall': 0.9215, 'f1': 0.8523}}


100%|██████████| 40/40 [00:53<00:00,  1.34s/it]


epoch:[3],train loss is:0.015044 

eval_info:  {'acc': 0.7990867579908676, 'recall': 0.9162303664921466, 'f1': 0.8536585365853657}
entity_info:  {'LOC': {'acc': 0.7991, 'recall': 0.9162, 'f1': 0.8537}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[4],train loss is:0.012262 

eval_info:  {'acc': 0.7807017543859649, 'recall': 0.9319371727748691, 'f1': 0.8496420047732697}
entity_info:  {'LOC': {'acc': 0.7807, 'recall': 0.9319, 'f1': 0.8496}}


100%|██████████| 40/40 [00:54<00:00,  1.35s/it]


epoch:[5],train loss is:0.009795 

eval_info:  {'acc': 0.8301886792452831, 'recall': 0.9214659685863874, 'f1': 0.8734491315136477}
entity_info:  {'LOC': {'acc': 0.8302, 'recall': 0.9215, 'f1': 0.8734}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[6],train loss is:0.009199 

eval_info:  {'acc': 0.8130841121495327, 'recall': 0.9109947643979057, 'f1': 0.8592592592592593}
entity_info:  {'LOC': {'acc': 0.8131, 'recall': 0.911, 'f1': 0.8593}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[7],train loss is:0.007964 

eval_info:  {'acc': 0.8110599078341014, 'recall': 0.9214659685863874, 'f1': 0.8627450980392156}
entity_info:  {'LOC': {'acc': 0.8111, 'recall': 0.9215, 'f1': 0.8627}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[8],train loss is:0.007751 

eval_info:  {'acc': 0.7882882882882883, 'recall': 0.9162303664921466, 'f1': 0.847457627118644}
entity_info:  {'LOC': {'acc': 0.7883, 'recall': 0.9162, 'f1': 0.8475}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[9],train loss is:0.007667 

eval_info:  {'acc': 0.8148148148148148, 'recall': 0.9214659685863874, 'f1': 0.8648648648648648}
entity_info:  {'LOC': {'acc': 0.8148, 'recall': 0.9215, 'f1': 0.8649}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[10],train loss is:0.007375 

eval_info:  {'acc': 0.7918552036199095, 'recall': 0.9162303664921466, 'f1': 0.8495145631067962}
entity_info:  {'LOC': {'acc': 0.7919, 'recall': 0.9162, 'f1': 0.8495}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[11],train loss is:0.006950 

eval_info:  {'acc': 0.7873303167420814, 'recall': 0.9109947643979057, 'f1': 0.8446601941747572}
entity_info:  {'LOC': {'acc': 0.7873, 'recall': 0.911, 'f1': 0.8447}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[12],train loss is:0.006558 

eval_info:  {'acc': 0.7797356828193832, 'recall': 0.9267015706806283, 'f1': 0.8468899521531101}
entity_info:  {'LOC': {'acc': 0.7797, 'recall': 0.9267, 'f1': 0.8469}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[13],train loss is:0.006125 

eval_info:  {'acc': 0.7719298245614035, 'recall': 0.9214659685863874, 'f1': 0.8400954653937949}
entity_info:  {'LOC': {'acc': 0.7719, 'recall': 0.9215, 'f1': 0.8401}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[14],train loss is:0.006061 

eval_info:  {'acc': 0.7729257641921398, 'recall': 0.9267015706806283, 'f1': 0.8428571428571427}
entity_info:  {'LOC': {'acc': 0.7729, 'recall': 0.9267, 'f1': 0.8429}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[15],train loss is:0.006453 

eval_info:  {'acc': 0.7629310344827587, 'recall': 0.9267015706806283, 'f1': 0.8368794326241136}
entity_info:  {'LOC': {'acc': 0.7629, 'recall': 0.9267, 'f1': 0.8369}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[16],train loss is:0.007121 

eval_info:  {'acc': 0.7927927927927928, 'recall': 0.9214659685863874, 'f1': 0.8523002421307506}
entity_info:  {'LOC': {'acc': 0.7928, 'recall': 0.9215, 'f1': 0.8523}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[17],train loss is:0.006414 

eval_info:  {'acc': 0.7857142857142857, 'recall': 0.9214659685863874, 'f1': 0.8481927710843373}
entity_info:  {'LOC': {'acc': 0.7857, 'recall': 0.9215, 'f1': 0.8482}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[18],train loss is:0.006651 

eval_info:  {'acc': 0.7847533632286996, 'recall': 0.9162303664921466, 'f1': 0.8454106280193238}
entity_info:  {'LOC': {'acc': 0.7848, 'recall': 0.9162, 'f1': 0.8454}}


100%|██████████| 40/40 [00:54<00:00,  1.35s/it]


epoch:[19],train loss is:0.006041 

eval_info:  {'acc': 0.7847533632286996, 'recall': 0.9162303664921466, 'f1': 0.8454106280193238}
entity_info:  {'LOC': {'acc': 0.7848, 'recall': 0.9162, 'f1': 0.8454}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[20],train loss is:0.005669 

eval_info:  {'acc': 0.7822222222222223, 'recall': 0.9214659685863874, 'f1': 0.8461538461538461}
entity_info:  {'LOC': {'acc': 0.7822, 'recall': 0.9215, 'f1': 0.8462}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[21],train loss is:0.005656 

eval_info:  {'acc': 0.7937219730941704, 'recall': 0.9267015706806283, 'f1': 0.8550724637681159}
entity_info:  {'LOC': {'acc': 0.7937, 'recall': 0.9267, 'f1': 0.8551}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[22],train loss is:0.005401 

eval_info:  {'acc': 0.8119266055045872, 'recall': 0.9267015706806283, 'f1': 0.8655256723716381}
entity_info:  {'LOC': {'acc': 0.8119, 'recall': 0.9267, 'f1': 0.8655}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[23],train loss is:0.005581 

eval_info:  {'acc': 0.8036529680365296, 'recall': 0.9214659685863874, 'f1': 0.8585365853658535}
entity_info:  {'LOC': {'acc': 0.8037, 'recall': 0.9215, 'f1': 0.8585}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[24],train loss is:0.005568 

eval_info:  {'acc': 0.8165137614678899, 'recall': 0.9319371727748691, 'f1': 0.8704156479217604}
entity_info:  {'LOC': {'acc': 0.8165, 'recall': 0.9319, 'f1': 0.8704}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[25],train loss is:0.005165 

eval_info:  {'acc': 0.8202764976958525, 'recall': 0.9319371727748691, 'f1': 0.8725490196078431}
entity_info:  {'LOC': {'acc': 0.8203, 'recall': 0.9319, 'f1': 0.8725}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[26],train loss is:0.005118 

eval_info:  {'acc': 0.7946428571428571, 'recall': 0.9319371727748691, 'f1': 0.8578313253012048}
entity_info:  {'LOC': {'acc': 0.7946, 'recall': 0.9319, 'f1': 0.8578}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[27],train loss is:0.004777 

eval_info:  {'acc': 0.8, 'recall': 0.9214659685863874, 'f1': 0.856447688564477}
entity_info:  {'LOC': {'acc': 0.8, 'recall': 0.9215, 'f1': 0.8564}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[28],train loss is:0.005030 

eval_info:  {'acc': 0.7652173913043478, 'recall': 0.9214659685863874, 'f1': 0.8361045130641329}
entity_info:  {'LOC': {'acc': 0.7652, 'recall': 0.9215, 'f1': 0.8361}}


100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


epoch:[29],train loss is:0.005025 

eval_info:  {'acc': 0.7619047619047619, 'recall': 0.9214659685863874, 'f1': 0.834123222748815}
entity_info:  {'LOC': {'acc': 0.7619, 'recall': 0.9215, 'f1': 0.8341}}


In [28]:
import ark_nlp.model.ner.span_bert as span
import imp
imp.reload(span)

<module 'ark_nlp.model.ner.span_bert' from '/data/lpzhang/ner/ark_nlp/model/ner/span_bert/__init__.py'>

In [22]:
# from ark_nlp.model.ner.span_bert import Predictor

In [29]:
ner_predictor_instance = span.Predictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [32]:
predict_results = []
tta_data = []

for _line in tqdm(test["text"].tolist()):
    label = set()
    for _preditc in ner_predictor_instance.predict_one_sample(_line):
        label.add(_preditc["entity"][:-1])
    
    label = list(label)
    if len(label) > 0:
        tta_data.append([_line, label])

    predict_results.append(label)

100%|██████████| 2657/2657 [00:25<00:00, 104.58it/s]


In [33]:
with open('spanbert_submit.txt', 'w', encoding='utf-8') as f:
    f.write("tag\n")
    for _result in predict_results:
       f.write(f"{str(_result)}\n")

In [26]:
# tta_data = pd.DataFrame(tta_data, columns=["text", "tag"])
# tta_data.to_csv("data/tta.csv", index=False, encoding="utf-8", sep="\t")