From 39f89c64143dfff3c565160600cbef704c70f87c Mon Sep 17 00:00:00 2001 From: Dmitrii Mukhutdinov Date: Mon, 16 Oct 2023 07:35:03 +0000 Subject: [PATCH] Support for aligning in parallel --- lhotse/workflows/forced_alignment/base.py | 68 +++++++++++++++++-- lhotse/workflows/forced_alignment/workflow.py | 8 ++- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/lhotse/workflows/forced_alignment/base.py b/lhotse/workflows/forced_alignment/base.py index 3b92c7f2c..e4ba66840 100644 --- a/lhotse/workflows/forced_alignment/base.py +++ b/lhotse/workflows/forced_alignment/base.py @@ -1,6 +1,10 @@ import abc import logging -from typing import Generator, List, Optional, Tuple, Union +import multiprocessing +from concurrent.futures import ProcessPoolExecutor +from functools import partial, partialmethod +from itertools import chain +from typing import Dict, Generator, List, Optional, Tuple, Type, Union import torch from tqdm.auto import tqdm @@ -74,15 +78,67 @@ def __call__(self, cut: Cut, normalize: bool = True) -> Cut: class ForcedAlignmentProcessor: """ - TODO: Make multiprocessing, like in VAD. + TODO: Too much copypaste between VAD and FA processors, need to abstract it out. """ - def __init__(self, aligner: ForcedAligner, verbose: bool = False): - self.aligner = aligner + _aligners: Dict[Optional[int], ForcedAligner] = {} + + def __init__( + self, + aligner_kls: Type[ForcedAligner], + bundle_name: str, + num_jobs: int = 1, + device: torch.device = torch.device("cpu"), + verbose: bool = False, + ): + self._make_aligner = partial( + aligner_kls, bundle_name=bundle_name, device=torch.device(device) + ) + self.num_jobs = num_jobs self.verbose = verbose + def _init_aligner(self): + pid = multiprocessing.current_process().pid + self._aligners[pid] = self._make_aligner() + + def _process_cut(self, cut: Cut, normalize: bool = True) -> Cut: + pid = multiprocessing.current_process().pid + aligner = self._aligners[pid] + return aligner(cut, normalize=normalize) + def __call__( self, cuts: CutSet, normalize: bool = True ) -> Generator[Cut, None, None]: - for cut in tqdm(cuts, desc="Aligning", disable=not self.verbose): - yield self.aligner(cut, normalize=normalize) + if self.num_jobs == 1: + aligner = self._make_aligner() + for cut in tqdm(cuts, desc="Aligning", disable=not self.verbose): + yield aligner(cut, normalize=normalize) + + else: + pool = ProcessPoolExecutor( + max_workers=self.num_jobs, + initializer=self._init_aligner, + mp_context=multiprocessing.get_context("spawn"), + ) + + with pool as executor: + try: + res = executor.map( + partial(self._process_cut, normalize=normalize), cuts + ) + for cut in tqdm( + res, desc="Aligning", total=len(cuts), disable=not self.verbose + ): + yield cut + except KeyboardInterrupt as exc: # pragma: no cover + pool.shutdown(wait=False) + if self.verbose: + print("Forced alignment interrupted by the user.") + raise exc + except Exception as exc: # pragma: no cover + pool.shutdown(wait=False) + raise RuntimeError( + "Forced alignment failed. Please report this issue." + ) from exc + finally: + self._aligners.clear() diff --git a/lhotse/workflows/forced_alignment/workflow.py b/lhotse/workflows/forced_alignment/workflow.py index 5ee6e71c8..e6d7e28d5 100644 --- a/lhotse/workflows/forced_alignment/workflow.py +++ b/lhotse/workflows/forced_alignment/workflow.py @@ -27,6 +27,8 @@ def align_with_torchaudio( bundle_name: str = "WAV2VEC2_ASR_BASE_960H", device: str = "cpu", normalize_text: bool = True, + num_jobs: int = 1, + verbose: bool = False, ) -> Generator[MonoCut, None, None]: """ Use a pretrained model from torchaudio (such as Wav2Vec2) to perform forced @@ -57,6 +59,10 @@ def align_with_torchaudio( """ AlignerClass = __get_aligner_class(bundle_name) processor = ForcedAlignmentProcessor( - AlignerClass(bundle_name, device=device), verbose=True + AlignerClass, + bundle_name, + device=device, + num_jobs=num_jobs, + verbose=verbose, ) return processor(cuts, normalize=normalize_text)