Skip to content

Commit

Permalink
Add a text translation example (#1283)
Browse files Browse the repository at this point in the history
  • Loading branch information
Desjajja committed Aug 7, 2023
1 parent d9fee4f commit 398d229
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions examples/text_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import numpy as np
import torch
from datasets import load_dataset
from torchtext.data.metrics import bleu_score
from transformers import AutoTokenizer, T5ForConditionalGeneration

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner

tokenizer = AutoTokenizer.from_pretrained('t5-small')


class MMT5ForTranslation(BaseModel):

def __init__(self, model):
super().__init__()
self.model = model

def forward(self, label, input_ids, attention_mask, mode):
if mode == 'loss':
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=label)
return {'loss': output.loss}
elif mode == 'predict':
output = self.model.generate(input_ids)
return output, label


def post_process(preds, labels):
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [pred.split() for pred in preds]
decoded_labels = [[label.split()] for label in labels]
return decoded_preds, decoded_labels


class Accuracy(BaseMetric):

def process(self, data_batch, data_samples):
outputs, labels = data_samples
decoded_preds, decoded_labels = post_process(outputs, labels)
score = bleu_score(decoded_preds, decoded_labels)
prediction_lens = torch.tensor([
torch.count_nonzero(pred != tokenizer.pad_token_id)
for pred in outputs
],
dtype=torch.float64)

gen_len = torch.mean(prediction_lens).item()
self.results.append({
'gen_len': gen_len,
'bleu': score,
})

def compute_metrics(self, results):
return dict(
gen_len=np.mean([item['gen_len'] for item in results]),
bleu_score=np.mean([item['bleu'] for item in results]),
)


def collate_fn(data):
prefix = 'translate English to French: '
input_sequences = [prefix + item['translation']['en'] for item in data]
target_sequences = [item['translation']['fr'] for item in data]
input_dict = tokenizer(
input_sequences,
padding='longest',
return_tensors='pt',
)

label = tokenizer(
target_sequences,
padding='longest',
return_tensors='pt',
).input_ids
label[label ==
tokenizer.pad_token_id] = -100 # ignore contribution to loss
return dict(
label=label,
input_ids=input_dict.input_ids,
attention_mask=input_dict.attention_mask)


def main():
model = T5ForConditionalGeneration.from_pretrained('t5-small')

books = load_dataset('opus_books', 'en-fr')
books = books['train'].train_test_split(test_size=0.2)
train_set, test_set = books['train'], books['test']

train_loader = dict(
batch_size=16,
dataset=train_set,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=collate_fn)
test_loader = dict(
batch_size=32,
dataset=test_set,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=collate_fn)
runner = Runner(
model=MMT5ForTranslation(model),
train_dataloader=train_loader,
val_dataloader=test_loader,
optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_cfg=dict(),
work_dir='t5_work_dir',
val_evaluator=dict(type=Accuracy))

runner.train()


if __name__ == '__main__':
main()

0 comments on commit 398d229

Please sign in to comment.