diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py index 83007ebefd..821c228a8e 100644 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -8,11 +8,24 @@ import logging import os import sys +from typing import List from concurrent.futures import ProcessPoolExecutor from pathlib import Path import numpy as np +from fairseq.data.data_utils import numpy_seed + +try: + # TODO use pip install once it's available + from espresso.tools.lhotse import ( + CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter + ) + from espresso.tools.lhotse.manipulation import combine + from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords +except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -31,7 +44,15 @@ def get_parser(): parser.add_argument("--data-dir", default="data", type=str, help="data directory") parser.add_argument("--seed", default=1, type=int, help="random seed") parser.add_argument( - "--nj", default=1, type=int, help="number of jobs for features extraction" + "--num-jobs", default=1, type=int, help="number of jobs for features extraction" + ) + parser.add_argument( + "--max-remaining-duration", default=0.3, type=float, + help="not split if the left-over duration is less than this many seconds" + ) + parser.add_argument( + "--overlap-duration", default=0.3, type=float, + help="overlap between adjacent segments while splitting negative recordings" ) # fmt: on @@ -39,14 +60,6 @@ def get_parser(): def main(args): - try: - # TODO use pip install once it's available - from espresso.tools.lhotse import CutSet, Mfcc, MfccConfig, LilcomFilesWriter, WavAugmenter - from espresso.tools.lhotse.manipulation import combine - from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords - except ImportError: - raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") - root_dir = Path(args.data_dir) corpus_dir = root_dir / "MobvoiHotwords" output_dir = root_dir @@ -68,36 +81,46 @@ def main(args): np.random.seed(args.seed) # equivalent to Kaldi's mfcc_hires config mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) - num_jobs = args.nj for partition, manifests in mobvoihotwords_manifests.items(): cut_set = CutSet.from_manifests( recordings=manifests["recordings"], supervisions=manifests["supervisions"], ) sampling_rate = next(iter(cut_set)).sampling_rate - with ProcessPoolExecutor(num_jobs) as ex: + with ProcessPoolExecutor(args.num_jobs) as ex: if "train" in partition: - # original set - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_orig") as storage: - cut_set_orig = cut_set.compute_and_store_features( + # split negative recordings into smaller chunks with lengths sampled from + # length distribution of positive recordings + pos_durs = get_positive_durations(manifests["supervisions"]) + with numpy_seed(args.seed): + cut_set = keep_positives_and_split_negatives( + cut_set, + pos_durs, + max_remaining_duration=args.max_remaining_duration, + overlap_duration=args.overlap_duration, + ) + # "clean" set + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage: + cut_set_clean = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, augmenter=None, executor=ex, ) - # augmented with reverbration - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: - cut_set_rev = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=WavAugmenter(effect_chain=reverb()), - excutor=ex, - ) + # augmented with reverberation + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: + with numpy_seed(args.seed): + cut_set_rev = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=WavAugmenter(effect_chain=reverb()), + excutor=ex, + ) cut_set_rev = CutSet.from_cuts( cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts ) # augmented with speed perturbation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: cut_set_sp1p1 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, @@ -109,7 +132,7 @@ def main(args): cut_set_sp1p1 = CutSet.from_cuts( cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts ) - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: cut_set_sp0p9 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, @@ -121,9 +144,9 @@ def main(args): cut_set_sp0p9 = CutSet.from_cuts( cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts ) - # combine the original and augmented sets together + # combine the clean and augmented sets together cut_set = combine( - cut_set_orig, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 + cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 ) else: # no augmentations for dev and test sets with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage: @@ -137,6 +160,80 @@ def main(args): cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") +def get_positive_durations(sup_set: SupervisionSet) -> List[float]: + """ + Get duration values of all positive recordings, assuming Supervison.text is + "FREETEXT" for all negative recordings, and SupervisionSegment.duration + equals to the corresponding Recording.duration. + """ + return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] + + +def keep_positives_and_split_negatives( + cut_set: CutSet, + durations: List[float], + max_remaining_duration: float = 0.3, + overlap_duration: float = 0.3, +) -> CutSet: + """ + Returns a new CutSet where all the positives are directly taken from the original + input cut set, and the negatives are obtained by splitting original negatives + into shorter chunks of random lengths drawn from the given length distribution + (here it is the empirical distribution of the positive recordings), There can + be overlap between chunks. + + Args: + cut_set (CutSet): original input cut set + durations (list[float]): list of durations to sample from + max_remaining_duration (float, optional): not split if the left-over + duration is less than this many seconds (default: 0.3). + overlap_duration (float, optional): overlap between adjacent segments + (default: None) + + Returns: + CutSet: a new cut set after split + """ + assert max_remaining_duration >= 0.0 and overlap_duration >= 0.0 + new_cuts = [] + for cut in cut_set: + assert len(cut.supervisions) == 1 + if cut.supervisions[0].text != "FREETEXT": # keep the positive as it is + new_cuts.append(cut) + else: + this_offset = cut.start + this_offset_relative = this_offset - cut.start + remaining_duration = cut.duration + this_dur = durations[np.random.randint(len(durations))] + while remaining_duration > this_dur + max_remaining_duration: + new_cut = cut.truncate( + offset=this_offset_relative, duration=this_dur, preserve_id=True + ) + new_cut = new_cut.with_id( + "{id}-{s:07d}-{e:07d}".format( + id=new_cut.id, + s=int(round(100 * this_offset_relative)), + e=int(round(100 * (this_offset_relative + this_dur))) + ) + ) + new_cuts.append(new_cut) + this_offset += this_dur - overlap_duration + this_offset_relative = this_offset - cut.start + remaining_duration -= this_dur - overlap_duration + this_dur = durations[np.random.randint(len(durations))] + + new_cut = cut.truncate(offset=this_offset_relative, preserve_id=True) + new_cut = new_cut.with_id( + "{id}-{s:07d}-{e:07d}".format( + id=new_cut.id, + s=int(round(100 * this_offset_relative)), + e=int(round(100 * cut.duration)) + ) + ) + new_cuts.append(new_cut) + + return CutSet.from_cuts(new_cuts) + + def reverb(*args, **kwargs): """ Returns a reverb effect for wav augmentation.