## Quickstart

Install the library with pip:

In [None]:
import os
import json
from typing import Union, List, Dict

import transformers
import torch
from tqdm.auto import tqdm
!pip install utils


def _select_device(device_selection):
    selected = device_selection.lower()
    if selected == "auto":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif selected == "cpu":
        device = torch.device("cpu")
    elif selected == "gpu":
        device = torch.device("cuda")
    else:
        device = torch.device(selected)

    return device


def _resolve_lang_codes(source: str, target: str, model_family: str):
    def error_message(variable, value):
        return f'Your {variable}="{value}" is not valid. Please run `print(mt.available_languages())` to see which languages are available.'

    # If can't find in the lang -> code mapping, assumes it's already a code.
    lang_code_map = utils.get_lang_code_map(model_family)
    source = lang_code_map.get(source.capitalize(), source)
    target = lang_code_map.get(target.capitalize(), target)

    # If the code is not valid, raises an error
    if source not in utils.available_codes(model_family):
        raise ValueError(error_message("source", source))
    if target not in utils.available_codes(model_family):
        raise ValueError(error_message("target", target))

    return source, target


def _resolve_tokenizer(model_family):
    di = {
        "mbart50": transformers.MBart50TokenizerFast,
        "m2m100": transformers.M2M100Tokenizer,
    }
    if model_family in di:
        return di[model_family]
    else:
        error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}"
        raise ValueError(error_msg)


def _resolve_transformers_model(model_family):
    di = {
        "mbart50": transformers.MBartForConditionalGeneration,
        "m2m100": transformers.M2M100ForConditionalGeneration,
    }
    if model_family in di:
        return di[model_family]
    else:
        error_msg = f"{model_family} is not a valid value for model_family. Please choose model_family to be equal to one of the following values: {list(di.keys())}"
        raise ValueError(error_msg)


def _infer_model_family(model_or_path):
    di = {
        "facebook/mbart-large-50-many-to-many-mmt": "mbart50",
        "facebook/m2m100_418M": "m2m100",
        "facebook/m2m100_1.2B": "m2m100",
    }

    if model_or_path in di:
        return di[model_or_path]
    else:
        error_msg = f'Unable to infer the model_family from "{model_or_path}". Try explicitly setting the value of model_family to "mbart50" or "m2m100".'
        raise ValueError(error_msg)


def _infer_model_or_path(model_or_path):
    di = {
        "mbart50": "facebook/mbart-large-50-many-to-many-mmt",
        "m2m100": "facebook/m2m100_418M",
        "m2m100-small": "facebook/m2m100_418M",
        "m2m100-medium": "facebook/m2m100_1.2B",
    }

    return di.get(model_or_path, model_or_path)


class TranslationModel:
    def __init__(
        self,
        model_or_path: str = "m2m100",
        tokenizer_path: str = None,
        device: str = "auto",
        model_family: str = None,
        model_options: dict = None,
        tokenizer_options: dict = None,
    ):
        """
        *Instantiates a multilingual transformer model for translation.*
        {{params}}
        {{model_or_path}} The path or the name of the model. Equivalent to the first argument of `AutoModel.from_pretrained()`. You can also specify shorthands ("mbart50" and "m2m100").
        {{tokenizer_path}} The path to the tokenizer. By default, it will be set to `model_or_path`.
        {{device}} "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fall back to CPU.
        {{model_family}} Either "mbart50" or "m2m100". By default, it will be inferred based on `model_or_path`. Needs to be explicitly set if `model_or_path` is a path.
        {{model_options}} The keyword arguments passed to the model, which is a transformer for conditional generation.
        {{tokenizer_options}} The keyword arguments passed to the model's tokenizer.
        """
        model_or_path = _infer_model_or_path(model_or_path)
        self.model_or_path = model_or_path
        self.device = _select_device(device)

        # Resolve default values
        tokenizer_path = tokenizer_path or self.model_or_path
        model_options = model_options or {}
        tokenizer_options = tokenizer_options or {}
        self.model_family = model_family or _infer_model_family(model_or_path)

        # Load the tokenizer
        TokenizerFast = _resolve_tokenizer(self.model_family)
        self._tokenizer = TokenizerFast.from_pretrained(
            tokenizer_path, **tokenizer_options
        )

        # Load the model either from a saved torch model or from transformers.
        if model_or_path.endswith(".pt"):
            self._transformers_model = torch.load(
                model_or_path, map_location=self.device
            ).eval()
        else:
            ModelForConditionalGeneration = _resolve_transformers_model(
                self.model_family
            )
            self._transformers_model = (
                ModelForConditionalGeneration.from_pretrained(
                    self.model_or_path, **model_options
                )
                .to(self.device)
                .eval()
            )

    def translate(
        self,
        text: Union[str, List[str]],
        source: str,
        target: str,
        batch_size: int = 32,
        verbose: bool = False,
        generation_options: dict = None,
    ) -> Union[str, List[str]]:
        """
        *Translates a string or a list of strings from a source to a target language.*
        {{params}}
        {{text}} The content you want to translate.
        {{source}} The language of the original text.
        {{target}} The language of the translated text.
        {{batch_size}} The number of samples to load at once. If set to `None`, it will process everything at once.
        {{verbose}} Whether to display the progress bar for every batch processed.
        {{generation_options}} The keyword arguments passed to `model.generate()`, where `model` is the underlying transformers model.
        Note:
        - Run `print(dlt.utils.available_languages())` to see what's available.
        - A smaller value is preferred for `batch_size` if your (video) RAM is limited.
        """
        if generation_options is None:
            generation_options = {}

        source, target = _resolve_lang_codes(source, target, self.model_family)
        self._tokenizer.src_lang = source

        original_text_type = type(text)
        if original_text_type is str:
            text = [text]

        if batch_size is None:
            batch_size = len(text)

        generation_options.setdefault(
            "forced_bos_token_id", self._tokenizer.lang_code_to_id[target]
        )

        data_loader = torch.utils.data.DataLoader(text, batch_size=batch_size)
        output_text = []

        with torch.no_grad():
            for batch in tqdm(data_loader, disable=not verbose):
                encoded = self._tokenizer(batch, return_tensors="pt", padding=True)
                encoded.to(self.device)

                generated_tokens = self._transformers_model.generate(
                    **encoded, **generation_options
                ).cpu()

                decoded = self._tokenizer.batch_decode(
                    generated_tokens, skip_special_tokens=True
                )

                output_text.extend(decoded)

        # If text: str and output_text: List[str], then we should convert output_text to str
        if original_text_type is str and len(output_text) == 1:
            output_text = output_text[0]

        return output_text

    def get_transformers_model(self):
        """
        *Retrieve the underlying mBART transformer model.*
        """
        return self._transformers_model

    def get_tokenizer(self):
        """
        *Retrieve the mBART huggingface tokenizer.*
        """
        return self._tokenizer

    def available_languages(self) -> List[str]:
        """
        *Returns all the available languages for a given `dlt.TranslationModel`
        instance.*
        """
        return utils.available_languages(self.model_family)

    def available_codes(self) -> List[str]:
        """
        *Returns all the available codes for a given `dlt.TranslationModel`
        instance.*
        """
        return utils.available_codes(self.model_family)

    def get_lang_code_map(self) -> Dict[str, str]:
        """
        *Returns the language -> codes dictionary for a given `dlt.TranslationModel`
        instance.*
        """
        return utils.get_lang_code_map(self.model_family)

    def save_obj(self, path: str = "saved_model") -> None:
        """
        *Saves your model as a torch object and save your tokenizer.*
        {{params}}
        {{path}} The directory where you want to save your model and tokenizer
        """
        os.makedirs(path, exist_ok=True)
        torch.save(self._transformers_model, os.path.join(path, "weights.pt"))
        self._tokenizer.save_pretrained(path)

        dlt_config = dict(model_family=self.model_family)
        json.dump(dlt_config, open(os.path.join(path, "dlt_config.json"), "w"))

    @classmethod
    def load_obj(cls, path: str = "saved_model", **kwargs):
        """
        *Initialize `dlt.TranslationModel` from the torch object and tokenizer
        saved with `dlt.TranslationModel.save_obj`*
        {{params}}
        {{path}} The directory where your torch model and tokenizer are stored
        """
        config_prev = json.load(open(os.path.join(path, "dlt_config.json"), "rb"))
        config_prev.update(kwargs)
        return cls(
            model_or_path=os.path.join(path, "weights.pt"),
            tokenizer_path=path,
            **config_prev,
        )

Collecting utils
  Downloading utils-1.0.1-py2.py3-none-any.whl (21 kB)
Installing collected packages: utils
Successfully installed utils-1.0.1


In [None]:
!pip install -q dl-translate

[K     |████████████████████████████████| 3.4 MB 5.3 MB/s 
[K     |████████████████████████████████| 1.2 MB 52.1 MB/s 
[K     |████████████████████████████████| 596 kB 50.2 MB/s 
[K     |████████████████████████████████| 61 kB 450 kB/s 
[K     |████████████████████████████████| 3.3 MB 34.0 MB/s 
[K     |████████████████████████████████| 895 kB 41.1 MB/s 
[?25h

To translate some text:

In [None]:
import dl_translate as dlt

mt = dlt.TranslationModel()

text_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
mt.translate(text_hi, source=dlt.lang.HINDI, target=dlt.lang.ENGLISH)

Downloading:   0%|          | 0.00/3.54M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/272 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/908 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.80G [00:00<?, ?B/s]

'UN chief says there is no military solution in Syria'

Above, you can see that `dlt.lang` contains variables representing each of the 50 available languages with auto-complete support. Alternatively, you can specify the language (e.g. "Arabic") or the language code (e.g. "fr" for French):

In [None]:
text_ar = ' سوف اسافر '
mt.translate(text_ar, source="Arabic", target="en")

'I will travel.'

If you want to verify whether a language is available, you can check it:

In [None]:
print(mt.available_languages())  # All languages that you can use
print(mt.available_codes())  # Code corresponding to each language accepted
print(mt.get_lang_code_map())  # Dictionary of lang -> code

('Afrikaans', 'Amharic', 'Arabic', 'Asturian', 'Azerbaijani', 'Bashkir', 'Belarusian', 'Bulgarian', 'Bengali', 'Breton', 'Bosnian', 'Catalan', 'Valencian', 'Cebuano', 'Czech', 'Welsh', 'Danish', 'German', 'Greek', 'English', 'Spanish', 'Estonian', 'Persian', 'Fulah', 'Finnish', 'French', 'Western Frisian', 'Irish', 'Gaelic', 'Scottish Gaelic', 'Galician', 'Gujarati', 'Hausa', 'Hebrew', 'Hindi', 'Croatian', 'Haitian', 'Haitian Creole', 'Hungarian', 'Armenian', 'Indonesian', 'Igbo', 'Iloko', 'Icelandic', 'Italian', 'Japanese', 'Javanese', 'Georgian', 'Kazakh', 'Khmer', 'Central Khmer', 'Kannada', 'Korean', 'Luxembourgish', 'Letzeburgesch', 'Ganda', 'Lingala', 'Lao', 'Lithuanian', 'Latvian', 'Malagasy', 'Macedonian', 'Malayalam', 'Mongolian', 'Marathi', 'Malay', 'Burmese', 'Nepali', 'Dutch', 'Flemish', 'Norwegian', 'Northern Sotho', 'Occitan', 'Oriya', 'Panjabi', 'Punjabi', 'Polish', 'Pushto', 'Pashto', 'Portuguese', 'Romanian', 'Moldavian', 'Moldovan', 'Russian', 'Sindhi', 'Sinhala', 'Si

## Usage

### Selecting a device

When you load the model, you can specify the device:
```python
mt = dlt.TranslationModel(device="auto")
```

By default, the value will be `device="auto"`, which means it will use a GPU if possible. You can also explicitly set `device="cpu"` or `device="gpu"`, or some other strings accepted by [`torch.device()`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device). __In general, it is recommend to use a GPU if you want a reasonable processing time.__

Let's check what we originally loaded:

In [None]:
mt.device

device(type='cuda')

### Loading from a path

By default, `dlt.TranslationModel` will download the model from the [huggingface repo](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) and cache it. However, you are free to load from a path:
```python
mt = dlt.TranslationModel("/path/to/your/model/directory/", model_family="mbart50")
```
Make sure that your tokenizer is also stored in the same directory if you use this approach.


### Using a different model

You can also choose another model that has [a similar format](https://huggingface.co/models?filter=mbart-50). In those cases, it's preferable to specify the model family:
```python
mt = dlt.TranslationModel("facebook/mbart-large-50-one-to-many-mmt")
mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100")
```
Note that the available languages will change if you do this, so you will not be able to leverage `dlt.lang` or `dlt.utils`.


### Breaking down into sentences

It is not recommended to use extremely long texts as it takes more time to process. Instead, you can try to break them down into sentences with the help of `nltk`. First install the library with `pip install nltk`, then run:

In [None]:
import nltk
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
text = "Mr. Smith went to his favorite cafe. There, he met his friend Dr. Doe."
sents = nltk.tokenize.sent_tokenize(text, "english")  # don't use dlt.lang.ENGLISH
" ".join(mt.translate(sents, source=dlt.lang.ENGLISH, target='Arabic'))

'سميث ذهب إلى مقهى المفضل لديه. هناك، التقى صديقه الدكتور دو.'

### Setting a `batch_size` and verbosity when calling `dlt.TranslationModel.translate`

It's possible to set a batch size (i.e. the number of elements processed at once) for `mt.translate` and whether you want to see the progress bar or not:

In [None]:
mt.translate(sents, source=dlt.lang.ENGLISH, target='Arabic', batch_size=32, verbose=True)

  0%|          | 0/1 [00:00<?, ?it/s]

['سميث ذهب إلى مقهى المفضل لديه.', 'هناك، التقى صديقه الدكتور دو.']

If you set `batch_size=None`, it will compute the entire `text` at once rather than splitting into "chunks". We recommend lowering `batch_size` if you do not have a lot of RAM or VRAM and run into CUDA memory error. Set a higher value if you are using a high-end GPU and the VRAM is not fully utilized.


### `dlt.utils` module

An alternative to `mt.available_languages()` is the `dlt.utils` module. You can use it to find out which languages and codes are available:


In [None]:
print(dlt.utils.available_languages('mbart50'))  # All languages that you can use
print(dlt.utils.available_codes('mbart50'))  # Code corresponding to each language accepted
print(dlt.utils.get_lang_code_map('mbart50'))  # Dictionary of lang -> code

('Arabic', 'Czech', 'German', 'English', 'Spanish', 'Estonian', 'Finnish', 'French', 'Gujarati', 'Hindi', 'Italian', 'Japanese', 'Kazakh', 'Korean', 'Lithuanian', 'Latvian', 'Burmese', 'Nepali', 'Dutch', 'Romanian', 'Russian', 'Sinhala', 'Turkish', 'Vietnamese', 'Chinese', 'Afrikaans', 'Azerbaijani', 'Bengali', 'Persian', 'Hebrew', 'Croatian', 'Indonesian', 'Georgian', 'Khmer', 'Macedonian', 'Malayalam', 'Mongolian', 'Marathi', 'Polish', 'Pashto', 'Portuguese', 'Swedish', 'Swahili', 'Tamil', 'Telugu', 'Thai', 'Tagalog', 'Ukrainian', 'Urdu', 'Xhosa', 'Galician', 'Slovene')
('ar_AR', 'cs_CZ', 'de_DE', 'en_XX', 'es_XX', 'et_EE', 'fi_FI', 'fr_XX', 'gu_IN', 'hi_IN', 'it_IT', 'ja_XX', 'kk_KZ', 'ko_KR', 'lt_LT', 'lv_LV', 'my_MM', 'ne_NP', 'nl_XX', 'ro_RO', 'ru_RU', 'si_LK', 'tr_TR', 'vi_VN', 'zh_CN', 'af_ZA', 'az_AZ', 'bn_IN', 'fa_IR', 'he_IL', 'hr_HR', 'id_ID', 'ka_GE', 'km_KH', 'mk_MK', 'ml_IN', 'mn_MN', 'mr_IN', 'pl_PL', 'ps_AF', 'pt_XX', 'sv_SE', 'sw_KE', 'ta_IN', 'te_IN', 'th_TH', 'tl_XX

## Advanced

The following section assumes you have knowledge of PyTorch and Huggingface Transformers.

### Saving and loading

If you wish to accelerate the loading time the translation model, you can use `save_obj`:


In [None]:
mt.save_obj("saved_model")


Then later you can reload it with `load_obj`:

In [None]:
%%time
mt = dlt.TranslationModel.load_obj('saved_model')

CPU times: user 570 ms, sys: 810 ms, total: 1.38 s
Wall time: 1.39 s



**Warning:** Only use this if you are certain the torch module saved in `saved_model/weights.pt` can be correctly loaded. Indeed, it is possible that the `huggingface`, `torch` or some other dependencies change between when you called `save_obj` and `load_obj`, and that might break your code. Thus, it is recommend to only run `load_obj` in the same environment/session as `save_obj`. **Note this method might be deprecated in the future once there's no speed benefit in loading this way.**


### Interacting with underlying model and tokenizer

When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer (which will respectively be passed to `MBartForConditionalGeneration.from_pretrained` and `MBart50TokenizerFast.from_pretrained`):

```python
mt = dlt.TranslationModel(
    model_options=dict(
        state_dict=...,
        cache_dir=...,
        ...
    ),
    tokenizer_options=dict(
        tokenizer_file=...,
        eos_token=...,
        ...
    )
)
```

You can also access the underlying `transformers` model and `tokenizer`:

In [None]:
bart = mt.get_transformers_model()
tokenizer = mt.get_tokenizer()

print(tokenizer)
print(bart)

See the [huggingface docs](https://huggingface.co/transformers/master/model_doc/mbart.html) for more information.


### `bart_model.generate()` keyword arguments

When running `mt.translate`, you can also give a `generation_options` dictionary that is passed as keyword arguments to the underlying `bart_model.generate()` method:

In [None]:
mt.translate(
    sents,
    source=dlt.lang.ENGLISH,
    target=dlt.lang.SPANISH,
    generation_options=dict(num_beams=5, max_length=128)
)

Learn more in the [huggingface docs](https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate).


## Acknowledgement

`dl-translate` is built on top of Huggingface's implementation of two models created by Facebook AI Research.

1. The multilingual BART finetuned on many-to-many translation of over 50 languages, which is [documented here](https://huggingface.co/transformers/master/model_doc/mbart.html) The original paper was written by Tang et. al from Facebook AI Research; you can [find it here](https://arxiv.org/pdf/2008.00401.pdf) and cite it using the following:
    ```
    @article{tang2020multilingual,
        title={Multilingual translation with extensible multilingual pretraining and finetuning},
        author={Tang, Yuqing and Tran, Chau and Li, Xian and Chen, Peng-Jen and Goyal, Naman and Chaudhary, Vishrav and Gu, Jiatao and Fan, Angela},
        journal={arXiv preprint arXiv:2008.00401},
        year={2020}
    }
    ```
2. The transformer model published in [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Fan et. al, which supports over 100 languages. You can cite it here:
   ```
   @misc{fan2020englishcentric,
        title={Beyond English-Centric Multilingual Machine Translation}, 
        author={Angela Fan and Shruti Bhosale and Holger Schwenk and Zhiyi Ma and Ahmed El-Kishky and Siddharth Goyal and Mandeep Baines and Onur Celebi and Guillaume Wenzek and Vishrav Chaudhary and Naman Goyal and Tom Birch and Vitaliy Liptchinsky and Sergey Edunov and Edouard Grave and Michael Auli and Armand Joulin},
        year={2020},
        eprint={2010.11125},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
    }
   ```

`dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example:
```python
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# translate Hindi to French
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."

# translate Arabic to English
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "The Secretary-General of the United Nations says there is no military solution in Syria."
```

With `dlt`, you can run:
```python
import dl_translate as dlt

article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

mt = dlt.TranslationModel()
translated_fr = mt.translate(article_hi, source=dlt.lang.HINDI, target=dlt.lang.FRENCH)
translated_en = mt.translate(article_ar, source=dlt.lang.ARABIC, target=dlt.lang.ENGLISH)
```

Notice you don't have to think about tokenizers, condition generation, pretrained models, and regional codes; you can just tell the model what to translate!

If you are experienced with `huggingface`'s ecosystem, then you should be familiar enough with the example above that you wouldn't need this library. However, if you've never heard of huggingface or mBART, then I hope using this library will give you enough motivation to [learn more about them](https://github.com/huggingface/transformers) :)