# 微调transformer模型解决翻译任务

本文为datawhale.[learn-nlp-with-transformers](https://github.com/datawhalechina/learn-nlp-with-transformers/blob/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1/4.6-%E7%94%9F%E6%88%90%E4%BB%BB%E5%8A%A1-%E6%9C%BA%E5%99%A8%E7%BF%BB%E8%AF%91.md) 学习笔记

在这个notebook中，我们将展示如何使用Transformers代码库中的模型来解决自然语言处理中的翻译任务。我们将会使用WMT dataset数据集。这是翻译任务最常用的数据集之一。

## 安装环境

In [1]:
! pip install datasets transformers sacrebleu==1.5.1 sentencepiece



In [2]:
model_checkpoint = "Helsinki-NLP/opus-mt-en-ro" 
# 选择一个模型checkpoint


## 加载数据

In [3]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("wmt16", "ro-en")

metric = load_metric("sacrebleu")


Reusing dataset wmt16 (/root/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a)


-----
**问题：**AttributeError: module 'sacrebleu' has no attribute 'DEFAULT_TOKENIZER'

**解决：** pip install sacrebleu==1.5.1

In [4]:
raw_datasets


DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 610320
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
})

In [5]:
raw_datasets["train"][0]
# 我们可以看到一句英语en对应一句罗马尼亚语言ro


{'translation': {'en': 'Membership of Parliament: see Minutes',
  'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}

随机样例：

In [6]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))


In [7]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,translation
0,"{'en': 'However, this can only be brought to pass if South Africa manages to take real action on its own behalf by setting out a proper disarmament policy, which is thought to be the real foundation on which to base the development project in the country.', 'ro': 'Însă acest lucru nu se poate realiza decât dacă Africa de Sud va adopta, la rândul său, măsuri concrete prin elaborarea unei politici adecvate de dezarmare, considerată a fi fundamentul real pe care se poate construi proiectul de dezvoltare a ţării.'}"
1,"{'en': 'I would like to know how the European Commission intends to approach these negotiations and this draft convention, and on the basis of what mandate it will act on behalf of all of us, so that tomorrow, in the area of domestic work, the European Union can set an example and that we, too, can give expression to the values of the European Union.', 'ro': 'Aș vrea să știu cum intenționează Comisia Europeană să abordeze aceste negocieri și proiectul de convenție aferent și pe baza cărui mandat va acționa în numele nostru, al tuturor, astfel încât, în domeniul muncii casnice, Uniunea Europeană să devină mâine un exemplu demn de a fi urmat și astfel încât noi, la rândul nostru, să putem da viață valorilor Uniunii Europene.'}"
2,"{'en': 'The application relates to 2 840 job losses in the company Dell in the counties of Limerick, Clare and North Tipperary and in the city of Limerick, of which 2 400 were targeted for assistance.', 'ro': 'Cererea se referă la 2 840 de disponibilizări în compania Dell în districtele Limerick, Clare şi North Tipperary şi în oraşul Limerick, dintre care 2 400 au fost vizaţi pentru asistenţă.'}"
3,"{'en': 'Ms Rivasi', 'ro': 'Dna Rivasi'}"
4,"{'en': 'The Poles and the Germans also have a bit of tidying up to do.', 'ro': 'Polonezii şi germanii au şi ei de făcut puţină curăţenie.'}"


In [8]:
metric

