In this notebook, we build a classe for inference using JoeyNMT model.

We assume that your have your best JoeyNMT model on drive and want to use for inference.

In [1]:
# # install Joeynmt
# !pip install joeynmt==2.3.0

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
# # copy the model from drive
# !cp -R /content/drive/MyDrive/mt-dyu-fr/dyu_fr ./models

## (Optional) Download model from Hugging Face

In [11]:
# # Remember to run `huggingface-cli login` before you run the code below
# from huggingface_hub import snapshot_download

# HF_REPO_NAME = "MelioAI/dyu-fr-joeynmt"
# local_dir = "../saved_model/lean_model"

# snapshot_download(
#     repo_id=HF_REPO_NAME,
#     # allow_patterns=["*.md", "*.json"],
#     # ignore_patterns="vocab.json",
#     local_dir=local_dir
# )

In [4]:
import torch
from joeynmt.config import load_config, parse_global_args
from joeynmt.prediction import predict, prepare


class JoeyNMTModel:
    """
    JoeyNMTModel which load JoeyNMT model for inference.

    :param config_path: Path to YAML config file
    :param n_best: return this many hypotheses, <= beam (currently only 1)
    """
    def __init__(self, config_path: str, n_best: int = 1):
        seed = 42
        torch.manual_seed(seed)
        cfg = load_config(config_path)
        args = parse_global_args(cfg, rank=0, mode="translate")
        self.args = args._replace(test=args.test._replace(n_best=n_best))
        # build model
        self.model, _, _, self.test_data = prepare(self.args, rank=0, mode="translate")

    def _translate_data(self):
        _, _, hypotheses, trg_tokens, trg_scores, _ = predict(
            model=self.model,
            data=self.test_data,
            compute_loss=False,
            device=self.args.device,
            rank=0,
            n_gpu=self.args.n_gpu,
            normalization="none",
            num_workers=self.args.num_workers,
            args=self.args.test,
            autocast=self.args.autocast,
        )
        return hypotheses, trg_tokens, trg_scores

    def translate(self, sentence) -> list:
        """
        Translate the given sentence.

        :param sentence: Sentence to be translated
        :return:
        - translations: (list of str) possible translations of the sentence.
        """
        self.test_data.set_item(sentence.strip())
        translations, _, _ = self._translate_data()
        assert len(translations) == len(self.test_data) * self.args.test.n_best
        self.test_data.reset_cache()
        return translations


In [6]:
# load the model
config_path = "../saved_model/lean_model/config_local.yaml" # Change this to the path to your model congig file
model = JoeyNMTModel(config_path=config_path, n_best=1)

2024-05-07 15:14:14,250 - INFO - joeynmt.data - Building tokenizer...
2024-05-07 15:14:14,262 - INFO - joeynmt.tokenizers - dyu tokenizer: SentencePieceTokenizer(level=bpe, lowercase=False, normalize=False, filter_by_length=(-1, 100), pretokenizer=none, tokenizer=SentencePieceProcessor, nbest_size=5, alpha=0.0)
2024-05-07 15:14:14,263 - INFO - joeynmt.tokenizers - fr tokenizer: SentencePieceTokenizer(level=bpe, lowercase=False, normalize=False, filter_by_length=(-1, 100), pretokenizer=none, tokenizer=SentencePieceProcessor, nbest_size=5, alpha=0.0)
2024-05-07 15:14:14,263 - INFO - joeynmt.data - Building vocabulary...
2024-05-07 15:14:14,398 - INFO - joeynmt.data - Data loaded.
2024-05-07 15:14:14,398 - INFO - joeynmt.data - Train dataset: None
2024-05-07 15:14:14,399 - INFO - joeynmt.data - Valid dataset: None
2024-05-07 15:14:14,399 - INFO - joeynmt.data -  Test dataset: StreamDataset(split=test, len=0, src_lang="dyu", trg_lang="fr", has_trg=False, random_subset=-1, has_src_prompt=Fa

In [7]:
model

<__main__.JoeyNMTModel at 0x173d12650>

In [8]:
# translate a sentence
model.translate(sentence="i tɔgɔ bi cogodɔ")

2024-05-07 15:14:21,128 - INFO - joeynmt.prediction - Predicting 1 example(s)... (Beam search with beam_size=5, beam_alpha=1.0, n_best=1, min_output_length=1, max_output_length=100, return_prob='none', generate_unk=True, repetition_penalty=-1, no_repeat_ngram_size=-1)
2024-05-07 15:14:21,262 - INFO - joeynmt.prediction - Generation took 0.1326[sec].


['c’est ce qu’est pas']