-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a text translation example (#1283)
- Loading branch information
Showing
1 changed file
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import numpy as np | ||
import torch | ||
from datasets import load_dataset | ||
from torchtext.data.metrics import bleu_score | ||
from transformers import AutoTokenizer, T5ForConditionalGeneration | ||
|
||
from mmengine.evaluator import BaseMetric | ||
from mmengine.model import BaseModel | ||
from mmengine.runner import Runner | ||
|
||
tokenizer = AutoTokenizer.from_pretrained('t5-small') | ||
|
||
|
||
class MMT5ForTranslation(BaseModel): | ||
|
||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward(self, label, input_ids, attention_mask, mode): | ||
if mode == 'loss': | ||
output = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
labels=label) | ||
return {'loss': output.loss} | ||
elif mode == 'predict': | ||
output = self.model.generate(input_ids) | ||
return output, label | ||
|
||
|
||
def post_process(preds, labels): | ||
preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | ||
labels = torch.where(labels != -100, labels, tokenizer.pad_token_id) | ||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | ||
decoded_preds = [pred.split() for pred in preds] | ||
decoded_labels = [[label.split()] for label in labels] | ||
return decoded_preds, decoded_labels | ||
|
||
|
||
class Accuracy(BaseMetric): | ||
|
||
def process(self, data_batch, data_samples): | ||
outputs, labels = data_samples | ||
decoded_preds, decoded_labels = post_process(outputs, labels) | ||
score = bleu_score(decoded_preds, decoded_labels) | ||
prediction_lens = torch.tensor([ | ||
torch.count_nonzero(pred != tokenizer.pad_token_id) | ||
for pred in outputs | ||
], | ||
dtype=torch.float64) | ||
|
||
gen_len = torch.mean(prediction_lens).item() | ||
self.results.append({ | ||
'gen_len': gen_len, | ||
'bleu': score, | ||
}) | ||
|
||
def compute_metrics(self, results): | ||
return dict( | ||
gen_len=np.mean([item['gen_len'] for item in results]), | ||
bleu_score=np.mean([item['bleu'] for item in results]), | ||
) | ||
|
||
|
||
def collate_fn(data): | ||
prefix = 'translate English to French: ' | ||
input_sequences = [prefix + item['translation']['en'] for item in data] | ||
target_sequences = [item['translation']['fr'] for item in data] | ||
input_dict = tokenizer( | ||
input_sequences, | ||
padding='longest', | ||
return_tensors='pt', | ||
) | ||
|
||
label = tokenizer( | ||
target_sequences, | ||
padding='longest', | ||
return_tensors='pt', | ||
).input_ids | ||
label[label == | ||
tokenizer.pad_token_id] = -100 # ignore contribution to loss | ||
return dict( | ||
label=label, | ||
input_ids=input_dict.input_ids, | ||
attention_mask=input_dict.attention_mask) | ||
|
||
|
||
def main(): | ||
model = T5ForConditionalGeneration.from_pretrained('t5-small') | ||
|
||
books = load_dataset('opus_books', 'en-fr') | ||
books = books['train'].train_test_split(test_size=0.2) | ||
train_set, test_set = books['train'], books['test'] | ||
|
||
train_loader = dict( | ||
batch_size=16, | ||
dataset=train_set, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
collate_fn=collate_fn) | ||
test_loader = dict( | ||
batch_size=32, | ||
dataset=test_set, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
collate_fn=collate_fn) | ||
runner = Runner( | ||
model=MMT5ForTranslation(model), | ||
train_dataloader=train_loader, | ||
val_dataloader=test_loader, | ||
optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)), | ||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), | ||
val_cfg=dict(), | ||
work_dir='t5_work_dir', | ||
val_evaluator=dict(type=Accuracy)) | ||
|
||
runner.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |