<a href="https://colab.research.google.com/github/haeggee/error-detection-mt/blob/main/dataset/backtranslation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Backtranslating the WMT'21 Dataset
A short notebook to translate all sentences in the train and dev set of WMT'21 back to English. We use the ML50 multilingual translation model (mBART50) provided in Huggingface. This is the same model that was used for the original translation from En to (Ja, Cz, Cs, De). For reference, see 

https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt

https://arxiv.org/abs/2008.00401

In [None]:
!pip install transformers==4.8.2 -q
!pip install sentencepiece==0.1.96 -q
!pip install pickle5 -q

In [None]:
# -- Make use of our prepared .pkl for train and dev set, e.g. uploading dataset/ zip
# !unzip -qq dataset

In [None]:
from google.colab import drive
from google.colab import files
use_drive = True # either store to GDrive or direct download
if use_drive:
  drive.mount('/content/drive')

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import pickle5 as pickle
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL = "facebook/mbart-large-50-many-to-one-mmt"
batch_size = 8

In [None]:
filename_train = "dataset/wmt21_multi_train.pkl" 
dataset_train = pickle.load(open(filename_train,'rb'))
filename_dev = "dataset/wmt21_multi_dev.pkl" 
dataset_dev = pickle.load(open(filename_dev, 'rb'))

dataset_train = dataset_train.reset_index(drop=True)
dataset_dev = dataset_dev.reset_index(drop=True)

dataset_train[:3]

Unnamed: 0,id,src,mt,list_scores,avg_scores,critical,language_pair
0,9845,Are you one of the 74 stalkers watching my tal...,"Jsi jeden z 74 sledovatelů , kteří sledují mou...","[0, 1, 0]",NOT,0,en-cs
1,38,I believe we now have a fully-referenced article.,"Domnívám se , že nyní máme plně odkazovaný člá...","[0, 0, 0]",NOT,0,en-cs
2,9117,"FUCK YOU, GO SUCK SOME DOG COCK! YOU FUCKING D...","Do prdele , pojď sehnat nějakou děvku , ty seš...","[1, 1, 1]",ERR,1,en-cs


In [None]:
len_mt = []
for mt in dataset_train.mt:
  len_mt.append(len(mt))
for mt in dataset_dev.mt:
  len_mt.append(len(mt))

In [None]:
print(max(len_mt))

402


In [None]:
class CustomDataset(Dataset):

    def __init__(self, data, with_labels=True, trans_model='facebook/mbart-large-50-many-to-one-mmt'):
        self.data = data  # pandas dataframe
        self.tokenizer = AutoTokenizer.from_pretrained(trans_model)
        self.maxlen = 402

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sent = str(self.data.loc[index, 'mt'])
        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded = self.tokenizer(sent, padding='max_length',  # Pad to max_length
                                       truncation=True,  # Truncate to max_length
                                       max_length=self.maxlen,  
                                       return_tensors='pt')  # Return torch.Tensor objects
        token_ids = encoded['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded['attention_mask'].squeeze(0) 
        return token_ids, attn_masks

In [None]:
# Creating instances of training and validation set
train_set = CustomDataset(dataset_train, MODEL)
val_set = CustomDataset(dataset_dev, MODEL)
# Creating instances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=2)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=461.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1508.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=649.0, style=ProgressStyle(description_…




In [None]:
backtrans_train = []
backtrans_dev = []

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)
model = model.to(device)

In [None]:
for seq, attn in tqdm(train_loader):
  seq, attn = seq.to(device), attn.to(device)
  generated_tokens = model.generate(seq, num_beams=4, early_stopping=True)
  sents = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
  backtrans_train += sents

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
100%|██████████| 3734/3734 [1:27:03<00:00,  1.40s/it]


In [None]:
for seq, attn in tqdm(val_loader):
  seq, attn = seq.to(device), attn.to(device)
  generated_tokens = model.generate(seq, num_beams=4, early_stopping=True)
  sents = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
  backtrans_dev += sents

100%|██████████| 500/500 [11:32<00:00,  1.39s/it]


In [None]:
dataset_train['btr'] = backtrans_train
dataset_dev['btr'] = backtrans_dev

In [None]:
if use_drive:
  dataset_train.to_pickle('/content/drive/MyDrive/wmt21_multi_btr_train.pkl')
  dataset_dev.to_pickle('/content/drive/MyDrive/wmt21_multi_btr_dev.pkl')
else:
  dataset_train.to_pickle('wmt21_multi_btr_train.pkl')
  dataset_dev.to_pickle('wmt21_multi_btr_dev.pkl')
  files.download('wmt21_multi_btr_train.pkl')
  files.download('wmt21_multi_btr_dev.pkl')