Skip to content

Commit

Permalink
Lhotse/K2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Nov 4, 2020
1 parent b3ed99c commit 2b230b6
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 1 deletion.
2 changes: 2 additions & 0 deletions espresso/data/__init__.py
Expand Up @@ -6,6 +6,7 @@
from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset
from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset
from .asr_dataset import AsrDataset
from .asr_k2_dataset import AsrK2Dataset
from .asr_dictionary import AsrDictionary
from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset
from .feat_text_dataset import (
Expand All @@ -20,6 +21,7 @@
"AsrChainDataset",
"AsrDataset",
"AsrDictionary",
"AsrK2Dataset",
"AsrTextDataset",
"AsrXentDataset",
"FeatBucketPadLengthDataset",
Expand Down
260 changes: 260 additions & 0 deletions espresso/data/asr_k2_dataset.py
@@ -0,0 +1,260 @@
# Copyright (c) Yiming Wang, Yiwen Shao
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import re
from typing import Dict, List

import numpy as np

import torch

from fairseq.data import FairseqDataset, data_utils

import espresso.tools.utils as speech_utils


logger = logging.getLogger(__name__)


try:
from espresso.tools.lhotse.cut import CutSet
except ImportError:
logger.warning("Please install Lhotse by `make lhotse` after entering espresso/tools")


def collate(samples, pad_to_length=None, pad_to_multiple=1):
if len(samples) == 0:
return {}

def merge(key, pad_to_length=None):
if key == "source":
return speech_utils.collate_frames(
[sample[key] for sample in samples], 0.0,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
else:
raise ValueError("Invalid key.")

id = torch.LongTensor([sample["id"] for sample in samples])
src_frames = merge(
"source",
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
if pad_to_length is not None:
src_lengths = torch.IntTensor(
[sample["source"].ne(0.0).any(dim=1).int().sum() for sample in samples]
)
else:
src_lengths = torch.IntTensor([s["source"].size(0) for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()]
src_frames = src_frames.index_select(0, sort_order)
ntokens = src_lengths.sum().item()

target = None
if samples[0].get("target", None) is not None and len(samples[0].target) > 0:
# reorder the list of samples to make things easier
# (no need to reorder every element in target)
samples = [samples[i] for i in sort_order.numpy()]

from torch.utils.data._utils.collate import default_collate

dataset_idx_to_batch_idx = {
sample["target"][0]["sequence_idx"]: batch_idx
for batch_idx, sample in enumerate(samples)
}

def update(d: Dict, **kwargs) -> Dict:
for key, value in kwargs.items():
d[key] = value
return d

target = default_collate([
update(sup, sequence_idx=dataset_idx_to_batch_idx[sup["sequence_idx"]])
for sample in samples
for sup in sample["target"]
])

batch = {
"id": id,
"utt_id": utt_id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_frames,
"src_lengths": src_lengths,
},
"target": target,
}
return batch


class AsrK2Dataset(FairseqDataset):
"""
A K2 Dataset for ASR.
Args:
cuts (lhotse.CutSet): Lhotse CutSet to wrap
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
pad_to_multiple (int, optional): pad src lengths to a multiple of this value
"""

def __init__(
self,
cuts: CutSet,
shuffle=True,
pad_to_multiple=1,
):
self.cuts = cuts
self.cut_ids = list(self.cuts.ids)
self.src_sizes = np.array(
[cut.num_frames if cut.has_features else cut.num_samples for cut in cuts]
)
self.tgt_sizes = None
# assume all cuts have no supervisions if the first one does not
if len(cuts[self.cut_ids[0]].supervisions) > 0:
# take the size of the first supervision
self.tgt_sizes = np.array(
[
round(
cut.supervisions[0].trim(cut.duration).duration / cut.frame_shift
) for cut in cuts
]
)
self.shuffle = shuffle
self.epoch = 1
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)
self.pad_to_multiple = pad_to_multiple
self.feat_dim = self.cuts[self.cut_ids[0]].load_features().shape(1)

def __getitem__(self, index):
cut_id = self.cut_ids[index]
cut = self.cuts[cut_id]
features = torch.from_numpy(cut.load_features())

example = {
"id": index,
"utt_id": cut_id,
"source": features,
"target": [
{
"sequence_idx": index,
"text": sup.text,
"start_frame": round(sup.start / cut.frame_shift),
"num_frames": round(sup.duration / cut.frame_shift),
}
# CutSet's supervisions can exceed the cut, when the cut starts/ends in the middle
# of a supervision (they would have relative times e.g. -2 seconds start, meaning
# it started 2 seconds before the Cut starts). We use s.trim() to get rid of that
# property, ensuring the supervision time span does not exceed that of the cut.
for sup in (s.trim(cut.duration) for s in cut.supervisions)
]
}
return example

def __len__(self):
return len(self.cuts)

def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{"source": source_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `utt_id` (List[str]): list of utterance ids
- `nsentences` (int): batch size
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (FloatTensor): a padded 3D Tensor of features in
the source of shape `(bsz, src_len, feat_dim)`.
- `src_lengths` (IntTensor): 1D Tensor of the unpadded
lengths of each source sequence of shape `(bsz)`
- `target` (List[Dict[str, Any]]): an List representing a batch of
supervisions
"""
return collate(
samples, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple,
)

def num_tokens(self, index):
"""Return the number of frames in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.src_sizes[index]

def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)

def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]

@property
def supports_prefetch(self):
return False

def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)

