In [1]:
import torch
import os
from datasets import load_dataset

token = os.environ.get("HF_TOKEN")

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("hfl/rbt6")

dataset = load_dataset('peoples_daily_ner',token=token)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        #names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
        # 如果实体是个3字人名，应该表现为 122；

        #在线加载数据集
        dataset = load_dataset(path='peoples_daily_ner', split=split)
        
        #离线加载数据集
#         data_path = '/Users/xx/Downloads/NER_in_Chinese-main/data'
        #dataset = load_from_disk(dataset_path=data_path)[split]

        #过滤掉太长的句子
        def f(data):
            return len(data['tokens']) <= 512 - 2

        dataset = dataset.filter(f)

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']

        return tokens, labels

# 数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]

    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)
    

    # 找到最长的句子 长度
    lens = inputs['input_ids'].shape[1]
    
    # 对句子进行补充，头部/尾部+7，然后截取到最长的长度； 7不存在于
    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]

    return inputs, torch.LongTensor(labels)


# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True,pin_memory=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
from transformers import AutoModel

# 加载预训练模型
pretrained = AutoModel.from_pretrained('hfl/rbt6')

# 定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = None  # 默认情况下，让预训练模型 并不属于下游任务模型的一部分；
        
        # 这里定义下游模型只有简单两层网络 RNN(GRU)+全连接
        self.rnn = torch.nn.GRU(768, 768,batch_first=True)
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, inputs):
        
        inputs = inputs
        if self.tuneing:  # 如果是 tuneing模式，预训练模型属于自己模型的一部分，使用自己的预训练模型进行计算 
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():  # 使用外部的预训练模型进行计算
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)

        out = self.fc(out).softmax(dim=2)

        return out
    
    # 切换模型是否 tuneing 的模式 
    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        if tuneing:
            for i in pretrained.parameters():
                i.requires_grad = True  # 预训练模型也进行参数更新，计算参数梯度 

            pretrained.train() # 预训练模型也是训练模式 
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval() # 运行模式，不梯度更新
            self.pretrained = None


model = Model()


In [6]:
import torch

model_path='./model/ner.model'
def predict():
    model_load = torch.load(model_path)
    model_load = model_load.cpu()
    model_load.eval()

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (inputs, labels) in enumerate(loader_test):
        break

    with torch.no_grad():
        #[b, lens] -> [b, lens, 8] -> [b, lens]
        outs = model_load(inputs).argmax(dim=2)

    for i in range(32):
        #移除pad
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i, select]
        label = labels[i, select]
        
        #输出原句子
        print(tokenizer.decode(input_id).replace(' ', ''))

        #输出tag
        for tag in [label, out]:
            s = ''
            for j in range(len(tag)):
                if tag[j] == 0:
                    s += '·'
                    continue
                s += tokenizer.decode(input_id[j])
                s += str(tag[j].item())

            print(s)
        print('==========================')


predict()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


[CLS]来自吉林各地的优秀青年兴业领头人、下岗青工创业标兵代表和数百名青年在长春汇聚一堂，畅谈如何响应团省委的号召，通过自己的奋斗，再创新业。[SEP]
[CLS]7··吉5林6······························长5春6···········团3省4委4·················[SEP]7
[CLS]7···································································[SEP]7
[CLS]由于经营有方，生意日益红火，这家饭店几乎天天顾客爆满，经营规模不断扩大，由一层楼变成了两层楼，又在市区内开了一家分店，为近百人解决了就业问题，但饭店从业人员中几乎没有一名一汽下岗职工和职工子弟，大部分都是外来的农民工。[SEP]
[CLS]7·····················································································一3汽4······················[SEP]7
[CLS]7·············································································································[SEP]7
[CLS]戴、亨二人实力相当，前4局战成2∶2平。[SEP]
[CLS]7戴1·亨1·················[SEP]7
[CLS]7····················[SEP]7
[CLS]六十年代初，吴先生到了北京，还到我家做客。[SEP]
[CLS]7······吴1····北5京6········[SEP]7
[CLS]7·····················[SEP]7
[CLS]预计，17日晚上到18日白天，江南、华南、贵州、重庆等地有小到中雨，其中江南中北部、华南西部、贵州等地的部分地区有大到暴雨，局部地区有大暴雨，华北、东北有小到中雨，局部地区有大到暴雨并有短时雷雨大风或冰雹。[SEP]
[CLS]7···············江5南6·华5南6·贵5州6·重5庆6··········江5南6····华5南6···贵5州6····