In [6]:
import torch
from transformers import MarianMTModel, MarianTokenizer
import json
from pathlib import Path


device = 0


class BackTranslator:
    def __init__(self, src_lang_model_name, tmp_lang_model_name, device, batch_size=16):
        self.device = device
        self.batch_size = batch_size
        self.tmp_lang_tokenizer, self.tmp_lang_model = self.load_model(tmp_lang_model_name)
        self.src_lang_tokenizer, self.src_lang_model = self.load_model(src_lang_model_name)

    def load_model(self, model_name):
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        model = model.to(self.device)
        return tokenizer, model

    def dict_to_tensor(self, x):
        for k, v in x.items():
            x[k] = torch.Tensor(v).long().to(self.device)
        return x

    def translate(self, texts, model, tokenizer):
        # original_texts = texts
        # formatter_fn = lambda txt: f"{txt}" if language == "en" else f">>{language}<< {txt}"
        # original_texts = [formatter_fn(txt) for txt in texts]

        tokens = tokenizer.prepare_seq2seq_batch(texts)
        tokens = self.dict_to_tensor(tokens)
        translated = model.generate(**tokens)
        translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)
        return translated_texts


    def back_translate(self, texts):
        results = []
        i = 0
        while i < len(texts):
            x = texts[i: i + self.batch_size]
            translated = self.translate(x, self.tmp_lang_model, self.tmp_lang_tokenizer)
            back_translated = self.translate(translated, self.src_lang_model, self.src_lang_tokenizer)
            results.extend(back_translated)
            i += self.batch_size
            print('***')
        return results

In [7]:
back_translator = BackTranslator(
    src_lang_model_name='Helsinki-NLP/opus-mt-fr-en', 
    tmp_lang_model_name='Helsinki-NLP/opus-mt-en-fr', 
    device='cpu', 
    batch_size=3
)
    

In [8]:



data_dir = Path("../../data/datasets/public/sequence_classification/imdb/")
data = json.load(open(data_dir / 'train.json', 'r'))

In [9]:
texts = [x['content'] for x in data]

In [None]:

back_texts = back_translator.back_translate(texts)
# print(back_texts)

***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***
***


In [None]:
len(back_texts)