In [1]:
from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',
                                          use_fast=True)

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus(
    [['Hello, this one sentence!', 'This is another sentence.']])

PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})


{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [2]:
from datasets import load_dataset, load_from_disk

#加载数据
#dataset = load_dataset(path='wmt16', name='ro-en')
dataset = load_from_disk('datas/wmt16/ro-en')

#采样,数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))


#数据预处理
def preprocess_function(data):
    #取出数据中的en和ro
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]

    #源语言直接编码就行了
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)

    #目标语言在特殊模块中编码
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(
            ro, max_length=128, truncation=True)['input_ids']

    return data


dataset = dataset.map(function=preprocess_function,
                      batched=True,
                      batch_size=1000,
                      num_proc=4,
                      remove_columns=['translation'])

print(dataset['train'][0])

dataset

Loading cached shuffled indices for dataset at datas/wmt16/ro-en/train/cache-91525599c6b01037.arrow
Loading cached shuffled indices for dataset at datas/wmt16/ro-en/validation/cache-4a013ce783f1228a.arrow
Loading cached shuffled indices for dataset at datas/wmt16/ro-en/test/cache-c5fe17364be6807b.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/train/cache-1075e0b8ac206327.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/train/cache-9fe8e553b9907402.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/train/cache-fc42acdf7f6ebd0d.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/train/cache-196974f9ef2169b9.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/validation/cache-5906b054d2fcfe3c.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/validation/cache-fd904127e528ee0c.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/validation/cache-4e6ed6815ec94f7a.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/validation/cache-7154642fbead0298.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/test/cache-905606071ab5b8be.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/test/cache-75c11d7064f6a327.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/test/cache-b1c2b329c4e2d34e.arrow


 

Loading cached processed dataset at datas/wmt16/ro-en/test/cache-8670277652e40452.arrow


{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

In [3]:
#这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入
#from transformers import DataCollatorForSeq2Seq
#DataCollatorForSeq2Seq(tokenizer, model=model)

import torch


#数据批处理函数
def collate_fn(data):
    #求最长的label
    max_length = max([len(i['labels']) for i in data])

    #把所有的label都补pad到最长
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads

    #把多个数据整合成一个tensor
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt',
    )

    #定义decoder_input_ids
    data['decoder_input_ids'] = torch.full_like(data['labels'],
                                                tokenizer.get_vocab()['<pad>'],
                                                dtype=torch.long)
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] ==
                              -100] = tokenizer.get_vocab()['<pad>']

    return data


data = [{
    'input_ids': [21603, 10, 37, 3719, 13],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [10455, 120, 80]
}, {
    'input_ids': [21603, 10, 7086, 8408, 563],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [301, 53, 4074, 1669]
}]

collate_fn(data)['decoder_input_ids']

tensor([[59542, 10455,   120,    80],
        [59542,   301,    53,  4074]])

In [4]:
import torch

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:2])

len(loader)

input_ids torch.Size([8, 51]) tensor([[  363,    63,    32,    51,   154,  1574,  5352,    14,     4,  2196,
            14,   456,     8,  3562,    18,  1603,     4,  2196,   123, 16109,
         23241,   350,  1994,     2,     0, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542],
        [ 7429,     7,    11, 12663,    35, 21169,  6268,    40,  1289, 17749,
            56,   682,   198, 39728,    13,    47, 14297,     3,  1571,    45,
          1834, 37194,    37, 10567,    13,     4,  9307,  1080,  6677, 32510,
             7,  4608,    40, 42822,  1084,   340,   193,     4, 13310,  1000,
           174,   183,  4944,    37,   311,  1634,   439,     2,     0, 59542,
         59542]])
attention_mask torch.Size([8, 51]) tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0,

2500

In [5]:
from transformers import AutoModelForSeq2SeqLM, MarianModel

#加载模型
#model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = MarianModel.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')

        self.register_buffer('final_logits_bias',
                             torch.zeros(1, tokenizer.vocab_size))

        self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)

        #加载预训练模型的参数
        parameters = AutoModelForSeq2SeqLM.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 decoder_input_ids=decoder_input_ids)
        logits = logits.last_hidden_state

        logits = self.fc(logits) + self.final_logits_bias

        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())

        return {'loss': loss, 'logits': logits}


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['logits'].shape

Some weights of the model checkpoint at Helsinki-NLP/opus-mt-en-ro were not used when initializing MarianModel: ['final_logits_bias']
- This IS expected if you are initializing MarianModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MarianModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


10563.4816


(tensor(1.5629, grad_fn=<NllLossBackward0>), torch.Size([8, 69, 59543]))

In [6]:
from datasets import load_metric

#加载评价函数
metric = load_metric(path='sacrebleu')

#试算
metric.compute(predictions=['hello there', 'general kenobi'],
               references=[['hello there'], ['general kenobi']])

