Skip to content

Commit

Permalink
Speech to text translation utilizing 3-way data (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirHussein96 committed Aug 17, 2023
1 parent aa073f6 commit c80fc07
Show file tree
Hide file tree
Showing 6 changed files with 626 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/corpus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ a CLI tool that create the manifests given a corpus directory.
- :func:`lhotse.recipes.prepare_himia`
* - ICSI
- :func:`lhotse.recipes.prepare_icsi`
* - IWSLT22_Ta
- :func:`lhotse.recipes.prepare_iwslt22_ta`
* - KeSpeech
- :func:`lhotse.recipes.prepare_kespeech`
* - L2 Arctic
Expand Down
1 change: 1 addition & 0 deletions lhotse/bin/modes/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .hifitts import *
from .himia import *
from .icsi import *
from .iwslt22_ta import *
from .kespeech import *
from .l2_arctic import *
from .libricss import *
Expand Down
57 changes: 57 additions & 0 deletions lhotse/bin/modes/recipes/iwslt22_ta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional, Sequence, Union

import click

from lhotse.bin.modes import prepare
from lhotse.recipes.iwslt22_ta import prepare_iwslt22_ta
from lhotse.utils import Pathlike


@prepare.command(context_settings=dict(show_default=True))
@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True))
@click.argument("splits", type=click.Path(exists=True, dir_okay=True))
@click.argument("output_dir", type=click.Path())
@click.option(
"-j",
"--num-jobs",
type=int,
default=1,
help="How many threads to use (can give good speed-ups with slow disks).",
)
@click.option(
"--normalize-text",
default=False,
help="Whether to perform additional text cleaning and normalization from https://aclanthology.org/2022.iwslt-1.29.pdf.",
)
@click.option(
"--langs",
default="",
help="Comma-separated list of language abbreviations for source and target languages",
)
def iwslt22_ta(
corpus_dir: Pathlike,
splits: Pathlike,
output_dir: Pathlike,
normalize_text: bool,
langs: str,
num_jobs: int,
):
"""
IWSLT_2022 data preparation.
\b
This is conversational telephone speech collected as 8kHz-sampled data.
The catalog number LDC2022E01 corresponds to the train, dev, and test1
splits of the iwslt2022 shared task.
To obtaining this data your institution needs to have an LDC subscription.
You also should download the predined splits with
git clone https://github.com/kevinduh/iwslt22-dialect.git
"""
langs_list = langs.split(",")
prepare_iwslt22_ta(
corpus_dir,
splits,
output_dir=output_dir,
num_jobs=num_jobs,
clean=normalize_text,
langs=langs_list,
)
203 changes: 203 additions & 0 deletions lhotse/dataset/speech_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright 2023 Johns Hopkins (authors: Amir Hussein)

from typing import Callable, Dict, List, Union

import torch
from torch.utils.data.dataloader import default_collate

from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.dataset.speech_recognition import validate_for_asr
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix


class K2Speech2TextTranslationDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech translation task using k2 library.
This dataset expects to be queried with lists of cut IDs,
for which it loads features and automatically collates/batches them.
To use it with a PyTorch DataLoader, set ``batch_size=None``
and provide a :class:`SimpleCutSampler` sampler.
Each item in this dataset is a dict of:
.. code-block::
{
'inputs': float tensor with shape determined by :attr:`input_strategy`:
- single-channel:
- features: (B, T, F)
- audio: (B, T)
- multi-channel: currently not supported
'supervisions': [
{
'sequence_idx': Tensor[int] of shape (S,)
'src_text': List[str] of len S
'tgt_text': List[str] of len S
# For feature input strategies
'start_frame': Tensor[int] of shape (S,)
'num_frames': Tensor[int] of shape (S,)
# For audio input strategies
'start_sample': Tensor[int] of shape (S,)
'num_samples': Tensor[int] of shape (S,)
# Optionally, when return_cuts=True
'cut': List[AnyCut] of len S
}
]
}
Dimension symbols legend:
* ``B`` - batch size (number of Cuts)
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
* ``T`` - number of frames of the longest Cut
* ``F`` - number of features
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
"""

def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
):
"""
K2 Speech2TextTranslation IterableDataset constructor.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy

# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
# reset the internal HDF5 caches.
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)

def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
validate_for_asr(cuts)
self.hdf5_fix.update()

# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)

# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)

# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)

# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl

# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)

# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
"tgt_text": supervision.custom["translated_text"],
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]

has_word_alignments = all(
s.alignment is not None and "word" in s.alignment
for c in cuts
for s in c.supervisions
)
if has_word_alignments:
# TODO: might need to refactor BatchIO API to move the following conditional logic
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
# that returns either num_frames or num_samples depending on the strategy).
words, starts, ends = [], [], []
frame_shift = cuts[0].frame_shift
sampling_rate = cuts[0].sampling_rate
if frame_shift is None:
try:
frame_shift = self.input_strategy.extractor.frame_shift
except AttributeError:
raise ValueError(
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
)
for c in cuts:
for s in c.supervisions:
words.append([aliword.symbol for aliword in s.alignment["word"]])
starts.append(
[
compute_num_frames(
aliword.start,
frame_shift=frame_shift,
sampling_rate=sampling_rate,
)
for aliword in s.alignment["word"]
]
)
ends.append(
[
compute_num_frames(
aliword.end,
frame_shift=frame_shift,
sampling_rate=sampling_rate,
)
for aliword in s.alignment["word"]
]
)
batch["supervisions"]["word"] = words
batch["supervisions"]["word_start"] = starts
batch["supervisions"]["word_end"] = ends

return batch
1 change: 1 addition & 0 deletions lhotse/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .hifitts import download_hifitts, prepare_hifitts
from .himia import download_himia, prepare_himia
from .icsi import download_icsi, prepare_icsi
from .iwslt22_ta import prepare_iwslt22_ta
from .kespeech import prepare_kespeech
from .l2_arctic import prepare_l2_arctic
from .libricss import download_libricss, prepare_libricss
Expand Down
Loading

0 comments on commit c80fc07

Please sign in to comment.