In [1]:
import torch
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

from transformers import AutoTokenizer

#加载分词器
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt6')

print(tokenizer)

#分词测试
tokenizer.batch_encode_plus(
    [[
        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',
        '的', '海', '域', '。'
    ],
     [
         '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',
         '流', '的', '设', '计', '师', '主', '持', '设', '计', '，', '整', '个', '建', '筑',
         '群', '精', '美', '而', '恢', '宏', '。'
     ]],
    truncation=True,
    padding=True,
    return_tensors='pt',
    is_split_into_words=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


BertTokenizerFast(name_or_path='hfl/rbt6', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


{'input_ids': tensor([[ 101, 3862, 7157, 3683, 6612, 1765, 4157, 1762, 1336, 7305,  680, 7032,
         7305,  722, 7313, 4638, 3862, 1818,  511,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0],
        [ 101, 6821, 2429,  898, 2255,  988, 3717, 4638, 1300, 4289, 7667, 4507,
         1744, 1079,  671, 3837, 4638, 6392, 6369, 2360,  712, 2898, 6392, 6369,
         8024, 3146,  702, 2456, 5029, 5408, 5125, 5401, 5445, 2612, 2131,  511,
          102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [2]:
from datasets import load_dataset


class Dataset(torch.utils.data.Dataset):

    def __init__(self, split):
        #names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

        #在线加载数据集
        dataset = load_dataset(path='lansinuote/peoples-daily-ner',
                               split=split)

        #过滤掉太长的句子
        def f(data):
            return len(data['tokens']) <= 512 - 2

        dataset = dataset.filter(f)

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']

        return tokens, labels


dataset = Dataset('train')

tokens, labels = dataset[0]

len(dataset), tokens, labels

(20852,
 ['海',
  '钓',
  '比',
  '赛',
  '地',
  '点',
  '在',
  '厦',
  '门',
  '与',
  '金',
  '门',
  '之',
  '间',
  '的',
  '海',
  '域',
  '。'],
 [0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 5, 6, 0, 0, 0, 0, 0, 0])

In [3]:
#数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]

    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)

    lens = inputs['input_ids'].shape[1]

    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]

    return inputs, torch.LongTensor(labels)


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

#查看数据样例
for i, (inputs, labels) in enumerate(loader):
    break

print(len(loader))
print(tokenizer.decode(inputs['input_ids'][0]))
print(labels[0])

for k, v in inputs.items():
    print(k, v.shape)

1303
[CLS] 国 家 银 行 主 要 通 过 对 货 币 投 放 量 和 利 息 率 的 控 制 ， 达 到 稳 定 本 国 货 币 、 增 加 外 汇 储 备 、 抑 制 通 货 膨 胀 的 目 的 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tensor([7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
input_ids torch.Size([16, 94])
token_type_ids torch.Size([16, 94])
attention_mask torch.Size([16, 94])


In [4]:
from transformers import AutoModel

#加载预训练模型
pretrained = AutoModel.from_pretrained('hfl/rbt6')
pretrained.requires_grad_(False)

#统计参数量
print(sum(i.numel() for i in pretrained.parameters()) / 10000)

#模型试算
#[b, lens] -> [b, lens, 768]
pretrained(**inputs).last_hidden_state.shape

5974.0416


torch.Size([16, 94, 768])

In [5]:
#定义下游模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = pretrained

        self.rnn = torch.nn.GRU(768, 768, batch_first=True)
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, inputs):
        out = self.pretrained(**inputs).last_hidden_state
        out, _ = self.rnn(out)
        out = self.fc(out).softmax(dim=2)

        return out

    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        self.pretrained.requires_grad_(tuneing)


model = Model()

with torch.no_grad():
    print(model(inputs).shape)

torch.Size([16, 94, 8])


In [6]:
#对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    #变形,便于计算loss
    #[b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    #[b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    #忽略对pad的计算结果
    #[b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels


reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),
                       torch.ones(2, 3))

(tensor([[-1.2455,  0.6611, -0.9850, -0.7210, -0.2470, -0.1299, -1.0160,  0.0246],
         [ 1.0440,  2.8605,  1.2175,  0.5834,  0.1893,  0.4375,  0.6651, -0.7127],
         [ 0.8248,  0.9940, -0.5361,  1.3210,  0.2208,  0.2666,  1.6294,  1.4698],
         [ 0.4734,  1.0645,  1.2387, -0.4024,  0.6151, -1.4682, -0.2201, -1.8569],
         [ 0.1573, -1.2819, -1.1607, -0.3261,  0.1492, -0.2599, -1.1425, -0.3315],
         [-1.2250, -0.4709, -0.5910,  0.4635,  0.1000, -0.0412, -0.0479,  0.9608]]),
 tensor([1., 1., 1., 1., 1., 1.]))

In [7]:
#获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    #[b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    #计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content


get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

(3, 16, 3, 16)

In [8]:
#训练
def train(epochs):
    lr = 2e-5 if model.tuneing else 5e-4

    #训练
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    model.train()

    for epoch in range(epochs):
        for step, (inputs, labels) in enumerate(loader):
            #模型计算
            #[b, lens] -> [b, lens, 8]
            outs = model(inputs)

            #对outs和label变形,并且移除pad
            #outs -> [b, lens, 8] -> [c, 8]
            #labels -> [b, lens] -> [c]
            outs, labels = reshape_and_remove_pad(outs, labels,
                                                  inputs['attention_mask'])

            #梯度下降
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step % 200 == 0:
                counts = get_correct_and_total_count(labels, outs)

                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]

                print(epoch, step, len(loader), loss.item(), accuracy,
                      accuracy_content)

        torch.save(model.state_dict(), 'model/命名实体识别_中文.model')


model.fine_tuneing(False)
print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10000)
#train(1)

354.9704


In [9]:
model.fine_tuneing(True)
print(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10000)
#train(5)

6329.012


In [10]:
#测试
def test():
    model_load = Model()
    model_load.load_state_dict(torch.load('model/命名实体识别_中文.model'))

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=128,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        if step == 5:
            break
        print(step)

        with torch.no_grad():
            #[b, lens] -> [b, lens, 8] -> [b, lens]
            outs = model_load(inputs)

        #对outs和label变形,并且移除pad
        #outs -> [b, lens, 8] -> [c, 8]
        #labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels,
                                              inputs['attention_mask'])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)


