In [8]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [9]:
#!pip install evaluate
#!pip install seqeval

In [10]:
"""
* 参考案例，使用指定的数据集，编写代码实现ner模型训练和推理。
  https://huggingface.co/datasets/doushabao4766/msra_ner_k_V3
* 完成预测结果的实体抽取。
  输入：“双方确定了今后发展中美关系的指导方针。”
  输出：[{"entity":"ORG","content":"中"},{"entity":"ORG","content":"美"}]
"""
from transformers import AutoModelForTokenClassification,AutoTokenizer,TrainingArguments,Trainer
import torch
import evaluate  # pip install evaluate
import seqeval   # pip install seqeval
from datasets import load_dataset
from transformers import DataCollatorForTokenClassification
import numpy as np

In [11]:
model = AutoModelForTokenClassification.from_pretrained('google-bert/bert-base-chinese', num_labels=7)
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

In [12]:
## 实体映射数据集词典准备
entites = ['O'] + list({'PER', 'ORG', 'LOC'})
tags = ['O']
for entity in entites[1:]:
    tags.append('B-'+entity)
    tags.append('I-'+entity)
print(tags)

['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']


In [13]:
# 定义回调函数处理数据
def process_data(items):
    input_data = {} 
    max_length = 512  # 模型支持的最大长度
    # 生成iput_ids, token_type_ids, attention_mask, labels
    input_ids = []
    for tokens in items['tokens']:
        token_ids = tokenizer.convert_tokens_to_ids(tokens)
        # 截断 token_ids 到最大长度
        token_ids = token_ids[:max_length]
        input_ids.append(token_ids)
    input_data['input_ids'] = input_ids
    input_data['token_type_ids'] = [[0]*len(token_ids) for token_ids in input_ids]
    input_data['attention_mask'] = [[1]*len(token_ids) for token_ids in input_ids]
    # 对标签进行同样的截断操作
    input_data['labels'] = [labels[:max_length] for labels in items['ner_tags']]
    return input_data

In [14]:
# 加载hf中dataset
ds = load_dataset('doushabao4766/msra_ner_k_V3')
ds1 = ds.map(process_data, batched=True)  # batched 每次传入自定义方法样本数量多个，加快处理速度
ds1.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

README.md:   0%|          | 0.00/697 [00:00<?, ?B/s]

(…)-00000-of-00001-42717a92413393f9.parquet:   0%|          | 0.00/13.9M [00:00<?, ?B/s]

(…)-00000-of-00001-8899cab5fdab45bc.parquet:   0%|          | 0.00/946k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/45001 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3443 [00:00<?, ? examples/s]

Map:   0%|          | 0/45001 [00:00<?, ? examples/s]

Map:   0%|          | 0/3443 [00:00<?, ? examples/s]

In [None]:
# 模型训练
num_labels = len(tags)
id2label = {i: label for i, label in enumerate(tags)}
label2id = {v: k for k, v in id2label.items()}
model = AutoModelForTokenClassification.from_pretrained('google-bert/bert-base-chinese',
                                                        num_labels=num_labels,
                                                        id2label=id2label,
                                                        label2id=label2id)

args = TrainingArguments(
    output_dir="ner_train",  # 模型训练工作目录（tensorboard，临时模型存盘文件，日志）
    num_train_epochs = 3,    # 训练 epoch
    save_safetensors=False,  # 设置False保存文件可以通过torch.load加载
    per_device_train_batch_size=32,  # 训练批次
    per_device_eval_batch_size=32,
    report_to='tensorboard',  # 训练输出记录，不写的话会默认到XX网站里,所以要写上
    eval_strategy="epoch",
)

# metric 方法
def compute_metric(result):
    # result 是一个tuple (predicts, labels)
    
    # 获取评估对象
    seqeval = evaluate.load('seqeval')
    predicts,labels = result
    predicts = np.argmax(predicts, axis=2)
    
    # 准备评估数据
    predicts = [[tags[p] for p,l in zip(ps,ls) if l != -100]
                 for ps,ls in zip(predicts,labels)]
    labels = [[tags[l] for p,l in zip(ps,ls) if l != -100]
                 for ps,ls in zip(predicts,labels)]
    results = seqeval.compute(predictions=predicts, references=labels)

    return results

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)

