Skip to content

Commit

Permalink
Support for aligning in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingleafe committed Oct 16, 2023
1 parent ecc9cab commit 39f89c6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 7 deletions.
68 changes: 62 additions & 6 deletions lhotse/workflows/forced_alignment/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import abc
import logging
from typing import Generator, List, Optional, Tuple, Union
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
from functools import partial, partialmethod
from itertools import chain
from typing import Dict, Generator, List, Optional, Tuple, Type, Union

import torch
from tqdm.auto import tqdm
Expand Down Expand Up @@ -74,15 +78,67 @@ def __call__(self, cut: Cut, normalize: bool = True) -> Cut:

class ForcedAlignmentProcessor:
"""
TODO: Make multiprocessing, like in VAD.
TODO: Too much copypaste between VAD and FA processors, need to abstract it out.
"""

def __init__(self, aligner: ForcedAligner, verbose: bool = False):
self.aligner = aligner
_aligners: Dict[Optional[int], ForcedAligner] = {}

def __init__(
self,
aligner_kls: Type[ForcedAligner],
bundle_name: str,
num_jobs: int = 1,
device: torch.device = torch.device("cpu"),
verbose: bool = False,
):
self._make_aligner = partial(

Check warning on line 94 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L94

Added line #L94 was not covered by tests
aligner_kls, bundle_name=bundle_name, device=torch.device(device)
)
self.num_jobs = num_jobs
self.verbose = verbose

Check warning on line 98 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L97-L98

Added lines #L97 - L98 were not covered by tests

def _init_aligner(self):
pid = multiprocessing.current_process().pid
self._aligners[pid] = self._make_aligner()

Check warning on line 102 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L101-L102

Added lines #L101 - L102 were not covered by tests

def _process_cut(self, cut: Cut, normalize: bool = True) -> Cut:
pid = multiprocessing.current_process().pid
aligner = self._aligners[pid]
return aligner(cut, normalize=normalize)

Check warning on line 107 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L105-L107

Added lines #L105 - L107 were not covered by tests

def __call__(
self, cuts: CutSet, normalize: bool = True
) -> Generator[Cut, None, None]:
for cut in tqdm(cuts, desc="Aligning", disable=not self.verbose):
yield self.aligner(cut, normalize=normalize)
if self.num_jobs == 1:
aligner = self._make_aligner()
for cut in tqdm(cuts, desc="Aligning", disable=not self.verbose):
yield aligner(cut, normalize=normalize)

Check warning on line 115 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L112-L115

Added lines #L112 - L115 were not covered by tests

else:
pool = ProcessPoolExecutor(

Check warning on line 118 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L118

Added line #L118 was not covered by tests
max_workers=self.num_jobs,
initializer=self._init_aligner,
mp_context=multiprocessing.get_context("spawn"),
)

with pool as executor:
try:
res = executor.map(

Check warning on line 126 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L124-L126

Added lines #L124 - L126 were not covered by tests
partial(self._process_cut, normalize=normalize), cuts
)
for cut in tqdm(

Check warning on line 129 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L129

Added line #L129 was not covered by tests
res, desc="Aligning", total=len(cuts), disable=not self.verbose
):
yield cut

Check warning on line 132 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L132

Added line #L132 was not covered by tests
except KeyboardInterrupt as exc: # pragma: no cover
pool.shutdown(wait=False)
if self.verbose:
print("Forced alignment interrupted by the user.")
raise exc
except Exception as exc: # pragma: no cover
pool.shutdown(wait=False)
raise RuntimeError(
"Forced alignment failed. Please report this issue."
) from exc
finally:
self._aligners.clear()

Check warning on line 144 in lhotse/workflows/forced_alignment/base.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/base.py#L144

Added line #L144 was not covered by tests
8 changes: 7 additions & 1 deletion lhotse/workflows/forced_alignment/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def align_with_torchaudio(
bundle_name: str = "WAV2VEC2_ASR_BASE_960H",
device: str = "cpu",
normalize_text: bool = True,
num_jobs: int = 1,
verbose: bool = False,
) -> Generator[MonoCut, None, None]:
"""
Use a pretrained model from torchaudio (such as Wav2Vec2) to perform forced
Expand Down Expand Up @@ -57,6 +59,10 @@ def align_with_torchaudio(
"""
AlignerClass = __get_aligner_class(bundle_name)
processor = ForcedAlignmentProcessor(

Check warning on line 61 in lhotse/workflows/forced_alignment/workflow.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/workflow.py#L60-L61

Added lines #L60 - L61 were not covered by tests
AlignerClass(bundle_name, device=device), verbose=True
AlignerClass,
bundle_name,
device=device,
num_jobs=num_jobs,
verbose=verbose,
)
return processor(cuts, normalize=normalize_text)

Check warning on line 68 in lhotse/workflows/forced_alignment/workflow.py

View check run for this annotation

Codecov / codecov/patch

lhotse/workflows/forced_alignment/workflow.py#L68

Added line #L68 was not covered by tests

0 comments on commit 39f89c6

Please sign in to comment.