Metric(name: "sacrebleu", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: """
Produces BLEU scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions: The system stream (a sequence of segments).
    references: A list of one or more reference streams (each a sequence of segments).
    smooth_method: The smoothing method to use. (Default: 'exp').
    smooth_value: The smoothing value. Only valid for 'floor' and 'add-k'. (Defaults: floor: 0.1, add-k: 1).
    tokenize: Tokenization method to use for BLEU. If not provided, defaults to 'zh' for Chinese, 'ja-mecab' for
        Japanese and '13a' (mteval) otherwise.
    lowercase: Lowercase the data. If True, enables case-insensitivity. (Default: False).
    force: Insist that your tokenized input is actually detokenized.

Returns:
    'score': BLEU score,
    'counts'

In [9]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = [["hello there"], ["general kenobi"]]
metric.compute(predictions=fake_preds, references=fake_labels)

{'bp': 1.0,
 'counts': [4, 2, 0, 0],
 'precisions': [100.0, 100.0, 0.0, 0.0],
 'ref_len': 4,
 'score': 0.0,
 'sys_len': 4,
 'totals': [4, 2, 0, 0]}

## 数据预处理

In [10]:
from transformers import AutoTokenizer
# 需要安装`sentencepiece`： pip install sentencepiece
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


Downloading:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/789k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/817k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [11]:
if "mbart" in model_checkpoint:
    tokenizer.src_lang = "en-XX"
    tokenizer.tgt_lang = "ro-RO"


In [12]:
tokenizer("Hello, this one sentence!")


{'input_ids': [125, 778, 3, 63, 141, 9191, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}

In [13]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])


{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 0], [187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

In [14]:
with tokenizer.as_target_tokenizer():
    print(tokenizer("Hello, this one sentence!"))
    model_input = tokenizer("Hello, this one sentence!")
    tokens = tokenizer.convert_ids_to_tokens(model_input['input_ids'])
    # 打印看一下special toke
    print('tokens: {}'.format(tokens))


{'input_ids': [10334, 1204, 3, 15, 8915, 27, 452, 59, 29579, 581, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
tokens: ['▁Hel', 'lo', ',', '▁', 'this', '▁o', 'ne', '▁se', 'nten', 'ce', '!', '</s>']


In [15]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "translate English to Romanian: "
else:
    prefix = ""

In [16]:
max_input_length = 128
max_target_length = 128
source_lang = "en"
target_lang = "ro"

def preprocess_function(examples):
    inputs = [prefix + ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [17]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[393, 4462, 14, 1137, 53, 216, 28636, 0], [24385, 14, 28636, 14, 4646, 4622, 53, 216, 28636, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[42140, 494, 1750, 53, 8, 59, 903, 3543, 9, 15202, 0], [36199, 6612, 9, 15202, 122, 568, 35788, 21549, 53, 8, 59, 903, 3543, 9, 15202, 0]]}

In [18]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/611 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

## 微调transformer模型

In [19]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)


Downloading:   0%|          | 0.00/301M [00:00<?, ?B/s]

In [20]:
batch_size = 3
args = Seq2SeqTrainingArguments(
    "test-translation",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
)


In [21]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


In [22]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


In [23]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()


The following columns in the training set  don't have a corresponding argument in `MarianMTModel.forward` and have been ignored: translation.
***** Running training *****
  Num examples = 610320
  Num Epochs = 1
  Instantaneous batch size per device = 3
  Total train batch size (w. parallel, distributed & accumulation) = 3
  Gradient Accumulation steps = 1
  Total optimization steps = 203440


Epoch,Training Loss,Validation Loss


Saving model checkpoint to test-translation/checkpoint-500
Configuration saved in test-translation/checkpoint-500/config.json
Model weights saved in test-translation/checkpoint-500/pytorch_model.bin
tokenizer config file saved in test-translation/checkpoint-500/tokenizer_config.json
Special tokens file saved in test-translation/checkpoint-500/special_tokens_map.json
Saving model checkpoint to test-translation/checkpoint-1000
Configuration saved in test-translation/checkpoint-1000/config.json
Model weights saved in test-translation/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in test-translation/checkpoint-1000/tokenizer_config.json
Special tokens file saved in test-translation/checkpoint-1000/special_tokens_map.json
Saving model checkpoint to test-translation/checkpoint-1500
Configuration saved in test-translation/checkpoint-1500/config.json
Model weights saved in test-translation/checkpoint-1500/pytorch_model.bin
tokenizer config file saved in test-translation/checkpo