test()

0
1
2
3
4
0.9903343782654127 0.9527783878761257


In [11]:
#测试
def predict():
    model_load = Model()
    model_load.load_state_dict(torch.load('model/命名实体识别_中文.model'))

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (inputs, labels) in enumerate(loader_test):
        break

    with torch.no_grad():
        #[b, lens] -> [b, lens, 8] -> [b, lens]
        outs = model_load(inputs).argmax(dim=2)

    for i in range(32):
        #移除pad
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i, select]
        label = labels[i, select]

        #输出原句子
        print(tokenizer.decode(input_id).replace(' ', ''))

        #输出tag
        for tag in [label, out]:
            s = ''
            for j in range(len(tag)):
                if tag[j] == 0:
                    s += '·'
                    continue
                s += tokenizer.decode(input_id[j])
                s += str(tag[j].item())

            print(s)
        print('==========================')


predict()

[CLS]此案经海口市中级人民法院一审，以受贿罪判处辛业江有期徒刑5年，依法追缴其非法所得。[SEP]
[CLS]7···海3口4市4中4级4人4民4法4院4·········辛1业2江2·················[SEP]7
[CLS]7···海5口4市4中4级4人4民4法4院4·········辛1业2江2·················[SEP]7
[CLS]谁知此水通天，不觉流进银河，而且直至宝光四射的斗牛宫。[SEP]
[CLS]7···········银5河6··········斗5牛6宫6·[SEP]7
[CLS]7·······················斗5牛6宫6·[SEP]7
[CLS][UNK]让世界少一点痛苦，为人间添一份幸福[UNK]。[SEP]
[CLS]7····················[SEP]7
[CLS]7····················[SEP]7
[CLS]这种经历使我得以较深刻地了解打工生活的甘苦。[SEP]
[CLS]7······················[SEP]7
[CLS]7······················[SEP]7
[CLS]1979年到1997年11月，欧盟对中国产品反倾销立案调查达71起，给中国出口造成很大损失。[SEP]
[CLS]7···············欧3盟4·中5国6···············中5国6·········[SEP]7
[CLS]7···············欧3盟4·中5国6···············中5国6·········[SEP]7
[CLS]在这一点上，高敏是经受考验、令人折服的。[SEP]
[CLS]7······高1敏2············[SEP]7
[CLS]7······高1敏2············[SEP]7
[CLS]美国电信管理部门、司法部门的官员对这起兼并反应谨慎。[SEP]
[CLS]7美5国6························[SEP]7
[CLS]7美5国6························[SEP]7
[CLS]四是不准出现村务不公开的村，确保100％的村实行财务公开、政策公开、政务公开，接受群众监督。[SEP]
[CLS]7············