# Load and Use a Previously-trained Ro-to-En T5 Model

Some initialization:

In [1]:
import torch
import numpy as np
from transformers import set_seed

# set to True to use the gpu (if there is one available)
use_gpu = True

# select device
device = torch.device('cuda' if use_gpu and torch.cuda.is_available() else 'cpu')
print(f'device: {device.type}')

# random seed
seed = 42

# set random seed
if seed is not None:
    print(f'random seed: {seed}')
    set_seed(seed)

device: cuda
random seed: 42


In [2]:
source_lang = 'ro'
target_lang = 'en'
max_source_length = 1024
max_target_length = 128
task_prefix = 'translate Romanian to English: '
num_beams = 1
batch_size = 100

Load the tokenizer and model from the location where you save them:

In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

output_dir = '/media/data2/t5-translation-example' # make sure this is a valid path on your machine!
tokenizer = AutoTokenizer.from_pretrained(output_dir, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(output_dir, local_files_only=True)
model = model.to(device)

Load just the test partition of the dataset from HuggingFace:

In [4]:
from datasets import load_dataset

test_ds = load_dataset('wmt16', 'ro-en', split='test')
test_ds

Reusing dataset wmt16 (/home/marco/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/9e0038fe4cc117bd474d2774032cc133e355146ed0a47021b2040ca9db4645c0)


Dataset({
    features: ['translation'],
    num_rows: 1999
})

In [5]:
test_ds['translation'][0]

{'en': 'UN Chief Says There Is No Military Solution in Syria',
 'ro': 'Șeful ONU declară că nu există soluții militare în Siria'}

Implement the `translate` method and apply it to the test partition:

In [6]:
def translate(batch):
    # get source language examples and prepend task prefix
    inputs = [x[source_lang] for x in batch["translation"]]
    inputs = [task_prefix + x for x in inputs]
    
    # tokenize inputs
    encoded = tokenizer(
        inputs,
        max_length=max_source_length,
        truncation=True,
        padding=True,
        return_tensors='pt',
    )
    
    # move data to gpu if needed
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device)
    
    # generate translated sentences
    output = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        num_beams=num_beams,
        max_length=max_target_length,
    )
    
    # generate predicted sentences from predicted token ids
    decoded = tokenizer.batch_decode(
        output,
        skip_special_tokens=True,
    )
    
    # get gold sentences in target language
    targets = [x[target_lang] for x in batch["translation"]]
    
    # return gold and predicted sentences
    return {
        'reference': targets,
        'prediction': decoded,
    }

In [7]:
results = test_ds.map(
    translate,
    batched=True,
    batch_size=batch_size,
    remove_columns=test_ds.column_names,
)

results.to_pandas()

  0%|          | 0/20 [00:00<?, ?ba/s]

Unnamed: 0,reference,prediction
0,UN Chief Says There Is No Military Solution in...,UN chief says no military solutions in Syria
1,Secretary-General Ban Ki-moon says his respons...,Secretary-General Ban Ki-moon says his respons...
2,"The U.N. chief again urged all parties, includ...","The UN chief again called on all parties, incl..."
3,Ban told a news conference Wednesday that he p...,Ban said at a conference Wednesday that he pla...
4,He expressed regret that divisions in the coun...,Ban expressed regret that divisions in the cou...
...,...,...
1994,Money is enough.,There is no money.
1995,Sometimes I'm ashamed to take my money from th...,Sometimes I am ashamed to raise money from the...
1996,At the end of the office I will report everyth...,"At the end of my term of office, I will report..."
1997,"One month I happened to collect 30,000 lei wit...",It happened to rise in one month and 30.000 le...


Compute the BLEU score:

In [8]:
from datasets import load_metric

metric = load_metric('sacrebleu')

for r in results:
    prediction = r['prediction']
    reference = [r['reference']]
    metric.add(prediction=prediction, reference=reference)
    
metric.compute()

{'score': 33.39227034561738,
 'counts': [31267, 18123, 11587, 7627],
 'totals': [47851, 45852, 43855, 41859],
 'precisions': [65.34241708637228,
  39.52499345721015,
  26.421160643028163,
  18.22069327982035],
 'bp': 1.0,
 'sys_len': 47851,
 'ref_len': 47562}