In [1]:
from dataset import read_atis

test = read_atis('test')

In [2]:
words = set()

for sent in test['utterance']:
    words.update(sent.split())

In [3]:
len(words)

2781

In [4]:
from functools import cache
from easynmt import EasyNMT


model = EasyNMT('m2m_100_418M', device='cuda', max_length=1)


@cache
def translate(word: str, source_lang: str, target_lang: str) -> str:
    return model.translate(word, source_lang=source_lang, target_lang=target_lang)

In [5]:
from tqdm import tqdm

from utils import load_config

languages = load_config()['languages']

translations = {lang: {lang1: {} for lang1 in languages if lang1 != lang} for lang in languages}

In [6]:
translations

{'en': {'de': {}, 'es': {}, 'fr': {}, 'ja': {}, 'pt': {}, 'zh': {}},
 'de': {'en': {}, 'es': {}, 'fr': {}, 'ja': {}, 'pt': {}, 'zh': {}},
 'es': {'en': {}, 'de': {}, 'fr': {}, 'ja': {}, 'pt': {}, 'zh': {}},
 'fr': {'en': {}, 'de': {}, 'es': {}, 'ja': {}, 'pt': {}, 'zh': {}},
 'ja': {'en': {}, 'de': {}, 'es': {}, 'fr': {}, 'pt': {}, 'zh': {}},
 'pt': {'en': {}, 'de': {}, 'es': {}, 'fr': {}, 'ja': {}, 'zh': {}},
 'zh': {'en': {}, 'de': {}, 'es': {}, 'fr': {}, 'ja': {}, 'pt': {}}}

In [7]:
for idx, row in tqdm(test.iterrows(), desc='translating', total=len(test)):
    source_lang = row['language']
    sent = row['utterance'].split()

    for target_lang in languages:
        if source_lang != target_lang:
            for word in sent:
                translations[source_lang][target_lang][word] = translate(
                    word,
                    source_lang=source_lang,
                    target_lang=target_lang
                )

translating: 100%|██████████| 5285/5285 [31:33<00:00,  2.79it/s]  


In [9]:
import torch

torch.save(translations, 'data/atis_test_translations/translations.pt')