trainer = Trainer(
    model,
    args,
    train_dataset=ds1['train'],
    eval_dataset=ds1['test'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metric,
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [16]:
trainer.train()

Epoch,Training Loss,Validation Loss,Loc,Org,Per,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
1,0.0317,0.028218,"{'precision': 0.9512020093290277, 'recall': 0.929523141654979, 'f1': 0.9402376307856003, 'number': 2852}","{'precision': 0.9044368600682594, 'recall': 0.8815701929474384, 'f1': 0.8928571428571429, 'number': 1503}","{'precision': 0.8492753623188406, 'recall': 0.8878787878787879, 'f1': 0.8681481481481482, 'number': 1320}",0.914062,0.907137,0.910586,0.991807
2,0.0139,0.028152,"{'precision': 0.9521295318549806, 'recall': 0.9484572230014026, 'f1': 0.950289829615317, 'number': 2852}","{'precision': 0.9390797148412184, 'recall': 0.9640718562874252, 'f1': 0.9514116874589625, 'number': 1503}","{'precision': 0.8715728715728716, 'recall': 0.9151515151515152, 'f1': 0.8928307464892832, 'number': 1320}",0.929289,0.944846,0.937003,0.99277
3,0.0057,0.030974,"{'precision': 0.9578761061946902, 'recall': 0.9488078541374474, 'f1': 0.9533204157125242, 'number': 2852}","{'precision': 0.9479921000658328, 'recall': 0.9580838323353293, 'f1': 0.9530112508272667, 'number': 1503}","{'precision': 0.8809523809523809, 'recall': 0.925, 'f1': 0.9024390243902439, 'number': 1320}",0.936649,0.945727,0.941166,0.993528


Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

Trainer is attempting to log a value of "{'precision': 0.9512020093290277, 'recall': 0.929523141654979, 'f1': 0.9402376307856003, 'number': 2852}" of type <class 'dict'> for key "eval/LOC" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.9044368600682594, 'recall': 0.8815701929474384, 'f1': 0.8928571428571429, 'number': 1503}" of type <class 'dict'> for key "eval/ORG" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.8492753623188406, 'recall': 0.8878787878787879, 'f1': 0.8681481481481482, 'number': 1320}" of type <class 'dict'> for key "eval/PER" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.9521295318549806, 'recall': 0.9484572230014026,

TrainOutput(global_step=4221, training_loss=0.021817263170769645, metrics={'train_runtime': 2071.7816, 'train_samples_per_second': 65.163, 'train_steps_per_second': 2.037, 'total_flos': 9713864313512304.0, 'train_loss': 0.021817263170769645, 'epoch': 3.0})

In [17]:
# 模型保存
model.save_pretrained('/kaggle/working/bert-ner-model')
tokenizer.save_pretrained('/kaggle/working/bert-ner-model')

('/kaggle/working/bert-ner-model/tokenizer_config.json',
 '/kaggle/working/bert-ner-model/special_tokens_map.json',
 '/kaggle/working/bert-ner-model/vocab.txt',
 '/kaggle/working/bert-ner-model/added_tokens.json',
 '/kaggle/working/bert-ner-model/tokenizer.json')

In [38]:
# 模型推理 - 常规方法 model()
model = AutoModelForTokenClassification.from_pretrained('/kaggle/working/bert-ner-model')
tokenizer = AutoTokenizer.from_pretrained('/kaggle/working/bert-ner-model')
def predict(text):
    inputs = tokenizer(text, return_tensors='pt',add_special_tokens=False)
    with torch.no_grad():
        logits = model(**inputs).logits
    predictions = torch.argmax(logits, dim=-1)
    labels = [[tags[i] for i in prediction] for prediction in predictions]
    # 按照以下格式返回 输出：[{"entity":"ORG","content":"中"},{"entity":"ORG","content":"美"}]
    labels = [{"entity":label.split("-")[-1],"content":text[i]} for i,label in enumerate(labels[0]) if label != "O"]
    return labels
print(predict('双方确定了今后发展中美关系的指导方针'))

[{'entity': 'LOC', 'content': '中'}, {'entity': 'LOC', 'content': '美'}]


In [45]:
# 模型推理 - trainer.predict
text= '双方确定了今后发展中美关系的指导方针'
input_d = tokenizer(text,add_special_tokens=False)
result = trainer.predict([input_d])  #预测dataset没有label标签，所以label_ids=None，需用predictions计算
predictions = torch.argmax(torch.tensor(result.predictions), dim=-1)
labels = [[tags[i] for i in prediction] for prediction in predictions]
labels = [{"entity":label.split("-")[-1],"content":text[i]} for i,label in enumerate(labels[0]) if label != "O"]
print(labels)

[{'entity': 'LOC', 'content': '中'}, {'entity': 'LOC', 'content': '美'}]