{'score': 0.0,
 'counts': [4, 2, 0, 0],
 'totals': [4, 2, 0, 0],
 'precisions': [100.0, 100.0, 0.0, 0.0],
 'bp': 1.0,
 'sys_len': 4,
 'ref_len': 4}

In [7]:
#测试
def test():
    model.eval()

    #数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
    )

    predictions = []
    references = []
    for i, data in enumerate(loader_test):
        #计算
        with torch.no_grad():
            out = model(**data)

        pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
        label = tokenizer.batch_decode(data['decoder_input_ids'])
        predictions.extend(pred)
        references.extend(label)

        if i % 2 == 0:
            print(i)
            input_ids = tokenizer.decode(data['input_ids'][0])

            print('input_ids=', input_ids)
            print('pred=', pred[0])
            print('label=', label[0])

        if i == 10:
            break

    references = [[j] for j in references]
    metric_out = metric.compute(predictions=predictions, references=references)
    print(metric_out)


test()

0
input_ids= ▁Unfortunately for▁them, a RATP control▁team▁showed▁up.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
pred= Din păcate, ei, a-a făcut apariția o echipă de control RA RATP.,,,, - Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din Din
label= <pad> Din nefericire pentru ei, și-a făcut apariția o echipă de controlori RATP.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
2
input_ids= You are a▁great▁power▁only▁if you▁have▁solutions.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

In [8]:
from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            out = out['logits'].argmax(dim=2)
            correct = (data['decoder_input_ids'] == out).sum().item()
            total = data['decoder_input_ids'].shape[1] * 8
            accuracy = correct / total
            del correct
            del total

            predictions = []
            references = []
            for j in range(8):
                pred = tokenizer.decode(out[j])
                label = tokenizer.decode(data['decoder_input_ids'][j])
                predictions.append(pred)
                references.append([label])

            metric_out = metric.compute(predictions=predictions,
                                        references=references)

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), accuracy, metric_out, lr)

    torch.save(model, 'models/7.翻译.model')


train()



0 1.9611096382141113 0.0 {'score': 6.2696211486540125, 'counts': [127, 82, 54, 32], 'totals': [363, 355, 347, 339], 'precisions': [34.98622589531681, 23.098591549295776, 15.561959654178674, 9.43952802359882], 'bp': 0.33776683830627435, 'sys_len': 363, 'ref_len': 757} 1.9992e-05
50 1.214351773262024 0.0013736263736263737 {'score': 4.345771608240105, 'counts': [178, 122, 89, 65], 'totals': [663, 655, 647, 639], 'precisions': [26.84766214177979, 18.625954198473284, 13.75579598145286, 10.172143974960877], 'bp': 0.267199776786503, 'sys_len': 663, 'ref_len': 1538} 1.9592e-05
100 0.7556186318397522 0.006818181818181818 {'score': 12.965422783170721, 'counts': [171, 130, 100, 77], 'totals': [360, 352, 344, 336], 'precisions': [47.5, 36.93181818181818, 29.069767441860463, 22.916666666666668], 'bp': 0.3943345747424834, 'sys_len': 360, 'ref_len': 695} 1.9192000000000002e-05
150 0.7815563082695007 0.01201923076923077 {'score': 11.647085644686467, 'counts': [155, 117, 93, 74], 'totals': [345, 337, 3

1400 0.9339891672134399 0.00436046511627907 {'score': 6.273109997311125, 'counts': [197, 139, 106, 81], 'totals': [607, 599, 591, 583], 'precisions': [32.45469522240527, 23.20534223706177, 17.93570219966159, 13.893653516295025], 'bp': 0.30139275785881325, 'sys_len': 607, 'ref_len': 1335} 8.792e-06
1450 0.7789385318756104 0.0 {'score': 10.002563834271076, 'counts': [141, 105, 78, 56], 'totals': [340, 332, 324, 316], 'precisions': [41.470588235294116, 31.626506024096386, 24.074074074074073, 17.72151898734177], 'bp': 0.36572179669341776, 'sys_len': 340, 'ref_len': 682} 8.392e-06
1500 0.9499245882034302 0.003968253968253968 {'score': 8.773482938347657, 'counts': [173, 125, 92, 69], 'totals': [443, 435, 427, 419], 'precisions': [39.05191873589165, 28.735632183908045, 21.54566744730679, 16.46778042959427], 'bp': 0.34926695473651137, 'sys_len': 443, 'ref_len': 909} 7.992e-06
1550 0.8153000473976135 0.0021929824561403508 {'score': 9.020920286090554, 'counts': [164, 117, 84, 59], 'totals': [403

In [9]:
model = torch.load('models/7.翻译.model')
test()

0
input_ids= ▁Two▁brothers from Tomești▁swallowed▁sleeping▁pills.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
pred= Doi fraţiăţiți din din Tomești au înghiţitghițit somnifere.,u:: - Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi Doi
label= <pad> Doi frățiori din Tomești au înghițit somnifere.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
2
input_ids= ▁Whatever that▁does,▁it▁won't be▁pleasant.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p