Skip to content

Commit

Permalink
Fix #75, #102, #131. Add Back Translation Aug
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Aug 8, 2020
1 parent 32beaa6 commit 9826d89
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/augmenter/word/back_translaton.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
nlpaug.augmenter.word\.back_translatoin
========================================

.. automodule:: nlpaug.augmenter.word.back_translatoin
:members:
:inherited-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/augmenter/word/word.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Word Augmenter
:maxdepth: 6

./antonym
./back_translation
./context_word_embs
./random
./spelling
Expand Down
1 change: 1 addition & 0 deletions nlpaug/augmenter/word/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from nlpaug.augmenter.word.synonym import *
from nlpaug.augmenter.word.antonym import *
from nlpaug.augmenter.word.split import *
from nlpaug.augmenter.word.back_translation import *
82 changes: 82 additions & 0 deletions nlpaug/augmenter/word/back_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Augmenter that apply operation (word level) to textual input based on back translation.
"""

import string
import os

from nlpaug.augmenter.word import WordAugmenter
import nlpaug.model.lang_models as nml

BACK_TRANSLATION_MODELS = {}


def init_back_translatoin_model(from_model_name, from_model_checkpt, to_model_name, to_model_checkpt,
tokenzier_name, bpe_name, device, force_reload=False):
global BACK_TRANSLATION_MODELS

model_name = '_'.join([from_model_name, to_model_name])
if model_name in BACK_TRANSLATION_MODELS and not force_reload:
return BACK_TRANSLATION_MODELS[model_name]
model = nml.Fairseq(from_model_name=from_model_name, from_model_checkpt=from_model_checkpt,
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt,
tokenzier_name=tokenzier_name, bpe_name=bpe_name, device=device)

BACK_TRANSLATION_MODELS[model_name] = model
return model


class BackTranslationAug(WordAugmenter):
# https://arxiv.org/pdf/1511.06709.pdf
"""
Augmenter that leverage two translation models for augmentation. For example, the source is English. This
augmenter translate source to German and translating it back to English. For detail, you may visit
https://towardsdatascience.com/data-augmentation-in-nlp-2801a34dfc28
:param str from_model_name: Language of your text. Veriried 'transformer.wmt19.en-de', 'transformer.wmt19.de-en',
'transformer.wmt19.en-ru' and 'transformer.wmt19.ru-en'
:param str to_model_name: Language for translation. Veriried 'transformer.wmt19.en-de', 'transformer.wmt19.de-en',
'transformer.wmt19.en-ru' and 'transformer.wmt19.ru-en'
:param str tokenizer: Default value is 'moses'
:param str bpe: Default value is 'fastbpe'
:param str device: Use either cpu or gpu. Default value is None, it uses GPU if having. While possible values are
'cuda' and 'cpu'.
:param bool force_reload: Force reload the contextual word embeddings model to memory when initialize the class.
Default value is False and suggesting to keep it as False if performance is the consideration.
:param str name: Name of this augmenter
>>> import nlpaug.augmenter.word as naw
>>> aug = naw.BackTranslationAug()
"""

def __init__(self, from_model_name, to_model_name, from_model_checkpt='model1.pt', to_model_checkpt='model1.pt',
tokenizer='moses', bpe='fastbpe', name='BackTranslationAug', device=None, force_reload=False, verbose=0):
super().__init__(
# TODO: does not support include detail
action='substitute', name=name, aug_p=None, aug_min=None, aug_max=None, tokenizer=None,
device=device, verbose=verbose, include_detail=False)


self.model = self.get_model(
from_model_name=from_model_name, from_model_checkpt=from_model_checkpt,
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt,
tokenzier_name=tokenizer, bpe_name=bpe, device=device
)
self.device = self.model.device

def substitute(self, data):
augmented_text = self.model.predict(data)
return augmented_text

@classmethod
def get_model(cls, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt,
tokenzier_name, bpe_name, device='cuda', force_reload=False):
return init_back_translatoin_model(from_model_name, from_model_checkpt,
to_model_name, to_model_checkpt, tokenzier_name, bpe_name,
device, force_reload
)

@classmethod
def clear_cache(cls):
global BACK_TRANSLATION_MODELS
BACK_TRANSLATION_MODELS = {}
1 change: 1 addition & 0 deletions nlpaug/model/lang_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from nlpaug.model.lang_models.gpt2 import *
from nlpaug.model.lang_models.distilbert import *
from nlpaug.model.lang_models.roberta import *
from nlpaug.model.lang_models.fairseq import *
50 changes: 50 additions & 0 deletions nlpaug/model/lang_models/fairseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
try:
import torch
except ImportError:
# No installation required if not using this function
pass

from nlpaug.model.lang_models import LanguageModels
from nlpaug.util.selection.filtering import *


class Fairseq(LanguageModels):
def __init__(self, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt, tokenzier_name='moses', bpe_name='fastbpe', device='cuda'):
super().__init__(device, temperature=None, top_k=None, top_p=None)

try:
import torch
import fairseq
self.device = 'cuda' if device is None and torch.cuda.is_available() else device
except ImportError:
raise ImportError('Missed torch, fairseq libraries. Install torch by following https://pytorch.org/get-started/locally/ and fairseq by '
'https://github.com/pytorch/fairseq')

self.from_model_name = from_model_name
self.from_model_checkpt = from_model_checkpt
self.to_model_name = to_model_name
self.to_model_checkpt = to_model_checkpt
self.tokenzier_name = tokenzier_name
self.bpe_name = bpe_name

# TODO: enahnce to support custom model. https://github.com/pytorch/fairseq/tree/master/examples/translation
self.from_model = torch.hub.load(
github='pytorch/fairseq', model=from_model_name,
checkpoint_file=from_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
self.to_model = torch.hub.load(
github='pytorch/fairseq', model=to_model_name,
checkpoint_file=to_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)

self.from_model.eval()
self.to_model.eval()
if self.device == 'cuda':
self.from_model.cuda()
self.to_model.cuda()

def predict(self, text):
translated_text = self.from_model.translate(text)
back_translated_text = self.to_model.translate(translated_text)

return back_translated_text
40 changes: 40 additions & 0 deletions test/augmenter/word/test_back_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
import os
from dotenv import load_dotenv

import nlpaug.augmenter.word as naw
import nlpaug.model.lang_models as nml


class TestBackTranslationAug(unittest.TestCase):
@classmethod
def setUpClass(cls):
env_config_path = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', '..', '..', '.env'))
load_dotenv(env_config_path)

cls.text = 'The quick brown fox jumps over the lazy dog'

cls.model_names = [{
'from_model_name': 'transformer.wmt19.en-ru',
'from_model_checkpt': 'model1.pt',
'to_model_name': 'transformer.wmt19.ru-en',
'to_model_checkpt': 'model1.pt'
}, {
'from_model_name': 'transformer.wmt19.en-de',
'from_model_checkpt': 'model1.pt',
'to_model_name': 'transformer.wmt19.de-en',
'to_model_checkpt': 'model1.pt'
}
]

def test_back_translation(self):
for model_name in self.model_names:
aug = naw.BackTranslationAug(
from_model_name=model_name['from_model_name'], from_model_checkpt=model_name['from_model_checkpt'],
to_model_name=model_name['to_model_name'], to_model_checkpt=model_name['to_model_checkpt'])
augmented_text = aug.augment(self.text)
aug.clear_cache()
self.assertNotEqual(self.text, augmented_text)

self.assertTrue(len(self.model_names) > 1)

0 comments on commit 9826d89

Please sign in to comment.