diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index 1522105693..8edd372e5c 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -6,6 +6,7 @@ from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset from .asr_dataset import AsrDataset +from .asr_k2_dataset import AsrK2Dataset from .asr_dictionary import AsrDictionary from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset from .feat_text_dataset import ( @@ -20,6 +21,7 @@ "AsrChainDataset", "AsrDataset", "AsrDictionary", + "AsrK2Dataset", "AsrTextDataset", "AsrXentDataset", "FeatBucketPadLengthDataset", diff --git a/espresso/data/asr_k2_dataset.py b/espresso/data/asr_k2_dataset.py new file mode 100644 index 0000000000..1735ea8eca --- /dev/null +++ b/espresso/data/asr_k2_dataset.py @@ -0,0 +1,260 @@ +# Copyright (c) Yiming Wang, Yiwen Shao +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import re +from typing import Dict, List + +import numpy as np + +import torch + +from fairseq.data import FairseqDataset, data_utils + +import espresso.tools.utils as speech_utils + + +logger = logging.getLogger(__name__) + + +try: + from espresso.tools.lhotse.cut import CutSet +except ImportError: + logger.warning("Please install Lhotse by `make lhotse` after entering espresso/tools") + + +def collate(samples, pad_to_length=None, pad_to_multiple=1): + if len(samples) == 0: + return {} + + def merge(key, pad_to_length=None): + if key == "source": + return speech_utils.collate_frames( + [sample[key] for sample in samples], 0.0, + pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, + ) + else: + raise ValueError("Invalid key.") + + id = torch.LongTensor([sample["id"] for sample in samples]) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + if pad_to_length is not None: + src_lengths = torch.IntTensor( + [sample["source"].ne(0.0).any(dim=1).int().sum() for sample in samples] + ) + else: + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] + src_frames = src_frames.index_select(0, sort_order) + ntokens = src_lengths.sum().item() + + target = None + if samples[0].get("target", None) is not None and len(samples[0].target) > 0: + # reorder the list of samples to make things easier + # (no need to reorder every element in target) + samples = [samples[i] for i in sort_order.numpy()] + + from torch.utils.data._utils.collate import default_collate + + dataset_idx_to_batch_idx = { + sample["target"][0]["sequence_idx"]: batch_idx + for batch_idx, sample in enumerate(samples) + } + + def update(d: Dict, **kwargs) -> Dict: + for key, value in kwargs.items(): + d[key] = value + return d + + target = default_collate([ + update(sup, sequence_idx=dataset_idx_to_batch_idx[sup["sequence_idx"]]) + for sample in samples + for sup in sample["target"] + ]) + + batch = { + "id": id, + "utt_id": utt_id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_lengths, + }, + "target": target, + } + return batch + + +class AsrK2Dataset(FairseqDataset): + """ + A K2 Dataset for ASR. + + Args: + cuts (lhotse.CutSet): Lhotse CutSet to wrap + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). + pad_to_multiple (int, optional): pad src lengths to a multiple of this value + """ + + def __init__( + self, + cuts: CutSet, + shuffle=True, + pad_to_multiple=1, + ): + self.cuts = cuts + self.cut_ids = list(self.cuts.ids) + self.src_sizes = np.array( + [cut.num_frames if cut.has_features else cut.num_samples for cut in cuts] + ) + self.tgt_sizes = None + # assume all cuts have no supervisions if the first one does not + if len(cuts[self.cut_ids[0]].supervisions) > 0: + # take the size of the first supervision + self.tgt_sizes = np.array( + [ + round( + cut.supervisions[0].trim(cut.duration).duration / cut.frame_shift + ) for cut in cuts + ] + ) + self.shuffle = shuffle + self.epoch = 1 + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) + self.pad_to_multiple = pad_to_multiple + self.feat_dim = self.cuts[self.cut_ids[0]].load_features().shape(1) + + def __getitem__(self, index): + cut_id = self.cut_ids[index] + cut = self.cuts[cut_id] + features = torch.from_numpy(cut.load_features()) + + example = { + "id": index, + "utt_id": cut_id, + "source": features, + "target": [ + { + "sequence_idx": index, + "text": sup.text, + "start_frame": round(sup.start / cut.frame_shift), + "num_frames": round(sup.duration / cut.frame_shift), + } + # CutSet's supervisions can exceed the cut, when the cut starts/ends in the middle + # of a supervision (they would have relative times e.g. -2 seconds start, meaning + # it started 2 seconds before the Cut starts). We use s.trim() to get rid of that + # property, ensuring the supervision time span does not exceed that of the cut. + for sup in (s.trim(cut.duration) for s in cut.supervisions) + ] + } + return example + + def __len__(self): + return len(self.cuts) + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {"source": source_pad_to_length} + to indicate the max length to pad to in source and target respectively. + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `utt_id` (List[str]): list of utterance ids + - `nsentences` (int): batch size + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + the source of shape `(bsz, src_len, feat_dim)`. + - `src_lengths` (IntTensor): 1D Tensor of the unpadded + lengths of each source sequence of shape `(bsz)` + + - `target` (List[Dict[str, Any]]): an List representing a batch of + supervisions + """ + return collate( + samples, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, + ) + + def num_tokens(self, index): + """Return the number of frames in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.src_sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)).astype(np.int64) + else: + indices = np.arange(len(self), dtype=np.int64) + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + + @property + def supports_prefetch(self): + return False + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return False + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False # to avoid running out of CPU RAM + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self.epoch = epoch diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index a10b07ea52..d526a7bf9a 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -23,6 +23,7 @@ from espresso.data import ( AliScpCachedDataset, AsrChainDataset, + AsrK2Dataset, AsrXentDataset, AsrDictionary, AsrTextDataset, @@ -74,6 +75,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): }, ) feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + use_k2_dataset: bool = field(default=False, metadata={"help": "if True use K2 dataset"}) specaugment_config: Optional[str] = field( default=None, metadata={ @@ -146,6 +148,21 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage +def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1): + try: + from espresso.tools.lhotse.cut import CutSet + except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + + data_json_path = os.path.join(data_path, "cuts_{}.json".format(split)) + if not os.path.isfile(data_json_path): + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + + cut_set = CutSet.from_json(data_json_path) + logger.info("{} {} examples".format(data_json_path, len(cut_set))) + return AsrK2Dataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple) + + def get_asr_dataset_from_json( data_path, split, @@ -343,6 +360,7 @@ def __init__(self, cfg: DictConfig, dictionary): super().__init__(cfg) self.dictionary = dictionary self.feat_in_channels = cfg.feat_in_channels + self.use_k2_dataset = cfg.use_k2_dataset self.specaugment_config = cfg.specaugment_config self.num_targets = cfg.num_targets self.training_stage = (cfg.max_epoch > 0) # a hack @@ -402,6 +420,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] + if self.use_k2_dataset: + self.datasets[split] = get_k2_dataset_from_json( + data_path, + split, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + seed=self.cfg.seed, + ) + self.feat_dim = self.datasets[split].feat_dim + return + self.datasets[split] = get_asr_dataset_from_json( data_path, split, diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index 81ca5fb0fc..ecef83a85d 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -1,5 +1,5 @@ KALDI = -PYTHON_DIR = ~/anaconda3/bin +PYTHON_DIR = /export/b03/ywang/anaconda3/bin CXX ?= g++ @@ -30,6 +30,7 @@ kaldi: endif clean: openfst_cleaned + rm -rf lhotse rm -rf pychain rm -rf kaldi @@ -79,3 +80,8 @@ pychain: export PATH=$(PYTHON_DIR):$$PATH && \ cd pychain/openfst_binding && python3 setup.py install && \ cd ../pytorch_binding && python3 setup.py install + +.PHONY: lhotse +lhotse: + test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git + export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e '.[dev]'