Skip to content

Commit

Permalink
Preserve original, unnormalized words in the alignment labels
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingleafe committed Oct 16, 2023
1 parent 769833d commit 994d480
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
4 changes: 2 additions & 2 deletions lhotse/workflows/forced_alignment/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import logging
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union

import torch
from tqdm.auto import tqdm
Expand Down Expand Up @@ -34,7 +34,7 @@ def normalize_text(

@abc.abstractmethod
def align(
self, audio: torch.Tensor, transcript: Union[str, List[str]]
self, audio: torch.Tensor, transcript: Union[str, List[Tuple[str, str]]]
) -> List[AlignmentItem]:
pass

Check warning on line 39 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L39

Added line #L39 was not covered by tests

Expand Down
63 changes: 41 additions & 22 deletions lhotse/workflows/forced_alignment/mms_aligner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import List, Optional
from typing import List, Optional, Tuple

import torch
import torchaudio
Expand Down Expand Up @@ -40,35 +40,42 @@ def __init__(self, bundle_name: str, device: str = "cpu"):
def sample_rate(self) -> int:
return self.bundle.sample_rate

Check warning on line 41 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L41

Added line #L41 was not covered by tests

def normalize_text(self, text: str, language=None) -> List[str]:
# Add spaces between words for languages which do not have them
text = _spacify(text, language)
def normalize_text(self, text: str, language=None) -> List[Tuple[str, str]]:
# Split text into words (possibly with adjacent punctuation)
orig_words = _word_tokenize(text, language)

Check warning on line 45 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L45

Added line #L45 was not covered by tests

romanized = self._uroman(text, language=language)
romanized_l = romanized.lower().replace("’", "'")
romanized_no_punct = re.sub(self.discard_regex, "", romanized_l)
words = romanized_no_punct.strip().split()
sep = _safe_separator(text)
romanized_words = self._uroman(sep.join(orig_words), language=language).split(

Check warning on line 48 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L47-L48

Added lines #L47 - L48 were not covered by tests
sep
)
romanized_l = [w.lower().replace("’", "'") for w in romanized_words]
norm_words = [re.sub(self.discard_regex, "", w).strip() for w in romanized_l]
word_pairs = list(zip(orig_words, norm_words))

Check warning on line 53 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L51-L53

Added lines #L51 - L53 were not covered by tests

# Remove standalone dashes - aligner doesn't like them
return [w for w in words if w != "-"]
# Remove empty words and standalone dashes (aligner doesn't like them)
return [(orig, norm) for orig, norm in word_pairs if norm != "" and norm != "-"]

Check warning on line 56 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L56

Added line #L56 was not covered by tests

def align(self, audio: torch.Tensor, transcript: List[str]) -> List[AlignmentItem]:
def align(
self, audio: torch.Tensor, transcript: List[Tuple[str, str]]
) -> List[AlignmentItem]:
try:
with torch.inference_mode():
emission, _ = self.model(audio)
token_spans = self.aligner(emission[0], self.tokenizer(transcript))
token_spans = self.aligner(

Check warning on line 64 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L61-L64

Added lines #L61 - L64 were not covered by tests
emission[0], self.tokenizer([p[1] for p in transcript])
)
except Exception as e:
raise FailedToAlign from e

Check warning on line 68 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L67-L68

Added lines #L67 - L68 were not covered by tests

ratio = audio.shape[1] / emission.shape[1] / self.sample_rate
return [

Check warning on line 71 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L70-L71

Added lines #L70 - L71 were not covered by tests
AlignmentItem(
symbol=word,
symbol=orig_word,
start=round(ratio * t_spans[0].start, ndigits=8),
duration=round(ratio * (t_spans[-1].end - t_spans[0].start), ndigits=8),
score=_merge_score(t_spans),
)
for t_spans, word in zip(token_spans, transcript)
for t_spans, (orig_word, _) in zip(token_spans, transcript)
]


Expand All @@ -78,19 +85,19 @@ def _merge_score(tspans):
)


def _spacify(text: str, language: Optional[str] = None) -> str:
def _word_tokenize(text: str, language: Optional[str] = None) -> List[str]:
"""
Add spaces between words for languages which do not have them.
"""

# TODO: maybe add some simplistic auto-language detection?
# many dataset recipes might not provide proper language tags to supervisions
if language is None:
return text
return text.split()

Check warning on line 96 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L95-L96

Added lines #L95 - L96 were not covered by tests

language = _normalize_language(language)
if language not in LANGUAGES_WITHOUT_SPACES:
return text
return text.split()

Check warning on line 100 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L98-L100

Added lines #L98 - L100 were not covered by tests

if language == "zh":
if not is_module_available("jieba"):
Expand All @@ -101,7 +108,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import jieba

Check warning on line 109 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L109

Added line #L109 was not covered by tests

return " ".join(jieba.cut(text))
return jieba.lcut(text)

Check warning on line 111 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L111

Added line #L111 was not covered by tests

elif language == "ja":
if not is_module_available("nagisa"):
Expand All @@ -112,7 +119,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import nagisa

Check warning on line 120 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L120

Added line #L120 was not covered by tests

return " ".join(nagisa.tagging(text).words)
return nagisa.tagging(text).words

Check warning on line 122 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L122

Added line #L122 was not covered by tests

elif language == "ko":
if not is_module_available("kss"):
Expand All @@ -123,7 +130,7 @@ def _spacify(text: str, language: Optional[str] = None) -> str:

import kss

Check warning on line 131 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L131

Added line #L131 was not covered by tests

return " ".join(kss.split_morphemes(text, return_pos=False))
return kss.split_morphemes(text, return_pos=False)

Check warning on line 133 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L133

Added line #L133 was not covered by tests

elif language == "th":

Check warning on line 135 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L135

Added line #L135 was not covered by tests
# `pythainlp` is alive and much better, but it is a huge package bloated with dependencies
Expand All @@ -136,10 +143,9 @@ def _spacify(text: str, language: Optional[str] = None) -> str:
from tltk import nlp

Check warning on line 143 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L143

Added line #L143 was not covered by tests

pieces = nlp.pos_tag(text)
words = [
return [

Check warning on line 146 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L145-L146

Added lines #L145 - L146 were not covered by tests
word if word != "<s/>" else " " for piece in pieces for word, _ in piece
]
return " ".join(words)

else:
logging.warning(

Check warning on line 151 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L151

Added line #L151 was not covered by tests
Expand All @@ -163,3 +169,16 @@ def _normalize_language(language: str) -> str:
except tag_parser.LanguageTagError:

Check warning on line 169 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L168-L169

Added lines #L168 - L169 were not covered by tests
# If it fails, try to parse the language name.
return Language.find(language).language

Check warning on line 171 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L171

Added line #L171 was not covered by tests


def _safe_separator(text):
"""
Returns a separator that is not present in the text.
"""
special_symbols = "#$%^&~_"
i = 0
while special_symbols[i] in text and i < len(special_symbols):
i += 1

Check warning on line 181 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L178-L181

Added lines #L178 - L181 were not covered by tests

# better use space than just fail
return special_symbols[i] if i < len(special_symbols) else " "

Check warning on line 184 in lhotse/workflows/forced_alignment/mms_aligner.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/mms_aligner.py#L184

Added line #L184 was not covered by tests

0 comments on commit 994d480

Please sign in to comment.