Skip to content

Commit

Permalink
Merge pull request #444 from saucam/txt_423
Browse files Browse the repository at this point in the history
Add custom detection function for language detection
  • Loading branch information
davidmezzetti committed Mar 11, 2023
2 parents 71022c8 + e96bc93 commit 8ecfd01
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -7,3 +7,4 @@ __pycache__/
.coverage
.coverage.*
*.pyc
.vscode/
27 changes: 25 additions & 2 deletions src/python/txtai/pipeline/text/translation.py
Expand Up @@ -26,7 +26,7 @@ class Translation(HFModel):
# Default language detection model
DEFAULT_LANG_DETECT = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz"

def __init__(self, path="facebook/m2m100_418M", quantize=False, gpu=True, batch=64, langdetect=DEFAULT_LANG_DETECT, findmodels=True):
def __init__(self, path="facebook/m2m100_418M", quantize=False, gpu=True, batch=64, langdetect=None, findmodels=True):
"""
Constructs a new language translation pipeline.
Expand Down Expand Up @@ -125,6 +125,29 @@ def detect(self, texts):
list of languages
"""

# Backwards compatible to load fasttext model
if not self.langdetect or isinstance(self.langdetect, str):
return self.defaultdetect(texts)

# Call external language detector
return self.langdetect(texts)

def defaultdetect(self, texts):
"""
Detects the language for each element in texts.
Default path that uses fasttext.
Args:
texts: list of text
Returns:
list of languages
"""

path = self.langdetect
if not path:
path = self.DEFAULT_LANG_DETECT

if not FASTTEXT:
raise ImportError('Language detection is not available - install "pipeline" extra to enable')

Expand All @@ -133,7 +156,7 @@ def detect(self, texts):
fasttext.FastText.eprint = lambda x: None

# Load language detection model
path = cached_download(self.langdetect, legacy_cache_layout=True)
path = cached_download(path, legacy_cache_layout=True)
self.detector = fasttext.load_model(path)

# Transform texts to format expected by language detection model
Expand Down
26 changes: 26 additions & 0 deletions test/python/testpipeline/testtranslation.py
Expand Up @@ -13,6 +13,32 @@ class TestTranslation(unittest.TestCase):
Translation tests.
"""

def testDetect(self):
"""
Test language detection
"""
translate = Translation()

test = ["This is a test language detection."]
language = translate.detect(test)

self.assertListEqual(language, ["en"])

def testDetectWithCustomFunc(self):
"""
Test language detection with custom function
"""

def dummy_func(text):
return ["en" for x in text]

translate = Translation(langdetect=dummy_func)

test = ["This is a test language detection."]
language = translate.detect(test)

self.assertListEqual(language, ["en"])

def testLongTranslation(self):
"""
Test a translation longer than max tokenization length
Expand Down

0 comments on commit 8ecfd01

Please sign in to comment.