In [None]:
# https://huggingface.co/HuggingFaceH4


In [None]:
# https://github.com/yaodongC/awesome-instruction-dataset


In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from easynmt import EasyNMT
from optimum.bettertransformer import BetterTransformer
from datasets import load_from_disk, load_dataset
import os
import pandas as pd

class Translator:
    def __init__(self, 
        model_name: str,
        device = 'cuda'
    ):
        self.model_name = model_name
        self.device = device
        self.model = None
        self.tokenizer = None
        self.init()
    
    def init(self):
        print("Init model.")
        if self.model_name == "facebook/nllb-200-3.3B":
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                self.model_name, 
                use_auth_token=True,
            )
            self.model = BetterTransformer.transform(self.model)
            self.model.eval()
            self.model = torch.compile(self.model)
            self.model = self.model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                use_auth_token=True,
            )
        elif self.model_name == "facebook/wmt21-dense-24-wide-en-x":
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                self.model_name, 
                use_auth_token=True,
            )
            self.model = BetterTransformer.transform(self.model)
            self.model.eval()
            self.model = torch.compile(self.model)
            self.model = self.model.to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                use_auth_token=True,
            )
        elif self.model_name == "opus-mt":
            self.model = EasyNMT(self.model_name)
            
        print("Model is initialized.")
    
    def translate(self, text: str):
        func_map = {
            "facebook/nllb-200-3.3B": self.nllb_translate,
            "opus-mt": self.opusmt_translate,
            "facebook/wmt21-dense-24-wide-en-x": self.wmt21_translate
        }
        
        with torch.no_grad():
            return func_map[self.model_name](text)
    
    def __call__(self, text: str):
        return self.translate(text=text)
    
    def nllb_translate(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        inputs = self.to_device(inputs=inputs)
        translated_tokens = self.model.generate(
            **inputs, 
            forced_bos_token_id=self.tokenizer.lang_code_to_id["rus_Cyrl"],
        )
        return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    
    def opusmt_translate(self, text: str):
        return self.model.translate(
            text,
            source_lang="en" ,
            target_lang='ru'
        )

    def wmt21_translate(self, text):
        inputs = self.tokenizer(text, return_tensors="pt")
        inputs = self.to_device(inputs=inputs)
        translated_tokens = self.model.generate(
            **inputs, 
            forced_bos_token_id=self.tokenizer.get_lang_id("ru"),
        )
        return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
    
    def to_device(self, inputs):
        for key in inputs.keys():
            inputs[key] = inputs[key].to(self.device)
        return inputs

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
EasyNMT("facebook/wmt21-dense-24-wide-en-x")

In [7]:
# 2 min 40 sec - 50
# model_name = "facebook/nllb-200-3.3B"
# 6 min 5 sec - 50
# model_name = "facebook/wmt21-dense-24-wide-en-x"
# 45 sec - 50
model_name = "opus-mt"
device = torch.device("cuda:0")
translator = Translator(model_name=model_name, device=device)

Init model.
Model is initialized.


In [8]:
translator("hello world")



'Приветствую мир'

### databricks/databricks-dolly-15k

In [9]:
data = load_dataset("databricks/databricks-dolly-15k")
data


100%|██████████| 1/1 [00:00<00:00, 642.41it/s]


DatasetDict({
    train: Dataset({
        features: ['instruction', 'context', 'response', 'category'],
        num_rows: 15014
    })
})

In [10]:
data["train"][0]


{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [None]:
base_folder = "/home/kosenko/deepspeed/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/datasets/translations_examples/"
model_path = model_name.replace("/", "_")
save_path = f"{base_folder}{model_path}.csv"

assert not os.path.isfile(save_path), f'File {model_path} exists'

fields = ["context", "instruction", "response"]

dataset_map = {item: [] for item in fields}
for item in fields:
    dataset_map[f'{item}_translated'] = []

for i, example in enumerate(data["train"]):
    print("Progress ",i)
    for field in fields:
        print(f"Field name: {field}")
        print("Original: ", example[field])
        text = example[field]
        translated = translator(text=text)
        print("Translated: ", translated)
        dataset_map[field].append(example[field])
        dataset_map[f'{field}_translated'].append(translated)
        print()
    print("==" * 100)

    if i > 50:
        break
pd.DataFrame(data=dataset_map).to_csv(save_path, index=False)


In [18]:
torch.multiprocessing.set_start_method('spawn', force=True)


In [19]:
def test(example):
    for field in fields:
        text = example[field]
        translated = translator(text=text)
        example[f'{field}_translated'] = translated

data["train"].select(range(50)).map(test, num_proc=3)

                                                               