Skip to content

Commit

Permalink
Preserve original, unnormalized words in the alignment labels
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingleafe committed Oct 16, 2023
1 parent 448c081 commit ecc9cab
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
4 changes: 2 additions & 2 deletions lhotse/workflows/forced_alignment/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import logging
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union

import torch
from tqdm.auto import tqdm
Expand Down Expand Up @@ -34,7 +34,7 @@ def normalize_text(

@abc.abstractmethod
def align(
self, audio: torch.Tensor, transcript: Union[str, List[str]]
self, audio: torch.Tensor, transcript: Union[str, List[Tuple[str, str]]]
) -> List[AlignmentItem]:
pass

Expand Down
63 changes: 41 additions & 22 deletions lhotse/workflows/forced_alignment/mms_aligner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional
from typing import List, Optional, Tuple

import torch
import torchaudio
Expand Down Expand Up @@ -40,35 +40,42 @@ def __init__(self, bundle_name: str, device: str = "cpu"):
def sample_rate(self) -> int:
return self.bundle.sample_rate

def normalize_text(self, text: str, language=None) -> List[str]:
# Add spaces between words for languages which do not have them
text = _spacify(text, language)
def normalize_text(self, text: str, language=None) -> List[Tuple[str, str]]:
# Split text into words (possibly with adjacent punctuation)
orig_words = _word_tokenize(text, language)

romanized = self._uroman(text, language=language)
romanized_l = romanized.lower().replace("’", "'")
romanized_no_punct = re.sub(self.discard_regex, "", romanized_l)
words = romanized_no_punct.strip().split()
sep = _safe_separator(text)
romanized_words = self._uroman(sep.join(orig_words), language=language).split(
sep
)
romanized_l = [w.lower().replace("’", "'") for w in romanized_words]
norm_words = [re.sub(self.discard_regex, "", w).strip() for w in romanized_l]
word_pairs = list(zip(orig_words, norm_words))

# Remove standalone dashes - aligner doesn't like them
return [w for w in words if w != "-"]
# Remove empty words and standalone dashes (aligner doesn't like them)
return [(orig, norm) for orig, norm in word_pairs if norm != "" and norm != "-"]

def align(self, audio: torch.Tensor, transcript: List[str]) -> List[AlignmentItem]:
def align(
self, audio: torch.Tensor, transcript: List[Tuple[str, str]]
) -> List[AlignmentItem]:
try:
with torch.inference_mode():
emission, _ = self.model(audio)
token_spans = self.aligner(emission[0], self.tokenizer(transcript))
token_spans = self.aligner(
emission[0], self.tokenizer([p[1] for p in transcript])
)
except Exception as e:
raise FailedToAlign from e

ratio = audio.shape[1] / emission.shape[1] / self.sample_rate
return [
AlignmentItem(
symbol=word,
symbol=orig_word,
start=round(ratio * t_spans[0].start, ndigits=8),
duration=round(ratio * (t_spans[-1].end - t_spans[0].start), ndigits=8),
score=_merge_score(t_spans),
)
for t_spans, word in zip(token_spans, transcript)
for t_spans, (orig_word, _) in zip(token_spans, transcript)
]


Expand All @@ -78,19 +85,19 @@ def _merge_score(tspans):
)


def _spacify(text: str, language: Optional[str] = None) -> str:
def _word_tokenize(text: str, language: Optional[str] = None) -> List[str]:
"""
Add spaces between words for languages which do not have them.
"""

# TODO: maybe add some simplistic auto-language detection?
# many dataset recipes might not provide proper language tags to supervisions
if language is None:
return text
return text.split()

language = _normalize_language(language)
if language not in LANGUAGES_WITHOUT_SPACES:
return text
return text.split()

if language == "zh":
if not is_module_available("jieba"):
Expand All @@ -101,7 +108,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import jieba

return " ".join(jieba.cut(text))
return jieba.lcut(text)

elif language == "ja":
if not is_module_available("nagisa"):
Expand All @@ -112,7 +119,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import nagisa

return " ".join(nagisa.tagging(text).words)
return nagisa.tagging(text).words

elif language == "ko":
if not is_module_available("kss"):
Expand All @@ -123,7 +130,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import kss

return " ".join(kss.split_morphemes(text, return_pos=False))
return kss.split_morphemes(text, return_pos=False)

elif language == "th":
# `pythainlp` is alive and much better, but it is a huge package bloated with dependencies
Expand All @@ -136,10 +143,9 @@ def _spacify(text: str, language: Optional[str] = None) -> str:
from tltk import nlp

pieces = nlp.pos_tag(text)
words = [
return [
word if word != "<s/>" else " " for piece in pieces for word, _ in piece
]
return " ".join(words)

else:
logging.warning(
Expand All @@ -163,3 +169,16 @@ def _normalize_language(language: str) -> str:
except tag_parser.LanguageTagError:
# If it fails, try to parse the language name.
return Language.find(language).language


def _safe_separator(text):
"""
Returns a separator that is not present in the text.
"""
special_symbols = "#$%^&~_"
i = 0
while special_symbols[i] in text and i < len(special_symbols):
i += 1

# better use space than just fail
return special_symbols[i] if i < len(special_symbols) else " "

0 comments on commit ecc9cab

Please sign in to comment.