Skip to content

Commit

Permalink
Merge pull request #424 from saucam/txt_383
Browse files Browse the repository at this point in the history
Add translation pipeline parameter to return selected models and detected language
  • Loading branch information
davidmezzetti committed Feb 11, 2023
2 parents 679ba5f + 363ef72 commit 0ce31ee
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
33 changes: 21 additions & 12 deletions src/python/txtai/pipeline/text/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, path="facebook/m2m100_418M", quantize=False, gpu=True, batch=
self.models = {}
self.ids = self.modelids()

def __call__(self, texts, target="en", source=None):
def __call__(self, texts, target="en", source=None, showmodels=False):
"""
Translates text from source language into target language.
Expand All @@ -74,22 +74,30 @@ def __call__(self, texts, target="en", source=None):
languages = self.detect(values) if not source else [source] * len(values)
unique = set(languages)

# Build list of (index, language, text)
values = [(x, lang, values[x]) for x, lang in enumerate(languages)]
# Build a dict from language to list of (index, text)
langdict = {}
for x, lang in enumerate(languages):
if lang not in langdict:
langdict[lang] = []
langdict[lang].append((x, values[x]))

results = {}
for language in unique:
# Get all text values for language
inputs = [(x, text) for x, lang, text in values if lang == language]
# Get all indices and text values for a language
inputs = langdict[language]

# Translate text in batches
outputs = []
for chunk in self.batch([text for _, text in inputs], self.batchsize):
outputs.extend(self.translate(chunk, language, target))
outputs.extend(self.translate(chunk, language, target, showmodels))

# Store output value
for y, (x, _) in enumerate(inputs):
results[x] = outputs[y].strip()
if showmodels:
model, op = outputs[y]
results[x] = (op.strip(), language, model)
else:
results[x] = outputs[y].strip()

# Return results in same order as input
results = [results[x] for x in sorted(results)]
Expand Down Expand Up @@ -133,7 +141,7 @@ def detect(self, texts):

return [x[0].split("__")[-1] for x in self.detector.predict(texts)[0]]

def translate(self, texts, source, target):
def translate(self, texts, source, target, showmodels=False):
"""
Translates text from source to target language.
Expand All @@ -151,7 +159,7 @@ def translate(self, texts, source, target):
return texts

# Load model and tokenizer
model, tokenizer = self.lookup(source, target)
path, model, tokenizer = self.lookup(source, target)

model.to(self.device)
indices = None
Expand All @@ -176,10 +184,11 @@ def translate(self, texts, source, target):
# Combine translations - handle splits on large text from tokenizer
results, last = [], -1
for x, i in enumerate(indices):
v = (path, translated[x]) if showmodels else translated[x]
if i == last:
results[-1] += translated[x]
results[-1] += v
else:
results.append(translated[x])
results.append(v)

last = i

Expand All @@ -202,7 +211,7 @@ def lookup(self, source, target):
if path not in self.models:
self.models[path] = self.load(path)

return self.models[path]
return (path,) + self.models[path]

def modelpath(self, source, target):
"""
Expand Down
29 changes: 29 additions & 0 deletions test/python/testpipeline/testtranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,32 @@ def testNoTranslation(self):

# Validate no translation
self.assertEqual(text, translation)

def testTranslationWithShowmodels(self):
"""
Tests a translation using Marian models and showmodels flag to return
model and language.
"""

translate = Translation()

text = "This is a test translation into Spanish"
result = translate(text, "es", showmodels=True)

translation, language, modelpath = result
# Validate translation text
self.assertEqual(translation, "Esta es una traducción de prueba al español")
# Validate detected language
self.assertEqual(language, "en")
# Validate model
self.assertEqual(modelpath, "Helsinki-NLP/opus-mt-en-es")

# Validate translation back
result = translate(translation, "en", showmodels=True)

translation, language, modelpath = result
self.assertEqual(translation, text)
# Validate detected language
self.assertEqual(language, "es")
# Validate model
self.assertEqual(modelpath, "Helsinki-NLP/opus-mt-es-en")

0 comments on commit 0ce31ee

Please sign in to comment.