@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return False

@property
def can_reuse_epoch_itr_across_epochs(self):
return False # to avoid running out of CPU RAM

def set_epoch(self, epoch):
super().set_epoch(epoch)
self.epoch = epoch
29 changes: 29 additions & 0 deletions espresso/tasks/speech_recognition_hybrid.py
Expand Up @@ -23,6 +23,7 @@
from espresso.data import (
AliScpCachedDataset,
AsrChainDataset,
AsrK2Dataset,
AsrXentDataset,
AsrDictionary,
AsrTextDataset,
Expand Down Expand Up @@ -74,6 +75,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass):
},
)
feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"})
use_k2_dataset: bool = field(default=False, metadata={"help": "if True use K2 dataset"})
specaugment_config: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -146,6 +148,21 @@ class SpeechRecognitionHybridConfig(FairseqDataclass):
max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage


def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1):
try:
from espresso.tools.lhotse.cut import CutSet
except ImportError:
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")

data_json_path = os.path.join(data_path, "cuts_{}.json".format(split))
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))

cut_set = CutSet.from_json(data_json_path)
logger.info("{} {} examples".format(data_json_path, len(cut_set)))
return AsrK2Dataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple)


def get_asr_dataset_from_json(
data_path,
split,
Expand Down Expand Up @@ -343,6 +360,7 @@ def __init__(self, cfg: DictConfig, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
self.feat_in_channels = cfg.feat_in_channels
self.use_k2_dataset = cfg.use_k2_dataset
self.specaugment_config = cfg.specaugment_config
self.num_targets = cfg.num_targets
self.training_stage = (cfg.max_epoch > 0) # a hack
Expand Down Expand Up @@ -402,6 +420,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = paths[:1]
data_path = paths[(epoch - 1) % len(paths)]

if self.use_k2_dataset:
self.datasets[split] = get_k2_dataset_from_json(
data_path,
split,
shuffle=(split != self.cfg.gen_subset),
pad_to_multiple=self.cfg.required_seq_len_multiple,
seed=self.cfg.seed,
)
self.feat_dim = self.datasets[split].feat_dim
return

self.datasets[split] = get_asr_dataset_from_json(
data_path,
split,
Expand Down
8 changes: 7 additions & 1 deletion espresso/tools/Makefile
@@ -1,5 +1,5 @@
KALDI =
PYTHON_DIR = ~/anaconda3/bin
PYTHON_DIR = /export/b03/ywang/anaconda3/bin

CXX ?= g++

Expand Down Expand Up @@ -30,6 +30,7 @@ kaldi:
endif

clean: openfst_cleaned
rm -rf lhotse
rm -rf pychain
rm -rf kaldi

Expand Down Expand Up @@ -79,3 +80,8 @@ pychain:
export PATH=$(PYTHON_DIR):$$PATH && \
cd pychain/openfst_binding && python3 setup.py install && \
cd ../pytorch_binding && python3 setup.py install

.PHONY: lhotse
lhotse:
test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git
export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e '.[dev]'

0 comments on commit 2b230b6

Please sign in to comment.