In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from transformers import AutoModel, AutoTokenizer, AdamW
from datasets import load_dataset, load_from_disk
import evaluate
import math

In [None]:
path = "huggingface/models/opus-mt-de-en"
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModel.from_pretrained(path)

In [3]:
def collate_fn(data):
    de = [example['translation']['de'] for example in data]
    en = [example['translation']['en'] for example in data]
    data = tokenizer.batch_encode_plus(de, padding=True, truncation=True, max_length=128, return_tensors='pt')

    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(en, padding=True, truncation=True, max_length=128, return_tensors='pt')['input_ids']

    data['decoder_input_ids'] = torch.full_like(data['labels'], tokenizer.get_vocab()['<pad>'])
    data['decoder_input_ids'][:,1:] = data['labels'][:,:-1]

    return data

In [4]:
dataset = load_from_disk("huggingface/datasets/wmt16/de-en")
train_dataloader = DataLoader(dataset['train'], batch_size=32, shuffle=True, 
                            drop_last=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dataset['validation'], batch_size=32, shuffle=True, 
                            drop_last=True, collate_fn=collate_fn)
test_dataloader = DataLoader(dataset['test'], batch_size=32, shuffle=True, 
                            drop_last=True, collate_fn=collate_fn)

In [5]:
class Model(nn.Module):
    def __init__(self) :
        super().__init__()
        self.backbone = AutoModel.from_pretrained(path)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(512, tokenizer.vocab_size)
    
    def forward(self, input_ids, attention_mask, decoder_input_ids):
        out = self.backbone(input_ids, attention_mask, decoder_input_ids)
        out = out.last_hidden_state
        out = self.fc(self.dropout(out))

        return out
        

In [None]:
epochs = 1
model = Model()
optimizer = AdamW([
    {"params": model.backbone.parameters(), 'lr': 2e-5},
    {"params": model.fc.parameters(), 'lr': 5e-4}
])
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
def train():
    model.train()
    model.to(device)
    for epoch in range(epochs):
        epoch_loss = 0
        for i, data in enumerate(train_dataloader):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            decoder_input_ids = data['decoder_input_ids'].to(device)
            labels = data['labels'].to(device)
            
            out = model(input_ids, attention_mask, decoder_input_ids)
            output_dim = out.shape[-1]
            out = out.view(-1, output_dim)
            labels = labels.view(-1)
            
            loss = criterion(out, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            if i % 100 == 0:
                loss_item = epoch_loss / (i + 1)
                print('epoch:{}, idx:{}, loss:{}, PPL:{}'.format(epoch+1, i, loss_item, math.exp(loss_item)))
    


In [None]:
train()

In [25]:
def translate(model, dataloader):
    model.eval()
    model = model.to(device)
    predictions = []
    references = []
    for data in dataloader:
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        decoder_input_ids = data['decoder_input_ids'].to(device)
        labels = data['labels'].to(device)
        
        out = model(input_ids, attention_mask, decoder_input_ids)
        pred = tokenizer.batch_decode(out.argmax(dim=2), skip_special_tokens=True)
        label = tokenizer.batch_decode(labels, skip_special_tokens=True)
        predictions.extend(pred)
        references.extend(label)
    
    return predictions, references

In [29]:
def compute_bleu(predictions, references):
    references = [[i] for i in references]
    metric = evaluate.load('bleu')
    metric_out = metric.compute(predictions=predictions, references=references)
    print(metric_out)
    return metric_out

In [None]:
predictions, references = translate(model, valid_dataloader)
valid_bleu = compute_bleu(predictions, references)
predictions, references = translate(model, test_dataloader)
test_bleu = compute_bleu(predictions, references)

{'bleu': 0.29646412251432813, 'precisions': [0.6205257331394213, 0.3691747296528173, 0.2305722422994998, 0.14624795056123094], 'brevity_penalty': 1.0, 'length_ratio': 1.0033540237395187, 'translation_length': 46069, 'reference_length': 45915}

{'bleu': 0.3444770324148899, 'precisions': [0.6559922215600791, 0.4178318802434611, 0.2764209361054416, 0.18585313371870385], 'brevity_penalty': 1.0, 'length_ratio': 1.0028623553095117, 'translation_length': 63766, 'reference_length': 63584}


In [31]:
predictions[0], references[0]

("He is one of the most's moststt guyshuuse who who will have a lot of experience, played a under all conditions and against all attacks.",
 "He is one of the game's loveliest blokes, who will bring a wealth of experience having done it in all conditions and against all attacks.")