diff --git a/Jenkinsfile b/Jenkinsfile index 5782ad701d46..b59bc39edf03 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -293,55 +293,71 @@ pipeline { } } - stage('L2: ASR DALI dev run') { - when { - anyOf { - branch 'main' - changeRequest target: 'main' - } - } - failFast true - parallel { - stage('Speech to Text - DALI AudioToMelSpectrogramPreprocessor') { - steps { - sh 'python examples/asr/speech_to_text.py \ - model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ - +model.train_ds.use_dali=True \ - model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ - +model.validation_ds.use_dali=True \ - trainer.gpus=[0] \ - +trainer.fast_dev_run=True \ - exp_manager.exp_dir=examples/asr/speech_to_text_results' - sh 'rm -rf examples/asr/speech_to_text_results' - } - } - // TODO: This would fail due to an unnecessary torchaudio import. - // To be enabled once torchaudio is available in the container used for CI - // stage('Speech to Text - DALI AudioToMFCCPreprocessor') { - // steps { - // sh 'python examples/asr/speech_to_text.py \ - // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ - // +model.train_ds.use_dali=True \ - // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ - // +model.validation_ds.use_dali=True \ - // model.preprocessor._target_=nemo.collections.asr.modules.AudioToMFCCPreprocessor \ - // ~model.preprocessor.normalize \ - // ~model.preprocessor.features \ - // ~model.preprocessor.frame_splicing \ - // ~model.preprocessor.dither \ - // ~model.preprocessor.stft_conv \ - // +model.n_mels=64 \ - // +model.n_mfcc=64 \ - // trainer.gpus=[0] \ - // +trainer.fast_dev_run=True \ - // exp_manager.exp_dir=examples/asr/speech_to_text_results' - // sh 'rm -rf examples/asr/speech_to_text_results' - // } - // } - } - } + // TODO: Enable test after 21.08 container is used. + // stage('L2: ASR DALI dev run') { + // when { + // anyOf { + // branch 'main' + // changeRequest target: 'main' + // } + // } + // failFast true + // parallel { + // stage('Speech to Text - DALI AudioToMelSpectrogramPreprocessor') { + // steps { + // sh 'python examples/asr/speech_to_text.py \ + // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + // +model.train_ds.use_dali=True \ + // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + // +model.validation_ds.use_dali=True \ + // trainer.gpus=[0] \ + // +trainer.fast_dev_run=True \ + // exp_manager.exp_dir=examples/asr/speech_to_text_results' + // sh 'rm -rf examples/asr/speech_to_text_results' + // } + // } + // stage('Speech to Text BPE - DALI AudioToMelSpectrogramPreprocessor') { + // steps { + // sh 'python examples/asr/speech_to_text_bpe.py \ + // --config-path="conf/citrinet/" --config-name="config_bpe" \ + // model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ + // model.tokenizer.type="wpe" \ + // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + // +model.train_ds.use_dali=True \ + // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + // +model.validation_ds.use_dali=True \ + // trainer.gpus=[0] \ + // +trainer.fast_dev_run=True \ + // exp_manager.exp_dir=examples/asr/speech_to_text_wpe_results' + // sh 'rm -rf examples/asr/speech_to_text_wpe_results' + // } + // } + // // TODO: This would fail due to an unnecessary torchaudio import. + // // To be enabled once torchaudio is available in the container used for CI + // // stage('Speech to Text - DALI AudioToMFCCPreprocessor') { + // // steps { + // // sh 'python examples/asr/speech_to_text.py \ + // // model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ + // // +model.train_ds.use_dali=True \ + // // model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ + // // +model.validation_ds.use_dali=True \ + // // model.preprocessor._target_=nemo.collections.asr.modules.AudioToMFCCPreprocessor \ + // // ~model.preprocessor.normalize \ + // // ~model.preprocessor.features \ + // // ~model.preprocessor.frame_splicing \ + // // ~model.preprocessor.dither \ + // // ~model.preprocessor.stft_conv \ + // // +model.n_mels=64 \ + // // +model.n_mfcc=64 \ + // // trainer.gpus=[0] \ + // // +trainer.fast_dev_run=True \ + // // exp_manager.exp_dir=examples/asr/speech_to_text_results' + // // sh 'rm -rf examples/asr/speech_to_text_results' + // // } + // // } + // } + // } -// TODO: UNCOMMENT TESTS AFTER 21.04 release (numba 0.53 min requirement) stage('L2: ASR RNNT dev run') { when { anyOf { diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index cdff89d76f92..308b7e13723d 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -83,6 +83,65 @@ def _speech_collate_fn(batch, pad_id): return audio_signal, audio_lengths, tokens, tokens_lengths +class ASRManifestProcessor: + """ + Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + bos_id: Id of beginning of sequence symbol to append if not None. + eos_id: Id of end of sequence symbol to append if not None. + pad_id: Id of pad symbol. Defaults to 0. + """ + + def __init__( + self, + manifest_filepath: str, + parser: Union[str, Callable], + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: int = 0, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + ): + self.parser = parser + + self.collection = collections.ASRAudioText( + manifests_files=manifest_filepath.split(','), + parser=parser, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + + self.eos_id = eos_id + self.bos_id = bos_id + self.pad_id = pad_id + + def process_text(self, index) -> (List[int], int): + sample = self.collection[index] + + t, tl = sample.text_tokens, len(sample.text_tokens) + + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + return t, tl + + class _AudioTextDataset(Dataset): """ Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). @@ -134,24 +193,21 @@ def __init__( eos_id: Optional[int] = None, pad_id: int = 0, ): - self.parser = parser - - self.collection = collections.ASRAudioText( - manifests_files=manifest_filepath.split(','), + self.manifest_processor = ASRManifestProcessor( + manifest_filepath=manifest_filepath, parser=parser, - min_duration=min_duration, max_duration=max_duration, - max_number=max_utts, + min_duration=min_duration, + max_utts=max_utts, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, ) - self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim - self.eos_id = eos_id - self.bos_id = bos_id - self.pad_id = pad_id def __getitem__(self, index): - sample = self.collection[index] + sample = self.manifest_processor.collection[index] offset = sample.offset if offset is None: @@ -162,23 +218,17 @@ def __getitem__(self, index): ) f, fl = features, torch.tensor(features.shape[0]).long() - t, tl = sample.text_tokens, len(sample.text_tokens) - if self.bos_id is not None: - t = [self.bos_id] + t - tl += 1 - if self.eos_id is not None: - t = t + [self.eos_id] - tl += 1 + t, tl = self.manifest_processor.process_text(index) output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long() return output def __len__(self): - return len(self.collection) + return len(self.manifest_processor.collection) def _collate_fn(self, batch): - return _speech_collate_fn(batch, pad_id=self.pad_id) + return _speech_collate_fn(batch, pad_id=self.manifest_processor.pad_id) class AudioToCharDataset(_AudioTextDataset): @@ -1249,6 +1299,8 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS] + tokens to beginning and ending of speech respectively. pad_id (id): Token used to pad when collating samples in batches. If this is None, pads using 0s. Defaults to None. diff --git a/nemo/collections/asr/data/audio_to_text_dali.py b/nemo/collections/asr/data/audio_to_text_dali.py index 4d0e6642b9ff..7cbbe64a06d5 100644 --- a/nemo/collections/asr/data/audio_to_text_dali.py +++ b/nemo/collections/asr/data/audio_to_text_dali.py @@ -13,13 +13,16 @@ # limitations under the License. import math +import operator from collections.abc import Iterator from typing import Callable, List, Optional, Union import torch from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor from nemo.collections.common.parts.preprocessing import parsers +from nemo.utils import logging, model_utils from nemo.utils.decorators import experimental try: @@ -34,8 +37,46 @@ __all__ = [ 'AudioToCharDALIDataset', + 'AudioToBPEDALIDataset', ] +""" +Below minimum version is required to access the "read_idxs" argument in +dali.fn.readers.nemo_asr +""" +__DALI_MINIMUM_VERSION__ = "1.4" + +DALI_INSTALLATION_MESSAGE = ( + "Could not import `nvidia.dali`.\n" + "Please install DALI by following the steps provided here - \n" + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html" +) + + +def is_dali_supported(min_version: str, verbose: bool = False) -> bool: + """ + Checks if DALI in installed, and version is >= min_verion. + + Args: + min_version: A semver str that is the minimum requirement. + verbose: Whether to log the installation instructions if DALI is not found. + + Returns: + bool - whether DALI could be imported or not. + """ + module_available, _ = model_utils.check_lib_version( + 'nvidia.dali', checked_version=min_version, operator=operator.ge + ) + + # If DALI is not installed + if module_available is None: + if verbose: + logging.info(DALI_INSTALLATION_MESSAGE) + + return False + + return module_available + class DALIOutputs(object): def __init__(self, out_dict): @@ -70,7 +111,7 @@ def __len__(self): @experimental -class AudioToCharDALIDataset(Iterator): +class _AudioTextDALIDataset(Iterator): """ NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a sample descriptor in JSON, including audio files, transcripts, and durations (in seconds). @@ -79,24 +120,23 @@ class AudioToCharDALIDataset(Iterator): ... {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths. - labels: String containing all the possible characters to map to. - sample_rate (int): Sample rate to resample loaded audio to. + device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'. batch_size (int): Number of samples in a batch. + parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int]. + sample_rate (int): Sample rate to resample loaded audio to. num_threads (int): Number of CPU processing threads to be created by the DALI pipeline. max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files. min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files. - blank_index (int): blank character index, default = -1 - unk_index (int): unk_character index, default = -1 - normalize (bool): whether to normalize transcript text (default): True bos_id (int): Id of beginning of sequence symbol to append if not None eos_id (int): Id of end of sequence symbol to append if not None + pad_id (int): Id used to pad the input. Defaults to 0 if not provided. trim (bool): If True, it will extract the nonsilent region of the loaded audio signal. shuffle (bool): If set to True, the dataset will shuffled after loading. drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size. If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. - device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'. device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. @@ -108,20 +148,17 @@ def __init__( manifest_filepath: str, device: str, batch_size: int, - labels: Union[str, List[str]], + parser: Union[str, Callable], sample_rate: int = 16000, num_threads: int = 4, max_duration: float = 0.0, min_duration: float = 0.0, - blank_index: int = -1, - unk_index: int = -1, - normalize: bool = True, bos_id: Optional[int] = None, eos_id: Optional[int] = None, + pad_id: int = 0, trim: bool = False, - shuffle: bool = True, + shuffle: bool = False, drop_last: bool = False, - parser: Union[str, Callable] = 'en', device_id: int = 0, global_rank: int = 0, world_size: int = 1, @@ -151,14 +188,6 @@ def __init__( self.shard_id = None self.num_shards = None - self.labels = labels - if self.labels is None or len(self.labels) == 0: - raise ValueError(f"{self} expects non empty labels list") - - self.parser = parsers.make_parser( - labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize, - ) - self.eos_id = eos_id self.bos_id = bos_id self.sample_rate = sample_rate @@ -278,7 +307,7 @@ def __init__( self.pad_value = params['pad_value'] if 'pad_value' in params else 0.0 with self.pipe: - audio, transcript = dali.fn.nemo_asr_reader( + audio, indices = dali.fn.readers.nemo_asr( name="Reader", manifest_filepaths=manifest_filepath.split(','), dtype=dali.types.FLOAT, @@ -287,16 +316,14 @@ def __init__( min_duration=min_duration, max_duration=max_duration, read_sample_rate=False, - read_text=True, + read_text=False, + read_idxs=True, random_shuffle=shuffle, shard_id=self.shard_id, num_shards=self.num_shards, - pad_last_batch=False, + pad_last_batch=True, ) - transcript_len = dali.fn.shapes(dali.fn.reshape(transcript, shape=[-1])) - transcript = dali.fn.pad(transcript) - # Extract nonsilent region, if necessary if trim: # Need to extract non-silent region before moving to the GPU @@ -312,7 +339,7 @@ def __init__( # No preprocessing, the output is the audio signal audio_len = dali.fn.shapes(dali.fn.reshape(audio, shape=[-1])) audio = dali.fn.pad(audio) - self.pipe.set_outputs(audio, audio_len, transcript, transcript_len) + self.pipe.set_outputs(audio, audio_len, indices) else: # Additive gaussian noise (dither) if self.dither > 0.0: @@ -358,14 +385,14 @@ def __init__( # Pads feature dimension to be a multiple of `pad_to` and the temporal dimension to be as big as the largest sample (shape -1) spec = dali.fn.pad(spec, fill_value=self.pad_value, axes=(0, 1), align=(self.pad_to, 1), shape=(1, -1)) - self.pipe.set_outputs(spec, spec_len, transcript, transcript_len) + self.pipe.set_outputs(spec, spec_len, indices) # Building DALI pipeline self.pipe.build() if has_preprocessor: - output_names = ['processed_signal', 'processed_signal_len', 'transcript_raw', 'transcript_raw_len'] + output_names = ['processed_signal', 'processed_signal_len', 'manifest_indices'] else: - output_names = ['audio', 'audio_len', 'transcript_raw', 'transcript_raw_len'] + output_names = ['audio', 'audio_len', 'manifest_indices'] last_batch_policy = LastBatchPolicy.DROP if drop_last else LastBatchPolicy.PARTIAL self._iter = DALIPytorchIterator( @@ -387,6 +414,17 @@ def __len__(self): self.dataset = DummyDataset(self) # Used by NeMo + self.manifest_processor = ASRManifestProcessor( + manifest_filepath=manifest_filepath, + parser=parser, + max_duration=max_duration, + min_duration=min_duration, + max_utts=0, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + ) + def reset(self): self._iter.reset() @@ -407,8 +445,7 @@ def __next__(self): outputs = self._iter.next() assert len(outputs) == 1 dali_out = outputs[0] - text_raw_len = dali_out['transcript_raw_len'].numpy() - text_raw = dali_out['transcript_raw'].numpy() + manifest_indices = dali_out['manifest_indices'].numpy() out = {} out_names = ['processed_signal', 'processed_signal_len', 'audio', 'audio_len'] @@ -419,22 +456,17 @@ def __next__(self): text_tokens = [] text_tokens_len = [] max_len = 0 - batch_size = text_raw.shape[0] - for i, text in enumerate(text_raw): - n = text_raw_len[i][0] - tbytes = str(text[:n].tobytes(), encoding='utf8') - ttokens = self.parser(tbytes) - if self.bos_id is not None: - ttokens = [self.bos_id] + ttokens - if self.eos_id is not None: - ttokens = ttokens + [self.eos_id] - ttokens_len = len(ttokens) - text_tokens_len.append(ttokens_len) - text_tokens.append(ttokens) - if ttokens_len > max_len: - max_len = ttokens_len - - transcript_out = torch.zeros(batch_size, max_len, dtype=torch.long) + batch_size = manifest_indices.shape[0] + for i, manifest_index in enumerate(manifest_indices): + manifest_index = manifest_index[0] + text, text_length = self.manifest_processor.process_text(manifest_index) + + text_tokens_len.append(text_length) + text_tokens.append(text) + if text_length > max_len: + max_len = text_length + + transcript_out = torch.full([batch_size, max_len], fill_value=self.manifest_processor.pad_id, dtype=torch.long) for i, n in enumerate(text_tokens_len): transcript_out[i, :n] = torch.tensor(text_tokens[i], dtype=torch.long) transcript_len_out = torch.tensor(text_tokens_len, dtype=torch.long) @@ -442,3 +474,191 @@ def __next__(self): out['transcript'] = transcript_out out['transcript_len'] = transcript_len_out return DALIOutputs(out) + + +class AudioToCharDALIDataset(_AudioTextDALIDataset): + """ + Character based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a + sample descriptor in JSON, including audio files, transcripts, and durations (in seconds). + Here's an example: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths. + device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'. + batch_size (int): Number of samples in a batch. + labels (List[str]): String containing all the possible characters to map to. + sample_rate (int): Sample rate to resample loaded audio to. + num_threads (int): Number of CPU processing threads to be created by the DALI pipeline. + max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files. + min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files. + blank_index (int): blank character index, default = -1 + unk_index (int): unk_character index, default = -1 + normalize (bool): whether to normalize transcript text (default): True + bos_id (int): Id of beginning of sequence symbol to append if not None + eos_id (int): Id of end of sequence symbol to append if not None + pad_id (int): Id used to pad the input. Defaults to 0 if not provided. + trim (bool): If True, it will extract the nonsilent region of the loaded audio signal. + shuffle (bool): If set to True, the dataset will shuffled after loading. + drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size. + If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + parser (str, callable): A str for an inbuilt parser, or a callable with signature f(str) -> List[int]. + device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0. + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. + preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor. + """ + + def __init__( + self, + manifest_filepath: str, + device: str, + batch_size: int, + labels: Union[str, List[str]], + sample_rate: int = 16000, + num_threads: int = 4, + max_duration: float = 0.0, + min_duration: float = 0.0, + blank_index: int = -1, + unk_index: int = -1, + normalize: bool = True, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + trim: bool = False, + shuffle: bool = False, + drop_last: bool = False, + parser: Union[str, Callable] = 'en', + device_id: int = 0, + global_rank: int = 0, + world_size: int = 1, + preprocessor_cfg: DictConfig = None, + ): + self.labels = labels + + parser = parsers.make_parser( + labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize + ) + + super().__init__( + manifest_filepath=manifest_filepath, + device=device, + batch_size=batch_size, + sample_rate=sample_rate, + num_threads=num_threads, + max_duration=max_duration, + min_duration=min_duration, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + trim=trim, + shuffle=shuffle, + drop_last=drop_last, + parser=parser, + device_id=device_id, + global_rank=global_rank, + world_size=world_size, + preprocessor_cfg=preprocessor_cfg, + ) + + +class AudioToBPEDALIDataset(_AudioTextDALIDataset): + """ + Subword based NVIDIA DALI pipeline that loads tensors via one or more manifest files where each line containing a + sample descriptor in JSON, including audio files, transcripts, and durations (in seconds). + Here's an example: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest file with the format described above. Can be comma-separated paths. + tokenizer (TokenizerSpec): A TokenizerSpec implementation that wraps a tokenization implementation. + device (str): Determines the device type to be used for preprocessing. Allowed values are: 'cpu', 'gpu'. + batch_size (int): Number of samples in a batch. + sample_rate (int): Sample rate to resample loaded audio to. + num_threads (int): Number of CPU processing threads to be created by the DALI pipeline. + max_duration (float): Determines the maximum allowed duration, in seconds, of the loaded audio files. + min_duration (float): Determines the minimum allowed duration, in seconds, of the loaded audio files. + bos_id (int): Id of beginning of sequence symbol to append if not None. Injected from the tokenizer. + eos_id (int): Id of end of sequence symbol to append if not None. Injected from the tokenizer. + pad_id (int): Id used to pad the input. Defaults to 0 if not provided. Injected from the tokenizer. + trim (bool): If True, it will extract the nonsilent region of the loaded audio signal. + shuffle (bool): If set to True, the dataset will shuffled after loading. + drop_last (bool): If set to True, the last batch will be dropped if incomplete. This will be the case when the shard size is not divisible by the batch size. + If set to False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + + device_id (int): Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0. + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. + preprocessor_cfg (DictConfig): Preprocessor configuration. Supports AudioToMelSpectrogramPreprocessor and AudioToMFCCPreprocessor. + use_start_end_token (bool): Boolean which dictates whether to add [BOS] and [EOS] tokens to beginning and + ending of speech respectively. + """ + + def __init__( + self, + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + device: str, + batch_size: int, + sample_rate: int = 16000, + num_threads: int = 4, + max_duration: float = 0.0, + min_duration: float = 0.0, + trim: bool = False, + shuffle: bool = False, + drop_last: bool = False, + device_id: int = 0, + global_rank: int = 0, + world_size: int = 1, + preprocessor_cfg: DictConfig = None, + use_start_end_token: bool = True, + ): + if use_start_end_token and hasattr(tokenizer, 'bos_token'): + bos_id = tokenizer.bos_id + else: + bos_id = None + + if use_start_end_token and hasattr(tokenizer, 'eos_token'): + eos_id = tokenizer.eos_id + else: + eos_id = None + + if hasattr(tokenizer, 'pad_token'): + pad_id = tokenizer.pad_id + else: + pad_id = 0 + + class TokenizerWrapper: + def __init__(self, tokenizer): + self._tokenizer = tokenizer + + def __call__(self, text): + t = self._tokenizer.text_to_ids(text) + return t + + super().__init__( + manifest_filepath=manifest_filepath, + device=device, + batch_size=batch_size, + sample_rate=sample_rate, + num_threads=num_threads, + max_duration=max_duration, + min_duration=min_duration, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + trim=trim, + shuffle=shuffle, + drop_last=drop_last, + parser=TokenizerWrapper(tokenizer), + device_id=device_id, + global_rank=global_rank, + world_size=world_size, + preprocessor_cfg=preprocessor_cfg, + ) diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index dc1b3fdae883..6e98a59bd190 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -245,3 +245,47 @@ def get_dali_char_dataset( preprocessor_cfg=preprocessor_cfg, ) return dataset + + +def get_dali_bpe_dataset( + config: dict, + tokenizer, + shuffle: bool, + device_id: int, + global_rank: int, + world_size: int, + preprocessor_cfg: Optional[DictConfig] = None, +) -> audio_to_text_dali.AudioToCharDALIDataset: + """ + Instantiates a Subword Encoding based AudioToBPEDALIDataset. + + Args: + config: Config of the AudioToBPEDALIDataset. + tokenizer: An implementation of NeMo TokenizerSpec. + shuffle: Bool flag whether to shuffle the dataset. + device_id: Index of the GPU to be used (local_rank). Only applicable when device == 'gpu'. Defaults to 0. + global_rank: Global rank of this device. + world_size: Global world size in the training method. + augmentor: Optional AudioAugmentor object for augmentations on audio data. + + Returns: + An instance of AudioToCharDALIDataset. + """ + device = 'gpu' if torch.cuda.is_available() else 'cpu' + dataset = audio_to_text_dali.AudioToBPEDALIDataset( + manifest_filepath=config['manifest_filepath'], + tokenizer=tokenizer, + device=device, + batch_size=config['batch_size'], + sample_rate=config['sample_rate'], + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', False), + use_start_end_token=config.get('use_start_end_token', True), + shuffle=shuffle, + device_id=device_id, + global_rank=global_rank, + world_size=world_size, + preprocessor_cfg=preprocessor_cfg, + ) + return dataset diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index f3bdcca1a1e9..b3fb05ae1658 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -195,6 +195,19 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): augmentor = None shuffle = config['shuffle'] + device = 'gpu' if torch.cuda.is_available() else 'cpu' + if config.get('use_dali', False): + device_id = self.local_rank if device == 'gpu' else None + dataset = audio_to_text_dataset.get_dali_bpe_dataset( + config=config, + tokenizer=self.tokenizer, + shuffle=shuffle, + device_id=device_id, + global_rank=self.global_rank, + world_size=self.world_size, + preprocessor_cfg=self._cfg.preprocessor, + ) + return dataset # Instantiate tarred dataset loader or normal dataset loader if config.get('is_tarred', False): diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index 35d4095ea40a..ec6402cef3bc 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -44,7 +44,7 @@ def tds_uniform_(tensor, mode='fan_in'): Normalized to - .. math:: - \text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}} + \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}} Args: tensor: an n-dimensional `torch.Tensor` @@ -67,7 +67,7 @@ def tds_normal_(tensor, mode='fan_in'): Normalized to - .. math:: - \text{bound} = \text{2} \times \sqrt{\frac{1}{\text{fan\_mode}}} + \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}} Args: tensor: an n-dimensional `torch.Tensor` diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index bedacac9febd..dd937e01d20a 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -520,7 +520,10 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> (Optiona - A string analysis of the check. """ try: - mod = __import__(lib_name) + if '.' in lib_name: + mod = import_class_by_path(lib_name) + else: + mod = __import__(lib_name) if hasattr(mod, '__version__'): lib_ver = version.Version(mod.__version__) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index b2a3ccfd09fe..9bd4fb39b551 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -13,16 +13,49 @@ # limitations under the License. import copy +import json import os +import tempfile import pytest +import torch.cuda from omegaconf import OmegaConf from nemo.collections.asr.data.audio_to_text import TarredAudioToBPEDataset, TarredAudioToCharDataset +from nemo.collections.asr.data.audio_to_text_dali import ( + __DALI_MINIMUM_VERSION__, + AudioToBPEDALIDataset, + AudioToCharDALIDataset, + is_dali_supported, +) from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.common import tokenizers from nemo.utils import logging +try: + HAVE_DALI = is_dali_supported(__DALI_MINIMUM_VERSION__) +except (ImportError, ModuleNotFoundError): + HAVE_DALI = False + + +def decode_chars(tokens, token_length, mapping): + text = [] + tokens = tokens.cpu().numpy() + for idx in tokens: + text_token = mapping[idx] + text.append(text_token) + + text = text[:token_length] + text = ''.join(text) + return text + + +def decode_subwords(tokens, token_length, tokenizer: tokenizers.TokenizerSpec): + tokens = tokens.cpu().numpy() + tokens = tokens[:token_length] + text = tokenizer.ids_to_text(tokens) + return text + class TestASRDatasets: labels = [ @@ -125,3 +158,174 @@ def test_tarred_bpe_dataset(self, test_data_dir): for _ in ds_list_load: count += 1 assert count == 32 + + @pytest.mark.skipif(not HAVE_DALI, reason="NVIDIA DALI is not installed or incompatible version") + @pytest.mark.unit + def test_dali_char_dataset(self, test_data_dir): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/an4_val.json')) + + num_samples = 10 + batch_size = 2 + device = 'gpu' if torch.cuda.is_available() else 'cpu' + texts = [] + + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r') as m: + for ix, line in enumerate(m): + if ix >= num_samples: + break + + line = line.replace("tests/data/", "tests/.data/").replace("\n", "") + f.write(f"{line}\n") + + data = json.loads(line) + texts.append(data['text']) + + f.seek(0) + + dataset = AudioToCharDALIDataset( + manifest_filepath=f.name, + device=device, + batch_size=batch_size, + labels=self.labels, + max_duration=16.0, + parser='en', + shuffle=False, + ) + + assert len(dataset) == (num_samples // batch_size) # num batches + count = 0 + original_transcripts = [] + for batch in dataset: + transcripts = batch[2] # transcript index in DALIOutputs + transcripts_lengths = batch[3] # transcript length index in DALIOutputs + transcripts = [ + decode_chars(transcript, transcripts_length, mapping=self.labels) + for transcript, transcripts_length in zip(transcripts, transcripts_lengths) + ] + original_transcripts.extend(transcripts) + count += len(transcripts) + assert count == num_samples + + # Assert transcripts are correct + for text, og_transcript in zip(texts, original_transcripts): + assert text == og_transcript + + # Repeat, now with shuffle enabled + f.seek(0) + + dataset = AudioToCharDALIDataset( + manifest_filepath=f.name, + device=device, + batch_size=batch_size, + labels=self.labels, + max_duration=16.0, + parser='en', + shuffle=True, + ) + + assert len(dataset) == (num_samples // batch_size) # num batches + count = 0 + shuffled_transcripts = [] + for batch in dataset: + transcripts = batch[2] # transcript index in DALIOutputs + transcripts_lengths = batch[3] # transcript length index in DALIOutputs + transcripts = [ + decode_chars(transcript, transcripts_length, mapping=self.labels) + for transcript, transcripts_length in zip(transcripts, transcripts_lengths) + ] + shuffled_transcripts.extend(transcripts) + count += len(transcripts) + assert count == num_samples + + samples_changed = 0 + for orig, shuffled in zip(original_transcripts, shuffled_transcripts): + if orig != shuffled: + samples_changed += 1 + assert samples_changed > 1 # assume after shuffling at least 1 sample was displaced + + @pytest.mark.skipif(not HAVE_DALI, reason="NVIDIA DALI is not installed or incompatible version") + @pytest.mark.unit + def test_dali_bpe_dataset(self, test_data_dir): + manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/an4_val.json')) + + num_samples = 10 + batch_size = 2 + device = 'gpu' if torch.cuda.is_available() else 'cpu' + texts = [] + + tokenizer_path = os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128", 'vocab.txt') + tokenizer = tokenizers.AutoTokenizer(pretrained_model_name='bert-base-cased', vocab_file=tokenizer_path) + + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: + with open(manifest_path, 'r') as m: + for ix, line in enumerate(m): + if ix >= num_samples: + break + + line = line.replace("tests/data/", "tests/.data/").replace("\n", "") + f.write(f"{line}\n") + + data = json.loads(line) + texts.append(data['text']) + + f.seek(0) + + dataset = AudioToBPEDALIDataset( + manifest_filepath=f.name, + tokenizer=tokenizer, + device=device, + batch_size=batch_size, + max_duration=16.0, + shuffle=False, + ) + + assert len(dataset) == (num_samples // batch_size) # num batches + count = 0 + original_transcripts = [] + for batch in dataset: + transcripts = batch[2] # transcript index in DALIOutputs + transcripts_lengths = batch[3] # transcript length index in DALIOutputs + transcripts = [ + decode_subwords(transcript, transcripts_length, tokenizer=tokenizer) + for transcript, transcripts_length in zip(transcripts, transcripts_lengths) + ] + original_transcripts.extend(transcripts) + count += len(transcripts) + assert count == num_samples + + # Assert transcripts are correct + for text, og_transcript in zip(texts, original_transcripts): + assert text == og_transcript + + # Repeat, now with shuffle enabled + f.seek(0) + + dataset = AudioToBPEDALIDataset( + manifest_filepath=f.name, + tokenizer=tokenizer, + device=device, + batch_size=batch_size, + max_duration=16.0, + shuffle=True, + ) + + assert len(dataset) == (num_samples // batch_size) # num batches + count = 0 + shuffled_transcripts = [] + for batch in dataset: + transcripts = batch[2] # transcript index in DALIOutputs + transcripts_lengths = batch[3] # transcript length index in DALIOutputs + transcripts = [ + decode_subwords(transcript, transcripts_length, tokenizer=tokenizer) + for transcript, transcripts_length in zip(transcripts, transcripts_lengths) + ] + shuffled_transcripts.extend(transcripts) + count += len(transcripts) + assert count == num_samples + + samples_changed = 0 + for orig, shuffled in zip(original_transcripts, shuffled_transcripts): + if orig != shuffled: + samples_changed += 1 + assert samples_changed > 1 # assume after shuffling at least 1 sample was displaced