diff --git a/README_fairseq.md b/README_fairseq.md index 56ec16cdab..70e98fe395 100644 --- a/README_fairseq.md +++ b/README_fairseq.md @@ -112,7 +112,7 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example # Requirements and Installation -* [PyTorch](http://pytorch.org/) version >= 1.4.0 +* [PyTorch](http://pytorch.org/) version >= 1.5.0 * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * **To install fairseq** and develop locally: diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index cb29b46db0..96285fe2b1 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II import logging import numpy as np @@ -16,6 +15,7 @@ from fairseq.data import data_utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import II logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ @dataclass class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") label_smoothing: float = field( default=0.0, metadata={ @@ -85,7 +85,7 @@ def temporal_label_smoothing_prob_mask( prob_mask[:, :, padding_index] = 0 # clear cumulative count on prob_mask = prob_mask.float() # convert to float sum_prob = prob_mask.sum(-1, keepdim=True) - sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem + sum_prob[sum_prob.squeeze(-1).eq(0.0)] = 1.0 # to deal with the "division by 0" problem prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) return prob_mask @@ -109,8 +109,8 @@ def label_smoothed_nll_loss( raise ValueError("Unsupported smoothing type: {}".format(smoothing_type)) if ignore_index is not None: pad_mask = target.eq(ignore_index) - nll_loss.masked_fill_(pad_mask, 0.) - smooth_loss.masked_fill_(pad_mask, 0.) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) @@ -118,7 +118,7 @@ def label_smoothed_nll_loss( nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) if smoothing_type == "uniform" else epsilon - loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss, nll_loss @@ -126,9 +126,15 @@ def label_smoothed_nll_loss( class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__( - self, task, sentence_avg, label_smoothing, smoothing_type, - print_training_sample_interval, unigram_pseudo_count, - ignore_prefix_size=0, report_accuracy=False, + self, + task, + sentence_avg, + label_smoothing, + smoothing_type, + print_training_sample_interval, + unigram_pseudo_count, + ignore_prefix_size=0, + report_accuracy=False, ): super().__init__( task, sentence_avg, label_smoothing, @@ -149,7 +155,7 @@ def __init__( @classmethod def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" - dc = getattr(cls, '__dataclass', None) + dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc()) @@ -212,8 +218,13 @@ def compute_loss( padding_index=self.padding_idx, ) if smoothing_type == "temporal" else None loss, nll_loss = label_smoothed_nll_loss( - lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, - smoothing_type=smoothing_type, prob_mask=prob_mask, + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + smoothing_type=smoothing_type, + prob_mask=prob_mask, unigram_tensor=self.unigram_tensor, ) return loss, nll_loss, lprobs diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index 997b1a4228..dd00e91888 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II import logging import math @@ -14,6 +13,7 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics +from omegaconf import II logger = logging.getLogger(__name__) @@ -21,9 +21,9 @@ @dataclass class LatticeFreeMMICriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") denominator_fst_path: str = field( - default=None, metadata={"help": "path to the denominator fst file"} + default="???", metadata={"help": "path to the denominator fst file"} ) leaky_hmm_coefficient: float = field( default=1.0e-05, @@ -215,10 +215,11 @@ def compute_loss(self, net_output, sample, reduce=True): def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) - nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 metrics.log_scalar( "loss", loss_sum / sample_size / math.log(2), sample_size, round=7 ) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 2ed83741fd..5b88580fc9 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -12,7 +12,7 @@ import torch -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -48,12 +48,15 @@ def merge(key, pad_to_length=None): raise ValueError("Invalid key.") id = torch.LongTensor([s["id"] for s in samples]) - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) # sort by descending source length if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) @@ -134,8 +137,7 @@ def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < len(self.utt_ids)) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - "Duplicate elements in indices." + assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.rxfiles = [self.rxfiles[i] for i in indices] self.numerator_graphs = [self.numerator_graphs[i] for i in indices] @@ -172,8 +174,15 @@ class AsrChainDataset(FairseqDataset): """ def __init__( - self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, - num_buckets=0, pad_to_multiple=1, + self, + src, + src_sizes, + tgt=None, + tgt_sizes=None, + text=None, + shuffle=True, + num_buckets=0, + pad_to_multiple=1, ): self.src = src self.tgt = tgt @@ -196,10 +205,15 @@ def __init__( "Removed {} examples due to empty numerator graphs or missing entries, " "{} remaining".format(num_removed, num_after_matching) ) - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -215,8 +229,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -293,7 +306,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length} + {"source": source_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: @@ -327,7 +340,10 @@ def num_tokens(self, index): def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based @@ -339,9 +355,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -358,7 +372,7 @@ def prefetch(self, indices): self.src.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index b0dc9259ac..0246ccb883 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -7,7 +7,7 @@ import numpy as np import torch -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -48,14 +48,15 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): id = torch.LongTensor([s["id"] for s in samples]) src_frames = merge( - "source", left_pad=left_pad_source, + "source", + left_pad=left_pad_source, pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) @@ -68,7 +69,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): if samples[0].get("target", None) is not None: target = merge( "target", left_pad=left_pad_target, - pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) target = target.index_select(0, sort_order) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) @@ -82,7 +85,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): "target", left_pad=left_pad_target, move_eos_to_beginning=True, - pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) else: ntokens = src_lengths.sum().item() @@ -104,7 +109,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): "target_raw_text": target_raw_text, } if prev_output_tokens is not None: - batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(0, sort_order) + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) if samples[0].get("constraints", None) is not None: # Collate the packed constraints across the samples, padding to @@ -112,7 +119,7 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): lens = [sample.get("constraints").size(0) for sample in samples] constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): - constraints[i, 0:lens[i]] = samples[i].get("constraints") + constraints[i, 0: lens[i]] = samples[i].get("constraints") batch["constraints"] = constraints return batch @@ -141,19 +148,25 @@ class AsrDataset(FairseqDataset): num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch - will contain a field 'src_lang_id' in 'net_input' which indicates the + will contain a field "src_lang_id" in "net_input" which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch - will contain a field 'tgt_lang_id' which indicates the target language + will contain a field "tgt_lang_id" which indicates the target language of the samples. pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value """ def __init__( - self, src, src_sizes, - tgt=None, tgt_sizes=None, dictionary=None, - left_pad_source=False, left_pad_target=False, - shuffle=True, input_feeding=True, + self, + src, + src_sizes, + tgt=None, + tgt_sizes=None, + dictionary=None, + left_pad_source=False, + left_pad_target=False, + shuffle=True, + input_feeding=True, constraints=None, num_buckets=0, src_lang_id=None, @@ -175,10 +188,15 @@ def __init__( self.tgt_lang_id = tgt_lang_id if self.tgt is not None: self._match_src_tgt() - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -204,8 +222,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -261,7 +278,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length, 'target': target_pad_to_length} + {"source": source_pad_to_length, "target": target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: @@ -309,13 +326,13 @@ def collater(self, samples, pad_to_length=None): src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: - res["net_input"]["src_lang_id"] = torch.LongTensor( - [[self.src_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) if self.tgt_lang_id is not None: - res["tgt_lang_id"] = torch.LongTensor( - [[self.tgt_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) return res def num_tokens(self, index): @@ -326,7 +343,10 @@ def num_tokens(self, index): def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based @@ -338,9 +358,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -357,7 +375,7 @@ def prefetch(self, indices): self.src.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 3c0a079125..4cdd44dac3 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -3,10 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch +from argparse import Namespace +from typing import Union +import torch from fairseq.data import Dictionary, encoders from fairseq.file_io import PathManager +from omegaconf import DictConfig # will automatically load modules defined from there from espresso.data import encoders as encoders_espresso @@ -24,8 +27,9 @@ def __init__( space="", extra_special_symbols=None, ): - self.bos_word, self.unk_word, self.pad_word, self.eos_word, self.space_word = \ + self.bos_word, self.unk_word, self.pad_word, self.eos_word, self.space_word = ( bos, unk, pad, eos, space + ) self.symbols = [] self.count = [] self.indices = {} @@ -78,12 +82,13 @@ def load(cls, f, f_non_lang_syms=None): except UnicodeError: raise Exception( "Incorrect encoding detected in {}, please " - "rebuild the dataset".format(f) + "rebuild the dataset".format(fd) ) for sym in non_lang_syms: - assert d.index(sym) != d.unk(), \ - "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) + assert ( + d.index(sym) != d.unk() + ), "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d @@ -94,17 +99,19 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def build_tokenizer(self, args): - self.tokenizer = encoders.build_tokenizer(args) + def build_tokenizer(self, cfg: Union[DictConfig, Namespace]): + self.tokenizer = encoders.build_tokenizer(cfg) - def build_bpe(self, args): - if args.bpe == "characters_asr": + def build_bpe(self, cfg: Union[DictConfig, Namespace]): + if ( + (isinstance(cfg, DictConfig) and cfg._name == "characters_asr") + or (isinstance(cfg, Namespace) and getattr(cfg, "bpe", None) == "characters_asr") + ): self.bpe = encoders.build_bpe( - args, space_symbol=self.space_word, ends_with_space=True, - non_lang_syms=self.non_lang_syms, + cfg, space_symbol=self.space_word, non_lang_syms=self.non_lang_syms ) else: - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg) def wordpiece_encode(self, x): if self.tokenizer is not None: diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index e831345f9f..47ed464f39 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -12,7 +12,7 @@ import torch import torch.nn.functional as F -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -111,18 +111,26 @@ def chunking(src_item, tgt_item, tgt_start): s["source"] = src_item[: label_delay] if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) target = None if samples[0].get("target", None) is not None: - target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) + target = merge( + "target", + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -181,14 +189,25 @@ def chunking(src_item, tgt_item, tgt_start): s["source"] = ori_source[i].new_zeros( chunk_width + chunk_left_context + chunk_right_context, ori_source[i].size(1) ) - s["target"] = ori_target[i].new_full((chunk_width,), pad_idx) \ - if ori_target[i] is not None else None - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + s["target"] = ( + ori_target[i].new_full((chunk_width,), pad_idx) + if ori_target[i] is not None + else None + ) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) src_chunk_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) target = None if samples[0].get("target", None) is not None: - target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) + target = merge( + "target", + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -218,8 +237,12 @@ class AliScpCachedDataset(torch.utils.data.Dataset): """ def __init__( - self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, - ordered_prefetch=False, cache_size=327680, + self, + utt_ids: List[str], + rxfiles: List[str], + utt2num_frames: Optional[List[int]] = None, + ordered_prefetch=False, + cache_size=327680, ): super().__init__() assert len(utt_ids) == len(rxfiles) @@ -277,8 +300,7 @@ def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < len(self.utt_ids)) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - "Duplicate elements in indices." + assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.rxfiles = [self.rxfiles[i] for i in indices] self.sizes = self.sizes[indices] @@ -288,9 +310,9 @@ def filter_and_reorder(self, indices): def __getitem__(self, i): self.check_index(i) if i not in self.cache_index: - assert self.start_pos_for_next_cache < \ - len(self.ordered_indices), \ - "Position for next cache starting beyond the end of ordered_indices." + assert ( + self.start_pos_for_next_cache < len(self.ordered_indices) + ), "Position for next cache starting beyond the end of ordered_indices." try: pos_start = self.ordered_indices.index( i, self.start_pos_for_next_cache, @@ -304,8 +326,7 @@ def __getitem__(self, i): pos_end = min( pos_start + self.cache_size, len(self.ordered_indices), ) - self.start_pos_for_next_cache = pos_end \ - if self.ordered_prefetch else 0 + self.start_pos_for_next_cache = pos_end if self.ordered_prefetch else 0 total_size = 0 for idx in self.ordered_indices[pos_start: pos_end]: total_size += self.sizes[idx] @@ -358,9 +379,21 @@ class AsrXentDataset(FairseqDataset): """ def __init__( - self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, - shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, - chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, + self, + src, + src_sizes, + tgt: Optional[AliScpCachedDataset] = None, + tgt_sizes=None, + text=None, + shuffle=True, + num_buckets=0, + pad_to_multiple=1, + seed=1, + chunk_width=None, + chunk_left_context=None, + chunk_right_context=None, + label_delay=0, + random_chunking=True, ): self.src = src self.tgt = tgt @@ -375,8 +408,10 @@ def __init__( assert chunk_left_context >= 0 and chunk_right_context >= 0 self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context - assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ - (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) + assert ( + (label_delay < 0 and -label_delay <= chunk_right_context) + or (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) + ) self.label_delay = label_delay self.random_chunking = random_chunking if self.tgt is not None: @@ -385,7 +420,11 @@ def __init__( changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if chunk_width is not None: # remove those whose lengths are shorter than chunk_size @@ -406,6 +445,7 @@ def __init__( if num_buckets > 0: from fairseq.data import BucketPadLengthDataset from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -431,8 +471,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -511,7 +550,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length, 'target': target_pad_to_length} + {"source": source_pad_to_length, "target": target_pad_to_length} to indicate the max length to pad to in source and target respectively. @@ -571,9 +610,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -591,7 +628,7 @@ def prefetch(self, indices): self.tgt.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/encoders/characters_asr.py b/espresso/data/encoders/characters_asr.py index ef424150f6..0bd9a48d00 100644 --- a/espresso/data/encoders/characters_asr.py +++ b/espresso/data/encoders/characters_asr.py @@ -3,23 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from dataclasses import dataclass from typing import List, Optional from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass from espresso.tools.utils import tokenize -@register_bpe('characters_asr') -class CharactersAsr(object): +@dataclass +class CharactersAsrConfig(FairseqDataclass): + pass - @staticmethod - def add_args(parser): - pass +@register_bpe("characters_asr", dataclass=CharactersAsrConfig) +class CharactersAsr(object): def __init__( - self, args, space_symbol="", ends_with_space=True, + self, cfg, space_symbol="", ends_with_space=True, non_lang_syms: Optional[List[str]] = None, ): self.space_symbol = space_symbol diff --git a/espresso/data/feat_text_dataset.py b/espresso/data/feat_text_dataset.py index 0dce559762..64a43eba8d 100644 --- a/espresso/data/feat_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -30,8 +30,12 @@ class FeatScpDataset(torch.utils.data.Dataset): """ def __init__( - self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, - seed=1, specaugment_config: Optional[str] = None, + self, + utt_ids: List[str], + rxfiles: List[str], + utt2num_frames: Optional[List[int]] = None, + seed=1, + specaugment_config: Optional[str] = None, ): super().__init__() assert len(utt_ids) == len(rxfiles) diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index ee0348426d..214172fe41 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -12,14 +12,17 @@ import logging import os import sys +from argparse import Namespace import numpy as np import torch from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter +from omegaconf import DictConfig try: import kaldi_io @@ -27,12 +30,16 @@ raise ImportError("Please install kaldi_io with: pip install kaldi_io") -def main(args): - assert args.path is not None, "--path required for decoding!" - return _main(args, sys.stderr) +def main(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) -def _main(args, output_file): + assert cfg.common_eval.path is not None, "--path required for decoding!" + return _main(cfg, sys.stderr) + + +def _main(cfg, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -41,41 +48,41 @@ def _main(args, output_file): ) logger = logging.getLogger("espresso.dump_posteriors") - print_options_meaning_changes(args, logger) + print_options_meaning_changes(cfg, logger) - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset split - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Load state prior for cross-entropy trained systems decoding - if args.state_prior_file is not None: - prior = torch.from_numpy(kaldi_io.read_vec_flt(args.state_prior_file)) + if cfg.generation.state_prior_file is not None: + prior = torch.from_numpy(kaldi_io.read_vec_flt(cfg.generation.state_prior_file)) else: prior = [] @@ -83,11 +90,11 @@ def _main(args, output_file): for model in models: if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) if isinstance(prior, list) and getattr(model, "state_prior", None) is not None: prior.append(model.state_prior.unsqueeze(0)) @@ -98,7 +105,7 @@ def _main(args, output_file): prior = None if prior is not None: - if args.fp16: + if cfg.common.fp16: prior = prior.half() if use_cuda: prior = prior.cuda() @@ -108,31 +115,30 @@ def _main(args, output_file): # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() if hasattr(model, "encoder") - else (None, model.max_positions()) for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + generator = task.build_generator(models, cfg.generation) # Generate and dump num_sentences = 0 @@ -153,12 +159,12 @@ def _main(args, output_file): out_lengths = (~padding_mask).long().sum(dim=1).cpu() if padding_mask is not None else None num_processed_frames = sample["ntokens"] gen_timer.stop(num_processed_frames) - num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() if out_lengths is not None: for i in range(sample["nsentences"]): length = out_lengths[i] - kaldi_io.write_mat(f, lprobs[i, :length, :].cpu().numpy(), key=sample["utt_id"][i]) + kaldi_io.write_mat(f, lprobs[i, : length, :].cpu().numpy(), key=sample["utt_id"][i]) else: for i in range(sample["nsentences"]): kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]) @@ -189,9 +195,9 @@ def _main(args, output_file): num_sentences += len(utt_id) for j in range(len(utt_id)): truncated_length = models[0].output_lengths( - task.dataset(args.gen_subset).src_sizes[id[j]] + task.dataset(cfg.dataset.gen_subset).src_sizes[id[j]] ) # length is after possible subsampling by the model - mat = whole_lprobs[j, :truncated_length, :] + mat = whole_lprobs[j, : truncated_length, :] kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j]) logger.info("Dumped {} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)".format( @@ -200,7 +206,7 @@ def _main(args, output_file): return -def print_options_meaning_changes(args, logger): +def print_options_meaning_changes(cfg, logger): """Options that have different meanings than those in the translation task are explained here. """ @@ -209,12 +215,6 @@ def print_options_meaning_changes(args, logger): def cli_main(): parser = options.get_generation_parser(default_task="speech_recognition_hybrid") - parser.add_argument("--apply-log-softmax", action="store_true", - help="Apply log-softmax to the neural network outputs for some " - "systems, e.g., Xent. Otherwise use the raw outputs") - parser.add_argument("--state-prior-file", default=None, type=str, metavar="FILE", - help="state prior file. If provided, use this file instead of " - "that from the checkpoint") args = options.parse_args_and_arch(parser) main(args) diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index dd418ff973..c1f04307f3 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -56,9 +56,10 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): assert isinstance(wordlm, FairseqLanguageModel) self.lm_decoder = wordlm.decoder - assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ - callable(self.lm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.lm_decoder, "masked_copy_incremental_state") + and callable(self.lm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue @@ -76,7 +77,7 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): self.subword_vocab_size = len(subword_dict) def tokenizer(x): - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) def max_out_degree(node): @@ -92,18 +93,17 @@ def max_out_degree(node): @torch.no_grad() def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): - assert incremental_state is not None, \ - 'this model is for incremental decoding only' + assert incremental_state is not None, "this model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] bsz = prev_output_tokens.size(0) batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) - cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') + cached_state = self.lm_decoder.get_incremental_state(incremental_state, "cached_state") if cached_state is None: # it is the first time step assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) lm_probs = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), @@ -112,8 +112,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): cumsum_probs = torch.cumsum(lm_probs, dim=-1) # B x 1 x V nodes = [self.lexroot] * bsz else: - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') - nodes = self.get_incremental_state(incremental_state, 'nodes') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") + nodes = self.get_incremental_state(incremental_state, "nodes") assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else @@ -129,8 +129,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): self.lm_decoder.masked_copy_incremental_state( incremental_state, old_cached_state, batch_space_mask, ) # restore those not masked - cumsum_probs[batch_space_mask] = \ + cumsum_probs[batch_space_mask] = ( torch.cumsum(lm_probs, dim=-1)[batch_space_mask] + ) tokens_list = prev_output_tokens.squeeze(-1).tolist() for i in range(bsz): if tokens_list[i] == self.subword_space_idx: @@ -142,8 +143,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): else: # no path in the tree nodes[i] = None - self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "cumsum_probs", cumsum_probs) + self.set_incremental_state(incremental_state, "nodes", nodes) # initialize out_probs (B x 1 x V) if self.open_vocab: @@ -155,8 +156,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # set the probability of emitting to 0 if prev_output_tokens # is or , and that of emitting to 0 if # prev_output_tokens is not - batch_space_eos_mask = batch_space_mask | \ + batch_space_eos_mask = ( + batch_space_mask | prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + ) out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero # set transition probability to 1 for those whose node is out of the @@ -164,13 +167,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_node_none_mask = batch_space_mask.new( [node is None for node in nodes] ) - out_probs[batch_node_none_mask] = 1. + out_probs[batch_node_none_mask] = 1.0 else: # set out_probs to 0 out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) # compute parent probabilities for those whose node is not None - sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node + sum_probs = cumsum_probs.new_full([bsz, 1], 1.0) # default for root node left_ranges, right_ranges, batch_node_not_root_mask = [], [], [] for node in nodes: if node is not None and node.word_set is not None: @@ -243,8 +246,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_logprobs = out_probs.clamp(min=self.zero).log_() # assign log-probs of emitting word to that of emitting subword - out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ + out_logprobs[batch_space_mask, :, self.subword_eos_idx] = ( lm_probs.log_()[batch_space_mask, :, self.word_eos_idx] + ) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -254,16 +258,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) + self.set_incremental_state(incremental_state, "cumsum_probs", new_cumsum_probs) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" @@ -301,9 +305,10 @@ def __init__( assert isinstance(wordlm, FairseqLanguageModel) self.wordlm_decoder = wordlm.decoder - assert hasattr(self.wordlm_decoder, 'masked_copy_incremental_state') and \ - callable(self.wordlm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.wordlm_decoder, "masked_copy_incremental_state") and + callable(self.wordlm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" assert isinstance(subwordlm, FairseqLanguageModel) self.subwordlm_decoder = subwordlm.decoder self.subwordlm_weight = subwordlm_weight @@ -323,13 +328,12 @@ def __init__( self.subword_vocab_size = len(subword_dict) def tokenizer(x): - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) @torch.no_grad() def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): - assert incremental_state is not None, \ - 'this model is for incremental decoding only' + assert incremental_state is not None, "this model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] bsz = prev_output_tokens.size(0) @@ -337,16 +341,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_not_space_mask = ~batch_space_mask wordlm_cached_state = self.wordlm_decoder.get_incremental_state( - incremental_state, 'cached_state', + incremental_state, "cached_state", ) subwordlm_cached_state = self.subwordlm_decoder.get_incremental_state( - incremental_state, 'cached_state', + incremental_state, "cached_state", ) if wordlm_cached_state is None: # it is the first time step assert subwordlm_cached_state is None assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) wordlm_logprobs = self.wordlm_decoder.get_normalized_probs( self.wordlm_decoder(w, incremental_state=incremental_state), @@ -362,10 +366,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subword_cumlogprobs = out_logprobs.new_zeros(sw.size()) nodes = [self.lexroot] * bsz else: - wordlm_logprobs = self.get_incremental_state(incremental_state, 'wordlm_logprobs') - out_logprobs = self.get_incremental_state(incremental_state, 'out_logprobs') - subword_cumlogprobs = self.get_incremental_state(incremental_state, 'subword_cumlogprobs') - nodes = self.get_incremental_state(incremental_state, 'nodes') + wordlm_logprobs = self.get_incremental_state(incremental_state, "wordlm_logprobs") + out_logprobs = self.get_incremental_state(incremental_state, "out_logprobs") + subword_cumlogprobs = self.get_incremental_state(incremental_state, "subword_cumlogprobs") + nodes = self.get_incremental_state(incremental_state, "nodes") assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else @@ -403,15 +407,17 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_is_child_mask.append(False) token_idx = prev_output_tokens.new(token_idx).unsqueeze(-1) # b x 1 x 1 if self.open_vocab: - subword_cumlogprobs[batch_space_mask] = 0. + subword_cumlogprobs[batch_space_mask] = 0.0 assert batch_not_space_mask.sum().item() == len(token_idx) - subword_cumlogprobs[batch_not_space_mask] += \ + subword_cumlogprobs[batch_not_space_mask] += ( out_logprobs[batch_not_space_mask].gather(-1, token_idx).squeeze(-1) + ) else: - subword_cumlogprobs[~batch_is_child_mask] = 0. + subword_cumlogprobs[~batch_is_child_mask] = 0.0 assert batch_is_child_mask.sum().item() == len(token_idx) - subword_cumlogprobs[batch_is_child_mask] += \ + subword_cumlogprobs[batch_is_child_mask] += ( out_logprobs[batch_is_child_mask].gather(-1, token_idx).squeeze(-1) + ) out_logprobs = self.subwordlm_decoder.get_normalized_probs( self.subwordlm_decoder(prev_output_tokens, incremental_state=incremental_state), @@ -423,9 +429,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_oov_mask = batch_not_space_mask & ~batch_is_child_mask out_logprobs[batch_oov_mask] = self.logzero - self.set_incremental_state(incremental_state, 'wordlm_logprobs', wordlm_logprobs) - self.set_incremental_state(incremental_state, 'subword_cumlogprobs', subword_cumlogprobs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "wordlm_logprobs", wordlm_logprobs) + self.set_incremental_state(incremental_state, "subword_cumlogprobs", subword_cumlogprobs) + self.set_incremental_state(incremental_state, "nodes", nodes) # apply word-level probabilies for emitting w = prev_output_tokens.new([ @@ -443,16 +449,18 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # set the probability of emitting to 0 if prev_output_tokens is # or , and that of emitting to 0 if prev_output_tokens # is not - batch_space_eos_mask = batch_space_mask | \ - prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + batch_space_eos_mask = ( + batch_space_mask | prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + ) out_logprobs[batch_space_eos_mask, :, self.subword_space_idx] = self.logzero out_logprobs[~batch_space_mask, :, self.subword_eos_idx] = self.logzero # add log-probs of emitting word to that of emitting subword - out_logprobs[batch_space_mask, :, self.subword_eos_idx] += \ + out_logprobs[batch_space_mask, :, self.subword_eos_idx] += ( wordlm_logprobs[batch_space_mask, :, self.word_eos_idx] + ) - self.set_incremental_state(incremental_state, 'out_logprobs', out_logprobs) + self.set_incremental_state(incremental_state, "out_logprobs", out_logprobs) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -462,17 +470,17 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']: + for state_name in ["wordlm_logprobs", "out_logprobs", "subword_cumlogprobs"]: state = self.get_incremental_state(incremental_state, state_name) if state is not None: new_state = state.index_select(0, new_order) self.set_incremental_state(incremental_state, state_name, new_state) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 11e0764cae..7d0f7e7218 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II from typing import Optional from fairseq import utils @@ -15,6 +14,7 @@ register_model_architecture, ) from fairseq.models.lstm import Embedding +from omegaconf import II from espresso.models.speech_lstm import SpeechLSTMDecoder from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask @@ -44,7 +44,7 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): decoder_out_embed_dim: int = field( default=650, metadata={"help": "decoder output embedding dimension"} ) - decoder_rnn_residual: lambda x: utils.eval_bool(x) = field( + decoder_rnn_residual: bool = field( default=False, metadata={ "help": "create residual connections for rnn decoder layers " @@ -59,7 +59,7 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): "Must be used with adaptive_loss criterion" }, ) - share_embed: lambda x: utils.eval_bool(x) = field( + share_embed: bool = field( default=False, metadata={"help": "share input and output embeddings"} ) is_wordlm: bool = field( @@ -79,18 +79,59 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): metadata={"help": "dropout probability for decoder output"} ) # TODO common var add to parent - add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - tpu: bool = II("params.common.tpu") + tpu: bool = II("common.tpu") -@register_model("lstm_lm_espresso", dataclass=LSTMLanguageModelEspressoConfig) +@register_model("lstm_lm_espresso") class LSTMLanguageModelEspresso(FairseqLanguageModel): def __init__(self, decoder, args): super().__init__(decoder) self.is_wordlm = args.is_wordlm + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-embed-path", type=str, metavar="STR", + help="path to pre-trained decoder embedding") + parser.add_argument("--decoder-freeze-embed", action="store_true", + help="freeze decoder embeddings") + parser.add_argument("--decoder-hidden-size", type=int, metavar="N", + help="decoder hidden size") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="number of decoder layers") + parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", + help="decoder output embedding dimension") + parser.add_argument("--decoder-rnn-residual", + type=lambda x: utils.eval_bool(x), + help="create residual connections for rnn decoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion") + parser.add_argument("--share-embed", + type=lambda x: utils.eval_bool(x), + help="share input and output embeddings") + parser.add_argument("--is-wordlm", action="store_true", + help="whether it is word LM or subword LM. Only " + "relevant for ASR decoding with LM, and it determines " + "how the underlying decoder instance gets the dictionary " + "from the task instance when calling cls.build_model()") + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument("--decoder-dropout-in", type=float, metavar="D", + help="dropout probability for decoder input embedding") + parser.add_argument("--decoder-dropout-out", type=float, metavar="D", + help="dropout probability for decoder output") + # fmt: on + @classmethod def build_model(cls, args, task): """Build a new model instance.""" @@ -100,7 +141,9 @@ def build_model(cls, args, task): if getattr(args, "max_target_positions", None) is not None: max_target_positions = args.max_target_positions else: - max_target_positions = getattr(args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS) + max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -121,9 +164,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed = None if args.decoder_embed_path: pretrained_decoder_embed = load_pretrained_embedding_from_file( - args.decoder_embed_path, - dictionary, - args.decoder_embed_dim + args.decoder_embed_path, dictionary, args.decoder_embed_dim ) # one last double check of parameter combinations if args.share_embed and ( @@ -150,7 +191,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == "adaptive_loss" else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, ) @@ -181,39 +223,33 @@ def lstm_lm_wsj(args): @register_model_architecture("lstm_lm_espresso", "lstm_lm_librispeech") def lstm_lm_librispeech(args): - args.dropout = 0.0 - args.decoder_embed_dim = 800 - args.decoder_hidden_size = 800 - args.decoder_layers = 4 - args.decoder_out_embed_dim = 800 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 800) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_lm_swbd") def lstm_lm_swbd(args): - args.dropout = 0.3 - args.decoder_embed_dim = 1800 - args.decoder_hidden_size = 1800 - args.decoder_layers = 3 - args.decoder_out_embed_dim = 1800 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1800) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_wordlm_wsj") def lstm_wordlm_wsj(args): - args.dropout = 0.35 - args.decoder_embed_dim = 1200 - args.decoder_hidden_size = 1200 - args.decoder_layers = 3 - args.decoder_out_embed_dim = 1200 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.35) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1200) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1200) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1200) + args.share_embed = getattr(args, "share_embed", True) args.is_wordlm = True base_lm_architecture(args) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index a644b2536d..4764ffd89e 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -15,8 +15,8 @@ from fairseq.models import ( FairseqDecoder, FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) @@ -135,8 +135,12 @@ def build_model(cls, args, task): # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) - max_target_positions = getattr(args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS) + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) + max_target_positions = getattr( + args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -201,7 +205,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, - src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), + src_bucketed=(getattr(task.cfg, "num_batch_buckets", 0) > 0), max_source_positions=max_source_positions, ) decoder = SpeechLSTMDecoder( @@ -221,7 +225,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == "adaptive_loss" else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, @@ -247,8 +252,10 @@ def forward( ): encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) decoder_out = self.decoder( - prev_output_tokens, encoder_out=encoder_out, - incremental_state=incremental_state, epoch=epoch, + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + epoch=epoch, ) return decoder_out @@ -260,14 +267,16 @@ def max_positions(self): """Maximum length supported by the model.""" return ( self.encoder.max_positions(), - self.decoder.max_positions() if self.pretrained_lm is None else - min(self.decoder.max_positions(), self.pretrained_lm.max_positions()), + self.decoder.max_positions() if self.pretrained_lm is None + else min(self.decoder.max_positions(), self.pretrained_lm.max_positions()), ) def max_decoder_positions(self): """Maximum length supported by the decoder.""" - return self.decoder.max_positions() if self.pretrained_lm is None else \ - min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + return ( + self.decoder.max_positions() if self.pretrained_lm is None + else min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + ) class ConvBNReLU(nn.Module): @@ -327,29 +336,46 @@ def forward(self, src, src_lengths): class SpeechLSTMEncoder(FairseqEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=83, hidden_size=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., src_bucketed=False, + self, + conv_layers_before=None, + input_size=83, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + residual=False, + left_pad=False, + padding_value=0.0, + src_bucketed=False, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before self.num_layers = num_layers - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.bidirectional = bidirectional self.hidden_size = hidden_size self.residual = residual self.max_source_positions = max_source_positions - self.lstm = nn.ModuleList([ - LSTM( - input_size=input_size if layer == 0 else 2 * hidden_size if self.bidirectional else hidden_size, - hidden_size=hidden_size, - bidirectional=bidirectional, - ) - for layer in range(num_layers) - ]) + self.lstm = nn.ModuleList( + [ + LSTM( + input_size=input_size if layer == 0 + else 2 * hidden_size if self.bidirectional + else hidden_size, + hidden_size=hidden_size, + bidirectional=bidirectional, + ) + for layer in range(num_layers) + ] + ) self.left_pad = left_pad self.padding_value = padding_value self.src_bucketed = src_bucketed @@ -359,8 +385,10 @@ def __init__( self.output_units *= 2 def output_lengths(self, in_lengths): - return in_lengths if self.conv_layers_before is None \ + return ( + in_lengths if self.conv_layers_before is None else self.conv_layers_before.output_lengths(in_lengths) + ) def forward( self, @@ -392,8 +420,10 @@ def forward( if self.conv_layers_before is not None: x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: - x, padding_mask = src_tokens, \ + x, padding_mask = ( + src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + ) bsz, seqlen = x.size(0), x.size(1) @@ -422,7 +452,9 @@ def forward( packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0) + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value * 1.0 + ) if i < len(self.lstm) - 1: # not applying dropout for the last layer x = self.dropout_out_module(x) x = x + prev_x if self.residual and i > 0 else x @@ -432,7 +464,8 @@ def forward( return EncoderOut( encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() + else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, @@ -469,16 +502,32 @@ def max_positions(self): class SpeechLSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( - self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, - attn_type=None, attn_dim=0, need_attn=False, residual=False, pretrained_embed=None, - share_input_output_embed=False, adaptive_softmax_cutoff=None, + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + encoder_output_units=0, + attn_type=None, + attn_dim=0, + need_attn=False, + residual=False, + pretrained_embed=None, + share_input_output_embed=False, + adaptive_softmax_cutoff=None, max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, scheduled_sampling_rate_scheduler=None, ): super().__init__(dictionary) - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed if attn_type is None or attn_type.lower() == "none": @@ -500,13 +549,16 @@ def __init__( self.encoder_output_units = encoder_output_units - self.layers = nn.ModuleList([ - LSTMCell( - input_size=encoder_output_units + (embed_dim if layer == 0 else hidden_size), - hidden_size=hidden_size, - ) - for layer in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=encoder_output_units + + (embed_dim if layer == 0 else hidden_size), + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) if attn_type is None or attn_type.lower() == "none": self.attention = None @@ -527,7 +579,10 @@ def __init__( if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax( - num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out, + num_embeddings, + hidden_size, + adaptive_softmax_cutoff, + dropout=dropout_out, ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -641,8 +696,10 @@ def extract_features( zero_state = x.new_zeros(bsz, self.hidden_size) prev_hiddens = [zero_state for i in range(self.num_layers)] prev_cells = [zero_state for i in range(self.num_layers)] - input_feed = x.new_zeros(bsz, self.encoder_output_units) \ - if encoder_out is not None else None + input_feed = ( + x.new_zeros(bsz, self.encoder_output_units) if encoder_out is not None + else None + ) attn_scores = x.new_zeros(srclen, seqlen, bsz) if encoder_out is not None else None outs = [] @@ -746,7 +803,9 @@ def get_cached_state( assert prev_cells_ is not None prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models + input_feed = cached_state[ + "input_feed" + ] # can be None for decoder-only language models return prev_hiddens, prev_cells, input_feed def reorder_incremental_state( @@ -767,7 +826,7 @@ def reorder_incremental_state( "prev_hiddens": torch.stack(prev_hiddens), "prev_cells": torch.stack(prev_cells), "input_feed": input_feed, - } + }, ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new), return @@ -777,8 +836,9 @@ def masked_copy_incremental_state(self, incremental_state, another_cached_state, assert another_cached_state is None or len(another_cached_state) == 0 return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - another_prev_hiddens, another_prev_cells, another_input_feed = \ + another_prev_hiddens, another_prev_cells, another_input_feed = ( another_cached_state[0], another_cached_state[1], another_cached_state[2] + ) def mask_copy_state(state: Optional[Tensor], another_state: Optional[Tensor]): if state is None: @@ -807,7 +867,7 @@ def mask_copy_state(state: Optional[Tensor], another_state: Optional[Tensor]): "prev_hiddens": torch.stack(prev_hiddens_new), "prev_cells": torch.stack(prev_cells_new), "input_feed": input_feed_new, - } + }, ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new) diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index c7eed3833b..565baa2cd6 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -18,6 +19,7 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear +from omegaconf import DictConfig from espresso.models.speech_lstm import ConvBNReLU, SpeechLSTMEncoder import espresso.tools.utils as speech_utils @@ -72,7 +74,9 @@ def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) @@ -106,7 +110,7 @@ def build_model(cls, args, task): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, - src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), + src_bucketed=(getattr(task.cfg, "num_batch_buckets", 0) > 0), num_targets=getattr(task, "num_targets", None), # targets for encoder-only model chunk_width=getattr(task, "chunk_width", None), chunk_left_context=getattr(task, "chunk_left_context", 0), @@ -140,43 +144,81 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class SpeechChunkLSTMEncoder(SpeechLSTMEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=83, hidden_size=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., src_bucketed=False, - num_targets=None, chunk_width=20, chunk_left_context=0, training_stage=True, + self, + conv_layers_before=None, + input_size=83, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + residual=False, + left_pad=False, + padding_value=0.0, + src_bucketed=False, + num_targets=None, + chunk_width=20, + chunk_left_context=0, + training_stage=True, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__( - conv_layers_before=conv_layers_before, input_size=input_size, hidden_size=hidden_size, - num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_out, - bidirectional=bidirectional, residual=residual, left_pad=left_pad, - padding_value=padding_value, src_bucketed=src_bucketed, max_source_positions=max_source_positions, + conv_layers_before=conv_layers_before, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout_in=dropout_in, + dropout_out=dropout_out, + bidirectional=bidirectional, + residual=residual, + left_pad=left_pad, + padding_value=padding_value, + src_bucketed=src_bucketed, + max_source_positions=max_source_positions, + ) + receptive_field_radius = ( + sum(conv.padding[0] for conv in conv_layers_before.convolutions) + if conv_layers_before is not None + else 0 ) - receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ - if conv_layers_before is not None else 0 assert chunk_width is None or chunk_width > 0 - assert (conv_layers_before is None and chunk_left_context >= 0) or \ - (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + assert ( + (conv_layers_before is None and chunk_left_context >= 0) + or (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) if chunk_width is not None + else None + ) self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(self.output_units, num_targets, dropout=self.dropout_out_module.p) \ - if num_targets is not None else None + self.fc_out = ( + Linear(self.output_units, num_targets, dropout=self.dropout_out_module.p) + if num_targets is not None + else None + ) def forward( self, @@ -214,7 +256,8 @@ def forward( return EncoderOut( encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() + else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 8d22a5e2ea..214a4d0a73 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -21,6 +22,7 @@ from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear from fairseq.modules import FairseqDropout +from omegaconf import DictConfig import espresso.tools.utils as speech_utils @@ -125,13 +127,21 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class TdnnBNReLU(nn.Module): @@ -192,8 +202,12 @@ def __init__( dilations = [dilations] * num_layers else: assert len(dilations) == num_layers - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.residual = residual self.tdnn = nn.ModuleList([ @@ -206,15 +220,22 @@ def __init__( ]) receptive_field_radius = sum(layer.padding for layer in self.tdnn) - assert chunk_width is None or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) + assert ( + chunk_width is None + or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) + ) if ( chunk_width is not None and chunk_width > 0 and chunk_left_context > receptive_field_radius ): - logger.warning("chunk_{{left,right}}_context can be reduced to {}".format(receptive_field_radius)) + logger.warning( + "chunk_{{left,right}}_context can be reduced to {}".format(receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) if chunk_width is not None + else None + ) self.training_stage = training_stage self.fc_out = Linear(hidden_sizes[-1], output_size, dropout=self.dropout_out_module.p) diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 5386b72999..a381b1a77b 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) -@register_model('speech_transformer') +@register_model("speech_transformer") class SpeechTransformerModel(TransformerModel): """ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) @@ -246,7 +246,9 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con super(TransformerEncoder, self).__init__(None) # no src dictionary self.register_buffer("version", torch.Tensor([3])) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = args.encoder_embed_dim @@ -257,7 +259,7 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con self.embed_positions = ( PositionalEmbedding( - self.output_lengths(args.max_source_positions), + self.output_lengths(self.max_source_positions), embed_dim, 0, learned=args.encoder_learned_pos, @@ -297,8 +299,10 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con self.transformer_context = transformer_context def output_lengths(self, in_lengths): - return in_lengths if self.conv_layers_before is None \ + return ( + in_lengths if self.conv_layers_before is None else self.conv_layers_before.output_lengths(in_lengths) + ) def get_attn_mask(self, in_lengths): """ @@ -360,8 +364,10 @@ def forward( if self.conv_layers_before is not None: x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: - x, encoder_padding_mask = src_tokens, \ + x, encoder_padding_mask = ( + src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + ) x = self.dropout_module(x) if self.fc0 is not None: @@ -579,6 +585,15 @@ def base_architecture(args): args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) @register_model_architecture("speech_transformer", "speech_transformer_wsj") diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 989a359865..9fccf68134 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -18,6 +19,7 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import Linear +from omegaconf import DictConfig from espresso.models.speech_lstm import ConvBNReLU from espresso.models.speech_transformer import SpeechTransformerEncoder @@ -70,6 +72,11 @@ def add_args(parser): "can be None or a tuple of two non-negative integers/None") parser.add_argument("--no-token-positional-embeddings", action="store_true", help="if set, disables positional embeddings (outside self attention)") + parser.add_argument("--layernorm-embedding", action="store_true", + help="add layernorm to embedding") + parser.add_argument("--checkpoint-activations", action="store_true", + help="checkpoint activations at each layer, which saves GPU " + "memory usage at the cost of some additional compute") # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument("--encoder-layerdrop", type=float, metavar="D", default=0, help="LayerDrop probability for encoder") @@ -191,13 +198,21 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class SpeechChunkTransformerEncoder(SpeechTransformerEncoder): @@ -210,19 +225,30 @@ def __init__( args, conv_layers_before=conv_layers_before, input_size=input_size, transformer_context=transformer_context, ) - receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ - if conv_layers_before is not None else 0 + receptive_field_radius = ( + sum(conv.padding[0] for conv in conv_layers_before.convolutions) + if conv_layers_before is not None + else 0 + ) assert chunk_width is None or chunk_width > 0 - assert (conv_layers_before is None and chunk_left_context >= 0) or \ - (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + assert ( + (conv_layers_before is None and chunk_left_context >= 0) + or (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) + if chunk_width is not None + else None + ) self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) \ - if num_targets is not None else None + self.fc_out = ( + Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) + if num_targets is not None + else None + ) def forward( self, @@ -360,6 +386,13 @@ def base_architecture(args): ) args.adaptive_input = getattr(args, "adaptive_input", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) @register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model_wsj") diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index 00cabc3f34..6be342a380 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -57,9 +57,10 @@ def __init__(self, super().__init__(word_lm.decoder.dictionary) self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder - assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ - callable(self.lm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.lm_decoder, "masked_copy_incremental_state") + and callable(self.lm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" self.oov_penalty = oov_penalty self.open_vocab = open_vocab @@ -76,7 +77,7 @@ def __init__(self, self.subword_vocab_size = len(subword_dict) def tokenizer(x: str) -> List[str]: - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) assert self.tree.max_out_degree() <= self.subword_vocab_size @@ -86,7 +87,7 @@ def forward(self, prev_output_tokens: torch.Tensor, # Z_Tokens[Batch, SeqLength] encoder_out=None, incremental_state: Dict[str, Any] = None): - assert incremental_state is not None, 'This model is for incremental decoding only' + assert incremental_state is not None, "This model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] # Z_Tokens[Batch, Len=1] bsz = prev_output_tokens.size(0) @@ -95,11 +96,11 @@ def forward(self, # Move the batched state to the next state according to the automaton batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) # B[Batch] - cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') + cached_state = self.lm_decoder.get_incremental_state(incremental_state, "cached_state") if cached_state is None: # First step assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w: torch.Tensor = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) # Z[Batch, Len=1] lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), @@ -110,9 +111,9 @@ def forward(self, else: # Not the first step cumsum_probs: torch.Tensor = self.get_incremental_state( - incremental_state, 'cumsum_probs', + incremental_state, "cumsum_probs", ) # R[Batch, 1, Vocab] - nodes: torch.Tensor = self.get_incremental_state(incremental_state, 'nodes') # Z_NodeId[Batch] + nodes: torch.Tensor = self.get_incremental_state(incremental_state, "nodes") # Z_NodeId[Batch] assert nodes.size(0) == bsz w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] w[w < 0] = self.word_unk_idx @@ -139,8 +140,8 @@ def forward(self, all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] - self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "cumsum_probs", cumsum_probs) + self.set_incremental_state(incremental_state, "nodes", nodes) # Compute probabilities # initialize out_probs [Batch, 1, Vocab] @@ -161,7 +162,7 @@ def forward(self, # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) - out_probs[nodes.eq(self.tree.none_id)] = 1. + out_probs[nodes.eq(self.tree.none_id)] = 1.0 else: # set out_probs to 0 out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) @@ -226,15 +227,15 @@ def forward(self, def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) + self.set_incremental_state(incremental_state, "cumsum_probs", new_cumsum_probs) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_nodes = nodes.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/modules/__init__.py b/espresso/modules/__init__.py index 1e6b35acd2..08e3bb3ca8 100644 --- a/espresso/modules/__init__.py +++ b/espresso/modules/__init__.py @@ -7,6 +7,6 @@ __all__ = [ - 'BahdanauAttention', - 'LuongAttention', + "BahdanauAttention", + "LuongAttention", ] diff --git a/espresso/modules/speech_attention.py b/espresso/modules/speech_attention.py index fcba56e913..73b4132e69 100644 --- a/espresso/modules/speech_attention.py +++ b/espresso/modules/speech_attention.py @@ -54,8 +54,8 @@ def reset_parameters(self): self.value_proj.weight.data.uniform_(-0.1, 0.1) nn.init.uniform_(self.v, -0.1, 0.1) if self.normalize: - nn.init.constant_(self.b, 0.) - nn.init.constant_(self.g, math.sqrt(1. / self.embed_dim)) + nn.init.constant_(self.b, 0.0) + nn.init.constant_(self.g, math.sqrt(1.0 / self.embed_dim)) def forward(self, query, value, key_padding_mask=None, state=None): # projected_query: 1 x bsz x embed_dim @@ -71,9 +71,11 @@ def forward(self, query, value, key_padding_mask=None, state=None): attn_scores = self.v * torch.tanh(projected_query + key).sum(dim=2) if key_padding_mask is not None: - attn_scores = attn_scores.float().masked_fill_( - key_padding_mask, float('-inf'), - ).type_as(attn_scores) # FP16 support: cast to float and back + attn_scores = ( + attn_scores.float() + .masked_fill_(key_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz @@ -99,7 +101,7 @@ def __init__(self, query_dim, value_dim, embed_dim=None, scale=True): def reset_parameters(self): self.value_proj.weight.data.uniform_(-0.1, 0.1) if self.scale: - nn.init.constant_(self.g, 1.) + nn.init.constant_(self.g, 1.0) def forward(self, query, value, key_padding_mask=None, state=None): query = query.unsqueeze(1) # bsz x 1 x query_dim @@ -110,9 +112,11 @@ def forward(self, query, value, key_padding_mask=None, state=None): attn_scores = self.g * attn_scores if key_padding_mask is not None: - attn_scores = attn_scores.float().masked_fill_( - key_padding_mask, float('-inf'), - ).type_as(attn_scores) # FP16 support: cast to float and back + attn_scores = ( + attn_scores.float() + .masked_fill_(key_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index e8e9198600..c3f1141dbc 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II from typing import List import torch.optim.lr_scheduler @@ -13,6 +12,7 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.lr_scheduler import register_lr_scheduler from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau +from omegaconf import II, DictConfig @dataclass @@ -40,7 +40,7 @@ class ReduceLROnPlateauV2Config(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) final_lr_scale: float = field( @@ -52,25 +52,30 @@ class ReduceLROnPlateauV2Config(FairseqDataclass): metadata={"help": "start to reduce lr from the specified epoch"}, ) # TODO common vars at parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II("checkpoint.maximize_best_checkpoint_metric") @register_lr_scheduler("reduce_lr_on_plateau_v2", dataclass=ReduceLROnPlateauV2Config) class ReduceLROnPlateauV2(ReduceLROnPlateau): """Decay the LR by a factor every time the validation loss plateaus, starting - from the epoch specified as args.start_reduce_lr_epoch. + from the epoch specified as cfg.start_reduce_lr_epoch. We also support specifying a final lr which will be kept until the max number of epochs is reached. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) + def __init__(self, cfg: DictConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + self.cfg = cfg self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, - mode="max" if args.maximize_best_checkpoint_metric else "min", - threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0] + self.optimizer.optimizer, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, + min_lr=cfg.final_lr_scale * cfg.lr[0], ) @classmethod @@ -80,8 +85,8 @@ def add_args(cls, parser): gen_parser_from_dataclass(parser, dc()) def step(self, epoch, val_loss=None): - if epoch < self.args.start_reduce_lr_epoch: + if epoch < self.cfg.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch - self.optimizer.set_lr(self.args.lr[0]) + self.optimizer.set_lr(self.cfg.lr[0]) return self.optimizer.get_lr() return super().step(epoch, val_loss) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index a08f101207..5c982f4645 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -14,15 +14,16 @@ import math import os import sys +from argparse import Namespace import numpy as np - import torch - from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel +from omegaconf import DictConfig from espresso.models.external_language_model import MultiLevelLanguageModel from espresso.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel @@ -30,17 +31,22 @@ from espresso.tools.utils import plot_attention, sequence_mask -def main(args): - assert args.path is not None, "--path required for recognition!" - assert not args.sampling or args.nbest == args.beam, \ - "--sampling requires --nbest to be equal to --beam" +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for recognition!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" - if args.results_path is not None: - os.makedirs(args.results_path, exist_ok=True) - output_path = os.path.join(args.results_path, "decode.log") + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join(cfg.common_eval.results_path, "decode.log") with open(output_path, "w", buffering=1, encoding="utf-8") as h: - return _main(args, h) - return _main(args, sys.stdout) + return _main(cfg, h) + return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): @@ -50,7 +56,7 @@ def get_symbols_to_strip_from_output(generator): return {generator.eos, generator.pad} -def _main(args, output_file): +def _main(cfg, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -61,53 +67,53 @@ def _main(args, output_file): if output_file is not sys.stdout: # also print to stdout logger.addHandler(logging.StreamHandler(sys.stdout)) - print_options_meaning_changes(args, logger) + print_options_meaning_changes(cfg, logger) - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset split - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) # Set dictionary dictionary = task.target_dictionary - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) - if args.lm_path is not None: - overrides["data"] = args.data + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.lm_path), - arg_overrides=overrides, - task=None, + utils.split_paths(cfg.generation.lm_path), arg_overrides=overrides, task=None, ) except: - logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})") + logger.warning( + f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({cfg.task.data})" + ) raise assert len(lms) == 1 or len(lms) == 2 # Multi-level LM expects two LMs @@ -122,61 +128,60 @@ def _main(args, output_file): if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel): lms[i - 1] = MultiLevelLanguageModel( m, lms[i - 1], - subwordlm_weight=args.subwordlm_weight, - oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab, + subwordlm_weight=cfg.generation.subwordlm_weight, + oov_penalty=cfg.generation.oov_penalty, + open_vocab=not cfg.generation.disable_open_vocab, ) del lms[i] logger.info("LM fusion with Multi-level LM") else: lms[i] = TensorizedLookaheadLanguageModel( m, dictionary, - oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab, + oov_penalty=cfg.generation.oov_penalty, + open_vocab=not cfg.generation.disable_open_vocab, ) logger.info("LM fusion with Look-ahead Word LM") else: assert isinstance(m, FairseqLanguageModel) logger.info("LM fusion with Subword LM") - if args.lm_weight != 0.0: - logger.info("using LM fusion with lm-weight={:.2f}".format(args.lm_weight)) + if cfg.generation.lm_weight != 0.0: + logger.info("using LM fusion with lm-weight={:.2f}".format(cfg.generation.lm_weight)) # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() if hasattr(model, "encoder") - else (None, model.max_positions()) for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator - if args.match_source_len: + if cfg.generation.match_source_len: logger.warning( "The option match_source_len is not applicable to speech recognition. Ignoring it." ) @@ -184,18 +189,20 @@ def _main(args, output_file): extra_gen_cls_kwargs = { "lm_model": lms[0], - "lm_weight": args.lm_weight, - "eos_factor": args.eos_factor, + "lm_weight": cfg.generation.lm_weight, + "eos_factor": cfg.generation.eos_factor, } - args.score_reference = False # not applicable for ASR - temp_val = args.print_alignment - args.print_alignment = False # not applicable for ASR - generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) - args.print_alignment = temp_val + cfg.generation.score_reference = False # not applicable for ASR + temp_val = cfg.generation.print_alignment + cfg.generation.print_alignment = False # not applicable for ASR + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + cfg.generation.print_alignment = temp_val # Handle tokenization and BPE - tokenizer = task.build_tokenizer(args) - bpe = task.build_bpe(args) + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: @@ -204,8 +211,8 @@ def decode_fn(x): x = tokenizer.decode(x) return x - # Generate and compute WER - scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) + scorer = wer.Scorer(dictionary, wer_output_filter=cfg.task.wer_output_filter) + num_sentences = 0 has_target = True wps_meter = TimeMeter() @@ -215,20 +222,26 @@ def decode_fn(x): continue prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, :args.prefix_size] + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions - if args.print_alignment: + if cfg.generation.print_alignment: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] output_lengths = models[0].encoder.output_lengths(net_input["src_lengths"]) @@ -241,19 +254,19 @@ def decode_fn(x): # Retrieve the original sentences if has_target: target_str = sample["target_raw_text"][i] - if not args.quiet: + if not cfg.common_eval.quiet: detok_target_str = decode_fn(target_str) print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][:args.nbest]): + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): hypo_str = dictionary.string( hypo["tokens"].int().cpu(), bpe_symbol=None, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) - if not args.quiet: + if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 print("H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score), file=output_file) @@ -261,9 +274,9 @@ def decode_fn(x): if j == 0: # src_len x tgt_len attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \ - if args.print_alignment and hypo["attention"] is not None else None - if args.print_alignment and attention is not None: - save_dir = os.path.join(args.results_path, "attn_plots") + if cfg.generation.print_alignment and hypo["attention"] is not None else None + if cfg.generation.print_alignment and attention is not None: + save_dir = os.path.join(cfg.common_eval.results_path, "attn_plots") os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) @@ -277,26 +290,26 @@ def decode_fn(x): logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) - if args.print_alignment: + if cfg.generation.print_alignment: logger.info("Saved attention plots in " + save_dir) if has_target: - scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) + scorer.add_ordered_utt_list(task.datasets[cfg.dataset.gen_subset].tgt.utt_ids) fn = "decoded_char_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_char_results()) logger.info("Decoded char results saved as " + f.name) fn = "decoded_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_results()) logger.info("Decoded results saved as " + f.name) if has_target: - header = "Recognize {} with beam={}: ".format(args.gen_subset, args.beam) + header = "Recognize {} with beam={}: ".format(cfg.dataset.gen_subset, cfg.generation.beam) fn = "wer" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.wer())) logger.info(header + res) @@ -304,7 +317,7 @@ def decode_fn(x): logger.info("WER saved in " + f.name) fn = "cer" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.cer())) logger.info(" " * len(header) + res) @@ -312,34 +325,23 @@ def decode_fn(x): logger.info("CER saved in " + f.name) fn = "aligned_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_aligned_results()) logger.info("Aligned results saved as " + f.name) return scorer -def print_options_meaning_changes(args, logger): +def print_options_meaning_changes(cfg, logger): """Options that have different meanings than those in the translation task are explained here. """ logger.info("--max-tokens is the maximum number of input frames in a batch") - if args.print_alignment: + if cfg.generation.print_alignment: logger.info("--print-alignment has been set to plot attentions") def cli_main(): parser = options.get_generation_parser(default_task="speech_recognition_espresso") - parser.add_argument("--eos-factor", default=None, type=float, metavar="F", - help="only consider emitting EOS if its score is no less " - "than the specified factor of the best candidate score") - parser.add_argument("--subwordlm-weight", default=0.8, type=float, metavar="W", - help="subword LM weight relative to word LM. Only relevant " - "to MultiLevelLanguageModel as an external LM") - parser.add_argument("--oov-penalty", default=1e-4, type=float, - help="oov penalty with the pretrained external LM") - parser.add_argument("--disable-open-vocab", action="store_true", - help="whether open vocabulary mode is enabled with the " - "pretrained external LM") args = options.parse_args_and_arch(parser) assert args.results_path is not None, "please specify --results-path" main(args) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 1c2b2b4f37..9391ba72e5 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,13 +8,16 @@ Train a new model on one or across multiple GPUs. """ +import argparse import logging import math import os import sys +from typing import Dict, Optional, Any, List, Tuple, Callable import numpy as np import torch + from fairseq import ( checkpoint_utils, distributed_utils, @@ -24,8 +27,10 @@ utils, ) from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig from fairseq.trainer import Trainer @@ -38,90 +43,89 @@ logger = logging.getLogger("espresso.speech_train") -def main(args): - utils.import_user_module(args) +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - assert ( - args.max_tokens is not None or args.batch_size is not None - ), "Must specify batch size either with --max-tokens or --batch-size" + utils.import_user_module(cfg.common) + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(args): - checkpoint_utils.verify_checkpoint_directory(args.save_dir) + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args - logger.info(args) + logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(args) - + task = tasks.setup_task(cfg.task) + # Handle tokenization and BPE + task.build_tokenizer(cfg.tokenizer) + task.build_bpe(cfg.bpe) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(","): + for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion - model = task.build_model(args) - criterion = task.build_criterion(args) + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) - logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) - logger.info( - "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) - ) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) logger.info( - "num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # (optionally) Configure quantization - if args.quantization_config_path is not None: + if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( - config_path=args.quantization_config_path, - max_epoch=args.max_epoch, - max_update=args.max_update, + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer - if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion, quantizer) + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) else: - trainer = MegatronTrainer(args, task, model, criterion) + trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( - "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) - ) - logger.info( - "max input frames per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.batch_size - ) - ) + logger.info("training on {} devices (GPUs/TPUs)".format(cfg.distributed_training.distributed_world_size)) + logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - args, + cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) - # Train until the learning rate gets too small - max_epoch = args.max_epoch or math.inf + max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - - while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): # train for one epoch - valid_losses, should_stop = train(args, trainer, task, epoch_itr) + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break @@ -139,15 +143,15 @@ def main(args): logger.info("done training in {:.1f} seconds".format(train_meter.sum)) -def should_stop_early(args, valid_loss): +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False - if args.patience <= 0: + if cfg.checkpoint.patience <= 0: return False def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): @@ -156,42 +160,41 @@ def is_better(a, b): return False else: should_stop_early.num_runs += 1 - if should_stop_early.num_runs >= args.patience: - logger.info( - "early stop since valid performance hasn't improved for last {} runs".format( - args.patience - ) - ) + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info("early stop since valid performance hasn't improved for last {} runs".format(cfg.checkpoint.patience)) return True else: return False @metrics.aggregate("train") -def train(args, trainer, task, epoch_itr): +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( - fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( - args.update_freq[epoch_itr.epoch - 1] - if epoch_itr.epoch <= len(args.update_freq) - else args.update_freq[-1] + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, "tpu", False): + if getattr(cfg.common, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) @@ -199,8 +202,7 @@ def train(args, trainer, task, epoch_itr): if hasattr(trainer.criterion, "set_epoch"): trainer.criterion.set_epoch(epoch_itr.epoch) - valid_losses = [None] - valid_subsets = args.valid_subset.split(",") + valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -212,7 +214,7 @@ def train(args, trainer, task, epoch_itr): if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: + if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) @@ -220,13 +222,13 @@ def train(args, trainer, task, epoch_itr): # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") - # update the state prior stored in the model for cross-entropy training + # update the state prior stored in the model for cross-entropy training of hybrid systems if hasattr(task, "update_state_prior"): task.update_state_prior(trainer.get_model()) end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( - args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: @@ -242,84 +244,87 @@ def train(args, trainer, task, epoch_itr): return valid_losses, should_stop -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() - max_update = args.max_update or math.inf + max_update = cfg.optimization.max_update or math.inf do_save = ( - (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( - args.save_interval_updates > 0 + cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( - args.validate_interval_updates > 0 + cfg.dataset.validate_interval_updates > 0 and num_updates > 0 - and num_updates % args.validate_interval_updates == 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not args.disable_validation + ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( - should_stop_early(args, valid_losses[0]) + should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( - args.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop -def get_training_stats(stats): +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats -def validate(args, trainer, task, epoch_itr, subsets): +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" - if args.fixed_validation_seed is not None: + if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation - utils.set_torch_seed(args.fixed_validation_seed) + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: - logger.info('begin validation on "{}" subset'.format(subset)) + logger.info("begin validation on '{}' subset".format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics @@ -329,20 +334,20 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric]) + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer, stats): +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): - key = "best_{0}".format(args.best_checkpoint_metric) - best_function = max if args.maximize_best_checkpoint_metric else min + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] ) return stats @@ -354,16 +359,19 @@ def print_options_meaning_changes(args): logger.info("--max-tokens is the maximum number of input frames in a batch") -def cli_main(modify_parser=None): +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) + + cfg = convert_namespace_to_omegaconf(args) + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) else: - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) if __name__ == "__main__": diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 4af8008b15..5417359103 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -6,6 +6,7 @@ import logging import os from dataclasses import dataclass, field +from typing import Optional import torch @@ -22,7 +23,7 @@ @dataclass class LanguageModelingForASRConfig(LanguageModelingConfig): - dict: str = field(default=None, metadata={"help": "path to the dictionary"}) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) @register_task("language_modeling_for_asr", dataclass=LanguageModelingForASRConfig) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 5df4594564..58ed02b1cf 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -8,13 +8,17 @@ import json import logging import os +from dataclasses import dataclass, field +from typing import Optional import torch from fairseq import utils from fairseq.data import BaseWrapperDataset, ConcatDataset +from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task +from omegaconf import II, DictConfig from espresso.data import ( AsrDataset, @@ -27,12 +31,75 @@ logger = logging.getLogger(__name__) +@dataclass +class SpeechRecognitionEspressoConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) + non_lang_syms: Optional[str] = field( + default=None, + metadata={ + "help": "path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER" + }, + ) + word_dict: Optional[str] = field( + default=None, + metadata={"help": "path to the word dictionary. Only relevant for decoding"}, + ) + wer_output_filter: Optional[str] = field( + default=None, + metadata={"help": "path to wer_output_filter file for WER evaluation"}, + ) + max_source_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=1, metadata={"help": "amount to upsample primary dataset"}, + ) + num_batch_buckets: Optional[int] = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations" + }, + ) + feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + specaugment_config: Optional[str] = field( + default=None, + metadata={ + "help": "SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values" + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + train_subset: str = II("dataset.train_subset") + gen_subset: str = II("dataset.gen_subset") + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + + def get_asr_dataset_from_json( - data_path, split, tgt_dict, - combine, upsample_primary, - num_buckets=0, shuffle=True, + data_path, + split, + tgt_dict, + combine, + upsample_primary, + num_buckets=0, + shuffle=True, pad_to_multiple=1, - seed=1, specaugment_config=None, + seed=1, + specaugment_config=None, ): """ Parse data json and create dataset. @@ -58,7 +125,9 @@ def get_asr_dataset_from_json( if k > 0: break else: - raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + raise FileNotFoundError( + "Dataset not found: {}".format(data_json_path) + ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) @@ -97,8 +166,9 @@ def get_asr_dataset_from_json( tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): - assert feat_dim == src_datasets[i].feat_dim, \ - "feature dimension does not match across multiple json files" + assert ( + feat_dim == src_datasets[i].feat_dim + ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) @@ -109,8 +179,10 @@ def get_asr_dataset_from_json( tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, @@ -120,7 +192,7 @@ def get_asr_dataset_from_json( ) -@register_task("speech_recognition_espresso") +@register_task("speech_recognition_espresso", dataclass=SpeechRecognitionEspressoConfig) class SpeechRecognitionEspressoTask(FairseqTask): """ Transcribe from speech (source) to token text (target). @@ -144,40 +216,6 @@ class SpeechRecognitionEspressoTask(FairseqTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument("data", help="path to data directory") - parser.add_argument("--dict", default=None, type=str, - help="path to the dictionary") - parser.add_argument("--non-lang-syms", default=None, type=str, - help="path to a file listing non-linguistic symbols, e.g., " - "etc. One entry per line. To be filtered out when calculating WER/CER.") - parser.add_argument("--word-dict", default=None, type=str, - help="path to the word dictionary. Only relevant for decoding") - parser.add_argument("--wer-output-filter", default=None, type=str, - help="path to wer_output_filter file for WER evaluation") - parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", - help="max number of frames in the source sequence") - parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", - help="max number of tokens in the target sequence") - parser.add_argument("--upsample-primary", default=1, type=int, - help="amount to upsample primary dataset") - parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", - help="if >0, then bucket source and target lengths into N " - "buckets and pad accordingly; this is useful on TPUs " - "to minimize the number of compilations") - parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", - help="feature input channels") - parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", - help="SpecAugment config string. If not None and not empty, " - "then apply SpecAugment. Should be an evaluatable expression of " - "a python dict. See speech_tools.specaug_interpolate.specaug() for " - "all allowed arguments. Argments not appearing in this string " - "will take on their default values") - # fmt: off - @classmethod def load_dictionary(cls, filename, non_lang_syms=None): """Load the dictionary from the filename @@ -195,14 +233,12 @@ def build_dictionary( """ raise NotImplementedError - def __init__(self, args, tgt_dict, word_dict=None): - super().__init__(args) + def __init__(self, cfg: DictConfig, tgt_dict, word_dict=None): + super().__init__(cfg) self.tgt_dict = tgt_dict - self.tgt_dict.build_tokenizer(args) - self.tgt_dict.build_bpe(args) self.word_dict = word_dict - self.feat_in_channels = args.feat_in_channels - self.specaugment_config = args.specaugment_config + self.feat_in_channels = cfg.feat_in_channels + self.specaugment_config = cfg.specaugment_config torch.backends.cudnn.deterministic = True # Compansate for the removel of :func:`torch.rand()` from # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, @@ -210,23 +246,23 @@ def __init__(self, args, tgt_dict, word_dict=None): torch.rand(1) @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ # load dictionaries - dict_path = os.path.join(args.data, "dict.txt") if args.dict is None else args.dict - tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) + dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict + tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=cfg.non_lang_syms) logger.info("dictionary: {} types".format(len(tgt_dict))) - if args.word_dict is not None: - word_dict = cls.load_dictionary(args.word_dict) + if cfg.word_dict is not None: + word_dict = cls.load_dictionary(cfg.word_dict) logger.info("word dictionary: {} types".format(len(word_dict))) - return cls(args, tgt_dict, word_dict) + return cls(cfg, tgt_dict, word_dict) else: - return cls(args, tgt_dict) + return cls(cfg, tgt_dict) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -234,21 +270,23 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 - if split != getattr(self.args, "train_subset", None): + if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( - data_path, split, self.tgt_dict, + data_path, + split, + self.tgt_dict, combine=combine, - upsample_primary=self.args.upsample_primary, - num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", None)), - pad_to_multiple=self.args.required_seq_len_multiple, - seed=self.args.seed, + upsample_primary=self.cfg.upsample_primary, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + seed=self.cfg.seed, specaugment_config=self.specaugment_config, ) @@ -271,13 +309,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return AsrDataset( - src_tokens, src_lengths, dictionary=self.target_dictionary, constraints=constraints, + src_tokens, + src_lengths, + dictionary=self.target_dictionary, + constraints=constraints, ) - def build_model(self, args): - model = super().build_model(args) + def build_model(self, cfg: DictConfig): + model = super().build_model(cfg) # build the greedy decoder for validation with WER from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder + self.decoder_for_validation = SimpleGreedyDecoder( [model], self.target_dictionary, for_validation=True, ) @@ -304,13 +346,25 @@ def reduce_metrics(self, logging_outputs, criterion): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def target_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.tgt_dict + def build_tokenizer(self, cfg: DictConfig): + """Build the pre-tokenizer for this task.""" + self.tgt_dict.build_tokenizer(cfg) + # the instance is built within self.tgt_dict + return self.tgt_dict.tokenizer + + def build_bpe(self, cfg: DictConfig): + """Build the tokenizer for this task.""" + self.tgt_dict.build_bpe(cfg) + # the instance is built within self.tgt_dict + return self.tgt_dict.bpe + @property def word_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" @@ -319,7 +373,7 @@ def word_dictionary(self): def _inference_with_wer(self, decoder, sample, model): from espresso.tools import wer - scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.args.wer_output_filter) + scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.cfg.wer_output_filter) tokens, lprobs, _ = decoder.decode([model], sample) pred = tokens[:, 1:].data.cpu() # bsz x len target = sample["target"] diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 6eedeaa843..a10b07ea52 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -8,13 +8,17 @@ import json import logging import os +from dataclasses import dataclass, field +from typing import Optional import torch from fairseq import utils from fairseq.data import BaseWrapperDataset, ConcatDataset - +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig from fairseq.tasks import FairseqTask, register_task +from omegaconf import II, DictConfig from espresso.data import ( AliScpCachedDataset, @@ -35,14 +39,129 @@ logger = logging.getLogger(__name__) +@dataclass +class SpeechRecognitionHybridConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) + non_lang_syms: Optional[str] = field( + default=None, + metadata={ + "help": "path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER" + }, + ) + wer_output_filter: Optional[str] = field( + default=None, + metadata={"help": "path to wer_output_filter file for WER evaluation"}, + ) + max_source_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=1, metadata={"help": "amount to upsample primary dataset"}, + ) + num_batch_buckets: Optional[int] = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations" + }, + ) + feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + specaugment_config: Optional[str] = field( + default=None, + metadata={ + "help": "SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values" + }, + ) + num_targets: int = field( + default=3000, + metadata={"help": "number of targets for training (e.g., num pdf-ids)"}, + ) + initial_state_prior_file: Optional[str] = field( + default=None, + metadata={ + "help": "path to the file containing initial state prior. Only relevant " + "with cross-entropy training" + }, + ) + state_prior_update_interval: Optional[int] = field( + default=None, + metadata={ + "help": "state prior estimate will be updated every this number of updates " + "during training. If None, then use the initial value estimated from the " + "alignments. Only relevant with cross-entropy training" + }, + ) + state_prior_update_smoothing: Optional[float] = field( + default=0.1, + metadata={ + "help": "smoothing factor while updating state prior estimate. Only " + "relevant with cross-entropy training" + }, + ) + chunk_width: Optional[int] = field( + default=None, + metadata={ + "help": "chunk width for train/test data. Only relevant with chunk-wise " + "training (including both cross-entropy and Lattice-free MMI). " + "Do utterance-wise training/test if not specified" + }, + ) + chunk_left_context: Optional[int] = field( + default=0, + metadata={"help": "number of frames appended to the left of a chunk"}, + ) + chunk_right_context: Optional[int] = field( + default=0, + metadata={"help": "number of frames appended to the right of a chunk"}, + ) + label_delay: Optional[int] = field( + default=0, + metadata={ + "help": "offet of alignments as prediction labels. Maybe useful " + "in archs such as asymmetric convolution, unidirectional LSTM, etc. " + "It can be negative. Only relevant with chunk-wise cross-entropy training" + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + train_subset: str = II("dataset.train_subset") + valid_subset: str = II("dataset.valid_subset") + gen_subset: str = II("dataset.gen_subset") + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + criterion_name: str = II("criterion._name") + max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage + + def get_asr_dataset_from_json( - data_path, split, dictionary, - combine, upsample_primary, - num_buckets=0, shuffle=True, + data_path, + split, + dictionary, + combine, + upsample_primary, + num_buckets=0, + shuffle=True, pad_to_multiple=1, lf_mmi=True, - seed=1, specaugment_config=None, - chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, + seed=1, + specaugment_config=None, + chunk_width=None, + chunk_left_context=None, + chunk_right_context=None, + label_delay=0, ): """ Parse data json and create dataset. @@ -72,7 +191,9 @@ def get_asr_dataset_from_json( if k > 0: break else: - raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + raise FileNotFoundError( + "Dataset not found: {}".format(data_json_path) + ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) @@ -103,9 +224,12 @@ def get_asr_dataset_from_json( else: # cross-entropy if len(alignments) > 0: assert len(utt_ids) == len(alignments) - tgt_datasets.append(AliScpCachedDataset( - utt_ids, alignments, utt2num_frames=utt2num_frames, ordered_prefetch=True - )) + tgt_datasets.append( + AliScpCachedDataset( + utt_ids, alignments, utt2num_frames=utt2num_frames, + ordered_prefetch=True, + ) + ) if len(text) > 0: assert len(utt_ids) == len(text) @@ -127,8 +251,9 @@ def get_asr_dataset_from_json( text_dataset = text_datasets[0] if len(text_datasets) > 0 else None else: for i in range(1, len(src_datasets)): - assert feat_dim == src_datasets[i].feat_dim, \ - "feature dimension does not match across multiple json files" + assert ( + feat_dim == src_datasets[i].feat_dim + ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) @@ -144,8 +269,10 @@ def get_asr_dataset_from_json( tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None if lf_mmi: return AsrChainDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, @@ -153,19 +280,24 @@ def get_asr_dataset_from_json( ) else: return AsrXentDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, - seed=seed, chunk_width=chunk_width, - chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, - label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), + seed=seed, + chunk_width=chunk_width, + chunk_left_context=chunk_left_context, + chunk_right_context=chunk_right_context, + label_delay=label_delay, + random_chunking=(split == "train" and chunk_width is not None), ) -@register_task("speech_recognition_hybrid") +@register_task("speech_recognition_hybrid", dataclass=SpeechRecognitionHybridConfig) class SpeechRecognitionHybridTask(FairseqTask): """ Hybrid speech recognition with lattice-free MMI or cross-entropy loss. @@ -192,64 +324,6 @@ class SpeechRecognitionHybridTask(FairseqTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument("data", help="path to data directory") - parser.add_argument("--dict", default=None, type=str, - help="path to the dictionary") - parser.add_argument("--non-lang-syms", default=None, type=str, - help="path to a file listing non-linguistic symbols, e.g., " - "etc. One entry per line. To be filtered out when calculating WER/CER.") - parser.add_argument("--wer-output-filter", default=None, type=str, - help="path to wer_output_filter file for WER evaluation") - parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", - help="max number of frames in the source sequence") - parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", - help="max number of tokens in the target sequence") - parser.add_argument("--upsample-primary", default=1, type=int, - help="amount to upsample primary dataset") - parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", - help="if >0, then bucket source and target lengths into N " - "buckets and pad accordingly; this is useful on TPUs " - "to minimize the number of compilations") - parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", - help="feature input channels") - parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", - help="SpecAugment config string. If not None and not empty, " - "then apply SpecAugment. Should be an evaluatable expression of " - "a python dict. See speech_tools.specaug_interpolate.specaug() for " - "all allowed arguments. Argments not appearing in this string " - "will take on their default values") - - parser.add_argument("--num-targets", type=int, metavar="N", - help="number of targets for training (e.g., num pdf-ids)") - parser.add_argument("--initial-state-prior-file", default=None, type=str, metavar="FILE", - help="path to the file containing initial state prior. Only relevant " - "with cross-entropy training") - parser.add_argument("--state-prior-update-interval", default=None, type=int, metavar="N", - help="state prior estimate will be updated every this " - "number of updates during training. If None, then use " - "the initial value estimated from the alignments. Only relevant with " - "cross-entropy training") - parser.add_argument("--state-prior-update-smoothing", default=0.1, type=float, metavar="D", - help="smoothing factor while updating state prior estimate. Only " - "relevant with cross-entropy training") - parser.add_argument("--chunk-width", default=None, type=int, metavar="D", - help="chunk width for train/test data. Only relevant with chunk-wise " - "training (including both cross-entropy and Lattice-free MMI). " - "Do utterance-wise training/test if not specified") - parser.add_argument("--chunk-left-context", default=0, type=int, metavar="D", - help="number of frames appended to the left of a chunk") - parser.add_argument("--chunk-right-context", default=0, type=int, metavar="D", - help="number of frames appended to the right of a chunk") - parser.add_argument("--label-delay", default=0, type=int, metavar="D", - help="offet of alignments as prediction labels. Maybe useful " - "in archs such as asymmetric convolution, unidirectional LSTM, etc. " - "It can be negative. Only relevant with chunk-wise cross-entropy training") - # fmt: off - @classmethod def load_dictionary(cls, filename, non_lang_syms=None): """Load the dictionary from the filename @@ -265,51 +339,55 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, args, dictionary): - super().__init__(args) + def __init__(self, cfg: DictConfig, dictionary): + super().__init__(cfg) self.dictionary = dictionary - self.feat_in_channels = args.feat_in_channels - self.specaugment_config = args.specaugment_config - self.num_targets = args.num_targets - self.training_stage = hasattr(args, "valid_subset") + self.feat_in_channels = cfg.feat_in_channels + self.specaugment_config = cfg.specaugment_config + self.num_targets = cfg.num_targets + self.training_stage = (cfg.max_epoch > 0) # a hack # the following attributes are related to state_prior estimate self.initial_state_prior = None - if args.initial_state_prior_file is not None: # only relevant for Xent training, used in models - self.initial_state_prior = kaldi_io.read_vec_flt(args.initial_state_prior_file) + if cfg.initial_state_prior_file is not None: # only relevant for Xent training, used in models + self.initial_state_prior = kaldi_io.read_vec_flt(cfg.initial_state_prior_file) self.initial_state_prior = torch.from_numpy(self.initial_state_prior) - assert self.initial_state_prior.size(0) == self.num_targets, \ - "length of initial_state_prior ({}) != num_targets ({})".format( - self.initial_state_prior.size(0), self.num_targets - ) - self.state_prior_update_interval = args.state_prior_update_interval + assert ( + self.initial_state_prior.size(0) == self.num_targets + ), "length of initial_state_prior ({}) != num_targets ({})".format( + self.initial_state_prior.size(0), self.num_targets + ) + self.state_prior_update_interval = cfg.state_prior_update_interval if self.state_prior_update_interval is None and self.initial_state_prior is not None: logger.info("state prior will not be updated during training") - self.state_prior_update_smoothing = args.state_prior_update_smoothing + self.state_prior_update_smoothing = cfg.state_prior_update_smoothing self.averaged_state_post = None # state poterior will be saved here before commited as new state prior # the following 4 options are for chunk-wise training/test (including Xent and LF-MMI) - self.chunk_width = args.chunk_width - self.chunk_left_context = args.chunk_left_context - self.chunk_right_context = args.chunk_right_context - self.label_delay = args.label_delay # only for chunk-wise Xent training + self.chunk_width = cfg.chunk_width + self.chunk_left_context = cfg.chunk_left_context + self.chunk_right_context = cfg.chunk_right_context + self.label_delay = cfg.label_delay # only for chunk-wise Xent training torch.backends.cudnn.deterministic = True @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ # load dictionaries - dict_path = args.dict - dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) if \ - dict_path is not None else None + dict_path = cfg.dict + dictionary = ( + cls.load_dictionary(dict_path, non_lang_syms=cfg.non_lang_syms) + if dict_path is not None + else None + ) if dictionary is not None: logger.info("dictionary: {} types".format(len(dictionary))) - return cls(args, dictionary) + return cls(cfg, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -317,24 +395,30 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 - if split != getattr(self.args, "train_subset", None): + if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( - data_path, split, self.dictionary, + data_path, + split, + self.dictionary, combine=combine, - upsample_primary=self.args.upsample_primary, - num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", None)), - pad_to_multiple=self.args.required_seq_len_multiple, - lf_mmi=(self.args.criterion == "lattice_free_mmi"), - seed=self.args.seed, specaugment_config=self.specaugment_config, - chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, - chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, + upsample_primary=self.cfg.upsample_primary, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + lf_mmi=(self.cfg.criterion_name == "lattice_free_mmi"), + seed=self.cfg.seed, + specaugment_config=self.specaugment_config, + chunk_width=None if self.training_stage + and split in self.cfg.valid_subset.split(",") + else self.chunk_width, + chunk_left_context=self.chunk_left_context, + chunk_right_context=self.chunk_right_context, label_delay=self.label_delay, ) @@ -346,14 +430,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): else: self.feat_dim = src_dataset.feat_dim - def build_generator(self, models, args): - if args.score_reference: - args.score_reference = False + def build_generator(self, models, cfg: GenerationConfig): + if cfg.score_reference: + cfg.score_reference = False logger.warning( "--score-reference is not applicable to speech recognition, ignoring it." ) from espresso.tools.generate_log_probs_for_decoding import GenerateLogProbsForDecoding - apply_log_softmax = getattr(args, "apply_log_softmax", False) + + apply_log_softmax = getattr(cfg, "apply_log_softmax", False) return GenerateLogProbsForDecoding(models, apply_log_softmax=apply_log_softmax) def build_dataset_for_inference(self, src_tokens, src_lengths): @@ -387,7 +472,7 @@ def update_state_prior(self, model): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def target_dictionary(self): diff --git a/espresso/tools/compute_wer.py b/espresso/tools/compute_wer.py index 8555e0995e..7a56fed337 100755 --- a/espresso/tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -14,27 +14,27 @@ logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stderr, ) -logger = logging.getLogger('espresso.tools.compute_wer') +logger = logging.getLogger("espresso.tools.compute_wer") def get_parser(): parser = argparse.ArgumentParser( - description='Compute WER from text') + description="Compute WER from text") # fmt: off - parser.add_argument('--non-lang-syms', default=None, type=str, - help='path to a file listing non-linguistic symbols, ' - 'e.g., etc. One entry per line.') - parser.add_argument('--wer-output-filter', default=None, type=str, - help='path to wer_output_filter file for WER evaluation') - parser.add_argument('ref_text', type=str, - help='path to the reference text file') - parser.add_argument('hyp_text', type=str, - help='path to the hypothesis text file') + parser.add_argument("--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, " + "e.g., etc. One entry per line.") + parser.add_argument("--wer-output-filter", default=None, type=str, + help="path to wer_output_filter file for WER evaluation") + parser.add_argument("ref_text", type=str, + help="path to the reference text file") + parser.add_argument("hyp_text", type=str, + help="path to the hypothesis text file") # fmt: on @@ -44,36 +44,36 @@ def get_parser(): def main(args): non_lang_syms = [] if args.non_lang_syms is not None: - with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + with open(args.non_lang_syms, "r", encoding="utf-8") as f: non_lang_syms = [x.rstrip() for x in f.readlines()] word_filters = [] if args.wer_output_filter is not None: - with open(args.wer_output_filter, 'r', encoding='utf-8') as f: + with open(args.wer_output_filter, "r", encoding="utf-8") as f: for line in f: line = line.strip() - if line.startswith('#!') or line == '': + if line.startswith("#!") or line == "": continue - elif line.startswith('s/'): - m = re.match(r's/(\S+)/(\w*)/g', line) + elif line.startswith("s/"): + m = re.match(r"s/(\S+)/(\w*)/g", line) assert m is not None word_filters.append([m.group(1), m.group(2)]) - elif line.startswith('s:'): - m = re.match(r's:(\S+):(\w*):g', line) + elif line.startswith("s:"): + m = re.match(r"s:(\S+):(\w*):g", line) assert m is not None word_filters.append([m.group(1), m.group(2)]) else: - logger.warning('Unsupported pattern: "{}". Ignoring it.'.format(line)) + logger.warning("Unsupported pattern: '{}'. Ignoring it.".format(line)) refs = {} - with open(args.ref_text, 'r', encoding='utf-8') as f: + with open(args.ref_text, "r", encoding="utf-8") as f: for line in f: utt_id, text = line.strip().split(None, 1) assert utt_id not in refs, utt_id refs[utt_id] = text wer_counter = Counter() - with open(args.hyp_text, 'r', encoding='utf-8') as f: + with open(args.hyp_text, "r", encoding="utf-8") as f: for line in f: utt_id, text = line.strip().split(None, 1) assert utt_id in refs, utt_id @@ -91,19 +91,19 @@ def main(args): _, _, counter = edit_distance(ref_list, hyp_list) wer_counter += counter - assert wer_counter['words'] > 0 + assert wer_counter["words"] > 0 wer = float( - wer_counter['sub'] + wer_counter['ins'] + wer_counter['del'] - ) / wer_counter['words'] * 100 - sub = float(wer_counter['sub']) / wer_counter['words'] * 100 - ins = float(wer_counter['ins']) / wer_counter['words'] * 100 - dlt = float(wer_counter['del']) / wer_counter['words'] * 100 + wer_counter["sub"] + wer_counter["ins"] + wer_counter["del"] + ) / wer_counter["words"] * 100 + sub = float(wer_counter["sub"]) / wer_counter["words"] * 100 + ins = float(wer_counter["ins"]) / wer_counter["words"] * 100 + dlt = float(wer_counter["del"]) / wer_counter["words"] * 100 - print('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}'.format( - wer, sub, ins, dlt, wer_counter['words'])) + print("WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format( + wer, sub, ins, dlt, wer_counter["words"])) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/estimate_initial_state_prior_from_alignments.py b/espresso/tools/estimate_initial_state_prior_from_alignments.py index a1d1111069..9f48da1efd 100755 --- a/espresso/tools/estimate_initial_state_prior_from_alignments.py +++ b/espresso/tools/estimate_initial_state_prior_from_alignments.py @@ -13,7 +13,7 @@ try: import kaldi_io except ImportError: - raise ImportError('Please install kaldi_io with: pip install kaldi_io') + raise ImportError("Please install kaldi_io with: pip install kaldi_io") logging.basicConfig( diff --git a/espresso/tools/lexical_prefix_tree.py b/espresso/tools/lexical_prefix_tree.py index 79a0e8f381..d97281404a 100644 --- a/espresso/tools/lexical_prefix_tree.py +++ b/espresso/tools/lexical_prefix_tree.py @@ -24,8 +24,8 @@ def lexical_prefix_tree( Return: root (Node): the root of the prefix tree, where each node has the fields: - ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). - 'children' is subword_idx -> node, and 'word_set' is (first-1, last), + ("children": Dict[int,Node], "word_idx": int, "word_set": Tuple[int]). + "children" is subword_idx -> node, and "word_set" is (first-1, last), where [first, last] is the range of the word indexes (inclusive) in the word dictionary who share the same prefix at that node. We assume words in the word dictionary are in lexical order. @@ -43,8 +43,11 @@ def __init__(self, children={}, word_idx=-1, word_set=None): for widx in range(len(word_dict)): if widx not in special_symbols: # skip , , # tokenize a word into a list of subwords - subwords = subword_tokenizer(word_dict[widx]) \ - if subword_tokenizer is not None else list(word_dict[widx]) + subwords = ( + subword_tokenizer(word_dict[widx]) + if subword_tokenizer is not None + else list(word_dict[widx]) + ) if any(subword_dict.index(s) == subword_dict.unk() for s in subwords): # skip words containing any unknown subwords continue diff --git a/espresso/tools/text2token.py b/espresso/tools/text2token.py index 455d8ebe1a..2f31693650 100755 --- a/espresso/tools/text2token.py +++ b/espresso/tools/text2token.py @@ -12,20 +12,27 @@ def get_parser(): parser = argparse.ArgumentParser( - description='Convert transcripts into tokens and write them to stdout') + description="Convert transcripts into tokens and write them to stdout" + ) # fmt: off - parser.add_argument('--skip-ncols', default=0, type=int, - help='skip first n columns') - parser.add_argument('--space', default='', type=str, - help='space symbol') - parser.add_argument('--ends-with-space', default=True, type=bool, - help='Whether to append to the end of each ' - 'tokenized sentence.') - parser.add_argument('--non-lang-syms', default=None, type=str, - help='path to a file listing non-linguistic symbols, ' - 'e.g., etc. One entry per line.') - parser.add_argument('text', type=str, nargs='?', - help='input text') + parser.add_argument( + "--skip-ncols", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) + parser.add_argument( + "--ends-with-space", default=True, type=bool, + help="whether to append to the end of each tokenized sentence." + ) + parser.add_argument( + "--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, " + "e.g., etc. One entry per line." + ) + parser.add_argument( + "text", type=str, nargs="?", help="input text" + ) # fmt: on return parser @@ -34,29 +41,29 @@ def get_parser(): def main(args): nls = None if args.non_lang_syms is not None: - with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + with open(args.non_lang_syms, "r", encoding="utf-8") as f: nls = [x.rstrip() for x in f.readlines()] - with (open(args.text, 'r', encoding='utf-8') if args.text else sys.stdin) as f: + with (open(args.text, "r", encoding="utf-8") if args.text else sys.stdin) as f: for line in f: entry = line.rstrip().split() tokenized = tokenize( - ' '.join(entry[args.skip_ncols:]), + " ".join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls, ) if args.skip_ncols > 0: if args.ends_with_space: - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized + ' ' + args.space) + print(" ".join(entry[: args.skip_ncols]) + " " + tokenized + " " + args.space) else: - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + print(" ".join(entry[: args.skip_ncols]) + " " + tokenized) else: if args.ends_with_space: - print(tokenized + ' ' + args.space) + print(tokenized + " " + args.space) else: print(tokenized) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/text2vocabulary.py b/espresso/tools/text2vocabulary.py index 047d439575..2a054e4f90 100755 --- a/espresso/tools/text2vocabulary.py +++ b/espresso/tools/text2vocabulary.py @@ -12,50 +12,56 @@ logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stderr, ) -logger = logging.getLogger('espresso.tools.text2vocabulary') +logger = logging.getLogger("espresso.tools.text2vocabulary") def get_parser(): - parser = argparse.ArgumentParser( - description='Create a vocabulary from text files') + parser = argparse.ArgumentParser(description="Create a vocabulary from text files") # fmt: off - parser.add_argument('--skip-ncols', default=0, type=int, - help='skip first n columns') - parser.add_argument('--cutoff', default=0, type=int, - help='cut-off frequency') - parser.add_argument('--vocabsize', default=20000, type=int, - help='vocabulary size') - parser.add_argument('--exclude', type=str, default=None, - help='space separated, list of excluding words, ' - 'e.g., etc.') - parser.add_argument('--vocab', type=str, default=None, - help='path to the vocabulary file. If not None, calculate' - 'OOV stats with the provided vocabulary and output the ' - 'same vocabulary with word counts') - parser.add_argument('--valid-text', type=str, default=None, - help='path to the validation text file') - parser.add_argument('--test-text', type=str, default=None, - help='colon separated paths to the test text file(s)') - parser.add_argument('text_files', nargs='*', - help='input text files') + parser.add_argument( + "--skip-ncols", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--cutoff", default=0, type=int, help="cut-off frequency" + ) + parser.add_argument( + "--vocabsize", default=20000, type=int, help="vocabulary size" + ) + parser.add_argument( + "--exclude", type=str, default=None, + help="space separated, list of excluding words, e.g., etc." + ) + parser.add_argument( + "--vocab", type=str, default=None, + help="path to the vocabulary file. If not None, calculate OOV stats with " + "the provided vocabulary and output the same vocabulary with word counts" + ) + parser.add_argument( + "--valid-text", type=str, default=None, help="path to the validation text file" + ) + parser.add_argument( + "--test-text", type=str, default=None, + help="colon separated paths to the test text file(s)" + ) + parser.add_argument("text_files", nargs="*", help="input text files") # fmt: on return parser def main(args): - exclude = args.exclude.split(' ') if args.exclude is not None else [] + exclude = args.exclude.split(" ") if args.exclude is not None else [] if len(args.text_files) == 0: - args.text_files.append('-') + args.text_files.append("-") counter = Counter() for fn in args.text_files: - with (open(fn, 'r', encoding='utf-8') if fn != '-' else sys.stdin) as f: + with (open(fn, "r", encoding="utf-8") if fn != "-" else sys.stdin) as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] @@ -75,8 +81,8 @@ def main(args): most_common = most_common[:cutoff_point] vocab_set = set(list(zip(*most_common))[0]) else: - logger.info('using the provided vocabulary:') - with open(args.vocab, 'r', encoding='utf-8') as f: + logger.info("using the provided vocabulary:") + with open(args.vocab, "r", encoding="utf-8") as f: vocab_set = set([line.rstrip().split()[0] for line in f]) most_common = [] for word in vocab_set: @@ -85,46 +91,46 @@ def main(args): # words in vocabulary are lexically sorted for w, c in sorted(most_common, key=lambda x: x[0]): - print('{} {:d}'.format(w, c)) + print("{} {:d}".format(w, c)) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('training set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("training set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) if args.vocab is None: - logger.info(' cutoff frequency={:d}'.format(cutoff_freq)) + logger.info(" cutoff frequency={:d}".format(cutoff_freq)) if args.valid_text is not None: total_count = 0 invocab_count = 0 - with open(args.valid_text, 'r', encoding='utf-8') as f: + with open(args.valid_text, "r", encoding="utf-8") as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('validation set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("validation set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) if args.test_text is not None: for k, path in enumerate(args.test_text.split(os.pathsep)): total_count = 0 invocab_count = 0 - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('test set{}:'.format(k) if k > 0 else 'test set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("test set{}:".format(k) if k > 0 else "test set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index 19fe51ab5e..0ee0b728d1 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -13,15 +13,15 @@ from fairseq import utils -def tokenize(sent, space='', non_lang_syms=None): +def tokenize(sent, space="", non_lang_syms=None): assert isinstance(sent, str) - sent = ' '.join(sent.strip().split()) + sent = " ".join(sent.strip().split()) match_pos = [] if non_lang_syms is not None: assert isinstance(non_lang_syms, list) if len(non_lang_syms) > 0: - prog = re.compile('|'.join(map(re.escape, non_lang_syms))) + prog = re.compile("|".join(map(re.escape, non_lang_syms))) matches = prog.finditer(sent) for match in matches: match_pos.append([match.start(), match.end()]) @@ -34,8 +34,8 @@ def tokenize(sent, space='', non_lang_syms=None): i = end_pos tokens.extend([token for token in sent[i:]]) - tokens = [space if token == ' ' else token for token in tokens] - return ' '.join(tokens) + tokens = [space if token == " " else token for token in tokens] + return " ".join(tokens) def collate_frames(values, pad_value=0.0, left_pad=False, pad_to_length=None, pad_to_multiple=1): @@ -119,7 +119,7 @@ def plot_attention(attention, hypo_sent, utt_id, save_dir): """ try: import matplotlib as mpl - mpl.use('Agg') + mpl.use("Agg") import matplotlib.pyplot as plt except ImportError: raise ImportError( @@ -131,8 +131,8 @@ def plot_attention(attention, hypo_sent, utt_id, save_dir): attn = attention.data.numpy() plt.matshow(attn) plt.title(hypo_sent, fontsize=8) - filename = os.path.join(save_dir, utt_id + '.pdf') - plt.savefig(filename, bbox_inches='tight') + filename = os.path.join(save_dir, utt_id + ".pdf") + plt.savefig(filename, bbox_inches="tight") plt.close() @@ -149,8 +149,8 @@ def edit_distance(ref, hyp): dist: edit distance matrix of size len(ref) x len(hyp) steps: list of edit steps counter: object of collections.Counter containing counts of - reference words ('words'), number of correct words ('corr'), - substitutions ('sub'), insertions ('ins'), deletions ('del'). + reference words ("words"), number of correct words ("corr"), + substitutions ("sub"), insertions ("ins"), deletions ("del"). """ assert isinstance(ref, list) and isinstance(hyp, list) @@ -182,23 +182,23 @@ def edit_distance(ref, hyp): i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] and ref[i - 1] == hyp[j - 1] ): - steps.append('corr') + steps.append("corr") i, j = i - 1, j - 1 elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + 1: assert ref[i - 1] != hyp[j - 1] - steps.append('sub') + steps.append("sub") i, j = i - 1, j - 1 elif j >= 1 and dist[i][j] == dist[i][j - 1] + 1: - steps.append('ins') + steps.append("ins") j = j - 1 else: assert i >= 1 and dist[i][j] == dist[i - 1][j] + 1 - steps.append('del') + steps.append("del") i = i - 1 steps = steps[::-1] counter = Counter( - {'words': len(ref), 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0} + {"words": len(ref), "corr": 0, "sub": 0, "ins": 0, "del": 0} ) counter.update(steps) @@ -212,7 +212,7 @@ def aligned_print(ref, hyp, steps): Args: ref: list of words obtained by splitting reference sentence string hyp: list of words obtained by splitting hypothesis sentence string - steps: list of edit steps with elements 'corr', 'sub', 'ins' or 'del'. + steps: list of edit steps with elements "corr", "sub", "ins" or "del". Return: out_str: aligned reference and hypothesis string with edit steps. @@ -223,70 +223,76 @@ def aligned_print(ref, hyp, steps): if len(steps) == 0: # in case both ref and hyp are empty assert len(ref) == 0 and len(hyp) == 0 - out_str = 'REF: \nHYP: \nSTP: \nWER: {:.2f}%\n\n'.format(0.) + out_str = "REF: \nHYP: \nSTP: \nWER: {:.2f}%\n\n".format(0.0) return out_str - out_str = 'REF: ' + out_str = "REF: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) < len(hyp[hyp_idx]): - out_str += ref[ref_idx] + \ - ' ' * (len(hyp[hyp_idx]) - len(ref[ref_idx])) + delim + out_str += ( + ref[ref_idx] + " " * (len(hyp[hyp_idx]) - len(ref[ref_idx])) + delim + ) else: out_str += ref[ref_idx] + delim - elif steps[i] == 'ins': - idx = i - steps[:i].count('del') - out_str += ' ' * len(hyp[idx]) + delim + elif steps[i] == "ins": + idx = i - steps[: i].count("del") + out_str += " " * len(hyp[idx]) + delim else: - assert steps[i] == 'del' or steps[i] == 'corr' - idx = i - steps[:i].count('ins') + assert steps[i] == "del" or steps[i] == "corr" + idx = i - steps[: i].count("ins") out_str += ref[idx] + delim - out_str += 'HYP: ' + out_str += "HYP: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) > len(hyp[hyp_idx]): - out_str += hyp[hyp_idx] + \ - ' ' * (len(ref[ref_idx]) - len(hyp[hyp_idx])) + delim + out_str += ( + hyp[hyp_idx] + " " * (len(ref[ref_idx]) - len(hyp[hyp_idx])) + + delim + ) else: out_str += hyp[hyp_idx] + delim - elif steps[i] == 'del': - idx = i - steps[:i].count('ins') - out_str += ' ' * len(ref[idx]) + delim + elif steps[i] == "del": + idx = i - steps[: i].count("ins") + out_str += " " * len(ref[idx]) + delim else: - assert steps[i] == 'ins' or steps[i] == 'corr' - idx = i - steps[:i].count('del') + assert steps[i] == "ins" or steps[i] == "corr" + idx = i - steps[: i].count("del") out_str += hyp[idx] + delim - out_str += 'STP: ' + out_str += "STP: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) > len(hyp[hyp_idx]): - out_str += 'S' + ' ' * (len(ref[ref_idx]) - 1) + delim + out_str += "S" + " " * (len(ref[ref_idx]) - 1) + delim else: - out_str += 'S' + ' ' * (len(hyp[hyp_idx]) - 1) + delim - elif steps[i] == 'ins': - idx = i - steps[:i].count('del') - out_str += 'I' + ' ' * (len(hyp[idx]) - 1) + delim + out_str += "S" + " " * (len(hyp[hyp_idx]) - 1) + delim + elif steps[i] == "ins": + idx = i - steps[: i].count("del") + out_str += "I" + " " * (len(hyp[idx]) - 1) + delim else: - assert steps[i] == 'del' or steps[i] == 'corr' - idx = i - steps[:i].count('ins') - sym = 'D' if steps[i] == 'del' else ' ' - out_str += sym + ' ' * (len(ref[idx]) - 1) + delim + assert steps[i] == "del" or steps[i] == "corr" + idx = i - steps[: i].count("ins") + sym = "D" if steps[i] == "del" else " " + out_str += sym + " " * (len(ref[idx]) - 1) + delim counter = Counter(steps) - wer = float(counter['sub'] + counter['ins'] + counter['del']) / len(ref) \ - * 100 if len(ref) > 0 else 0. - out_str += 'WER: ' + '{:.2f}%'.format(wer) + '\n' - out_str += '\n' + wer = ( + float(counter["sub"] + counter["ins"] + counter["del"]) / len(ref) * 100 + if len(ref) > 0 + else 0.0 + ) + out_str += "WER: " + "{:.2f}%".format(wer) + "\n" + out_str += "\n" return out_str diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 5314682d94..839a9b0c24 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -32,52 +32,54 @@ def reset(self): def parse_wer_output_filter(self, wer_output_filter): if wer_output_filter: - with open(PathManager.get_local_path(wer_output_filter), 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(wer_output_filter), "r", encoding="utf-8") as f: for line in f: line = line.strip() - if line.startswith('#!') or line == '': + if line.startswith("#!") or line == "": continue - elif line.startswith('s/'): - m = re.match(r's/(.+)/(.*)/g', line) + elif line.startswith("s/"): + m = re.match(r"s/(.+)/(.*)/g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) - elif line.startswith('s:'): - m = re.match(r's:(.+):(.*):g', line) + elif line.startswith("s:"): + m = re.match(r"s:(.+):(.*):g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: - logger.warning('Unsupported pattern: "{}". Ignoring it'.format(line)) + logger.warning("Unsupported pattern: '{}'. Ignoring it".format(line)) def add_prediction(self, utt_id, pred): if not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) + raise TypeError("utt_id must be a string(got {})".format(type(utt_id))) if not isinstance(pred, str): - raise TypeError('pred must be a string(got {})'.format(type(pred))) + raise TypeError("pred must be a string(got {})".format(type(pred))) - assert utt_id not in self.char_results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) - self.char_results[utt_id] = pred + '\n' + assert ( + utt_id not in self.char_results + ), "Duplicated utterance id detected: {}".format(utt_id) + self.char_results[utt_id] = pred + "\n" pred_words = self.dictionary.wordpiece_decode(pred) - assert utt_id not in self.results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) - self.results[utt_id] = pred_words + '\n' + assert ( + utt_id not in self.results + ), "Duplicated utterance id detected: {}".format(utt_id) + self.results[utt_id] = pred_words + "\n" def add_evaluation(self, utt_id, ref, pred): if not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) + raise TypeError("utt_id must be a string(got {})".format(type(utt_id))) if not isinstance(ref, str): - raise TypeError('ref must be a string (got {})'.format(type(ref))) + raise TypeError("ref must be a string (got {})".format(type(ref))) if not isinstance(pred, str): - raise TypeError('pred must be a string(got {})'.format(type(pred))) + raise TypeError("pred must be a string(got {})".format(type(pred))) # filter out any non_lang_syms from ref and pred - non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None) + non_lang_syms = getattr(self.dictionary, "non_lang_syms", None) assert non_lang_syms is None or isinstance(non_lang_syms, list) if non_lang_syms is not None and len(non_lang_syms) > 0: ref_list, pred_list = ref.strip().split(), pred.strip().split() - ref = ' '.join([x for x in ref_list if x not in non_lang_syms]) - pred = ' '.join([x for x in pred_list if x not in non_lang_syms]) + ref = " ".join([x for x in ref_list if x not in non_lang_syms]) + pred = " ".join([x for x in pred_list if x not in non_lang_syms]) # char level counts _, _, counter = speech_utils.edit_distance( @@ -99,45 +101,48 @@ def add_evaluation(self, utt_id, ref, pred): ref_word_list, pred_word_list, ) self.word_counter += counter - assert utt_id not in self.aligned_results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) + assert ( + utt_id not in self.aligned_results + ), "Duplicated utterance id detected: {}".format(utt_id) self.aligned_results[utt_id] = speech_utils.aligned_print( ref_word_list, pred_word_list, steps, ) def cer(self): - assert self.char_counter['words'] > 0 + assert self.char_counter["words"] > 0 cer = float( - self.char_counter['sub'] + self.char_counter['ins'] + self.char_counter['del'] - ) / self.char_counter['words'] * 100 - sub = float(self.char_counter['sub']) / self.char_counter['words'] * 100 - ins = float(self.char_counter['ins']) / self.char_counter['words'] * 100 - dlt = float(self.char_counter['del']) / self.char_counter['words'] * 100 + self.char_counter["sub"] + self.char_counter["ins"] + self.char_counter["del"] + ) / self.char_counter["words"] * 100 + sub = float(self.char_counter["sub"]) / self.char_counter["words"] * 100 + ins = float(self.char_counter["ins"]) / self.char_counter["words"] * 100 + dlt = float(self.char_counter["del"]) / self.char_counter["words"] * 100 return cer, sub, ins, dlt def wer(self): - assert self.word_counter['words'] > 0 + assert self.word_counter["words"] > 0 wer = float( - self.word_counter['sub'] + self.word_counter['ins'] + self.word_counter['del'] - ) / self.word_counter['words'] * 100 - sub = float(self.word_counter['sub']) / self.word_counter['words'] * 100 - ins = float(self.word_counter['ins']) / self.word_counter['words'] * 100 - dlt = float(self.word_counter['del']) / self.word_counter['words'] * 100 + self.word_counter["sub"] + self.word_counter["ins"] + self.word_counter["del"] + ) / self.word_counter["words"] * 100 + sub = float(self.word_counter["sub"]) / self.word_counter["words"] * 100 + ins = float(self.word_counter["ins"]) / self.word_counter["words"] * 100 + dlt = float(self.word_counter["del"]) / self.word_counter["words"] * 100 return wer, sub, ins, dlt def tot_word_error(self): - return self.word_counter['sub'] + self.word_counter['ins'] + \ - self.word_counter['del'] + return ( + self.word_counter["sub"] + self.word_counter["ins"] + self.word_counter["del"] + ) def tot_word_count(self): - return self.word_counter['words'] + return self.word_counter["words"] def tot_char_error(self): - return self.char_counter['sub'] + self.char_counter['ins'] + \ - self.char_counter['del'] + return ( + self.char_counter["sub"] + self.char_counter["ins"] + self.char_counter["del"] + ) def tot_char_count(self): - return self.char_counter['words'] + return self.char_counter["words"] def add_ordered_utt_list(self, *args): if len(args) == 1 and isinstance(args[0], list): # aleady a list of utterance ids @@ -145,7 +150,7 @@ def add_ordered_utt_list(self, *args): return self.ordered_utt_list = [] for text_file in args: - with open(PathManager.get_local_path(text_file), 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(text_file), "r", encoding="utf-8") as f: one_utt_list = [line.strip().split()[0] for line in f] self.ordered_utt_list.extend(one_utt_list) if len(self.char_results): @@ -156,34 +161,34 @@ def add_ordered_utt_list(self, *args): assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) def print_char_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.char_results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + ' ' + self.char_results[utt_id] + res += utt_id + " " + self.char_results[utt_id] else: for utt_id in self.char_results: - res += utt_id + ' ' + self.char_results[utt_id] + res += utt_id + " " + self.char_results[utt_id] return res def print_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + ' ' + self.results[utt_id] + res += utt_id + " " + self.results[utt_id] else: for utt_id in self.results: - res += utt_id + ' ' + self.results[utt_id] + res += utt_id + " " + self.results[utt_id] return res def print_aligned_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + '\n' + self.aligned_results[utt_id] + res += utt_id + "\n" + self.aligned_results[utt_id] else: for utt_id in self.aligned_results: - res += utt_id + '\n' + self.aligned_results[utt_id] + res += utt_id + "\n" + self.aligned_results[utt_id] return res diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 13fc0e98e1..786acca27d 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -153,7 +153,7 @@ if [ ${stage} -le 4 ]; then for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -173,7 +173,7 @@ if [ ${stage} -le 5 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((16000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 32000 --batch-size 1024 --curriculum 1 \ @@ -194,7 +194,7 @@ if [ ${stage} -le 6 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -245,7 +245,7 @@ if [ ${stage} -le 8 ]; then opts="$opts --max-epoch 30 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10" fi fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((8000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 24 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ @@ -276,7 +276,7 @@ if [ ${stage} -le 9 ]; then for dataset in $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --batch-size 24 \ + --task speech_recognition_espresso --max-tokens 15000 --batch-size 24 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 95ab5403ad..f3be25a986 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -191,7 +191,7 @@ if [ $stage -le 3 ]; then test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -211,7 +211,7 @@ if [ $stage -le 4 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((1000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --batch-size 1024 \ @@ -232,7 +232,7 @@ if [ $stage -le 5 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -284,7 +284,7 @@ if [ $stage -le 7 ]; then opts="$opts --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14" fi fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((3000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 48 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ @@ -313,7 +313,7 @@ if [ $stage -le 8 ]; then for dataset in $test_set; do decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --batch-size 48 \ + --task speech_recognition_espresso --max-tokens 24000 --batch-size 48 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --non-lang-syms $nlsyms --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index ae3ea5fe75..c0bba6d613 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -158,7 +158,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing char text..." mkdir -p $lmdatadir/log ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 30 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -168,7 +168,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing word text..." mkdir -p $wordlmdatadir/log ${decode_cmd} $wordlmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 30 --srcdict $wordlmdict --only-source \ --trainpref $wordlmdatadir/train \ --validpref $wordlmdatadir/$valid_set \ @@ -189,7 +189,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --batch-size 128 \ @@ -206,7 +206,7 @@ if [ ${stage} -le 5 ] && ! $use_wordlm; then echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do log_file=$lmdir/log/evaluation_$gen_subset.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --batch-size 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -219,7 +219,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then mkdir -p $wordlmdir/log log_file=$wordlmdir/log/train.log [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 6400 --batch-size 256 \ @@ -237,7 +237,7 @@ if [ ${stage} -le 7 ] && $use_wordlm; then echo "Stage 7: word LM Evaluation" for gen_subset in valid test; do log_file=$wordlmdir/log/evaluation_$gen_subset.log - python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ --max-tokens 12800 --batch-size 512 --sample-break-mode eos \ --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -283,7 +283,7 @@ if [ ${stage} -le 9 ]; then opts="$opts --lr-shrink 0.5 --start-reduce-lr-epoch 11" opts="$opts --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6" fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((800/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((2000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --batch-size 32 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ @@ -316,7 +316,7 @@ if [ ${stage} -le 10 ]; then for dataset in $valid_set $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --batch-size 32 \ + --task speech_recognition_espresso --max-tokens 20000 --batch-size 32 \ --num-shards 1 --shard-id 0 --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 3d524db80e..c63240af0b 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -185,7 +185,7 @@ if [ ${stage} -le 6 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ @@ -212,7 +212,7 @@ if [ ${stage} -le 7 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/chain_e2e --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/chain_e2e --cpu --task speech_recognition_hybrid \ --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ diff --git a/examples/asr_wsj/run_chain_e2e_bichar.sh b/examples/asr_wsj/run_chain_e2e_bichar.sh index 8829f28e6b..23ead11f71 100755 --- a/examples/asr_wsj/run_chain_e2e_bichar.sh +++ b/examples/asr_wsj/run_chain_e2e_bichar.sh @@ -185,7 +185,7 @@ if [ ${stage} -le 6 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e_bichar --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e_bichar --task speech_recognition_hybrid --seed 1 \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ @@ -212,7 +212,7 @@ if [ ${stage} -le 7 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/chain_e2e_bichar --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/chain_e2e_bichar --cpu --task speech_recognition_hybrid \ --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh index d66b937a3c..e58fdc466a 100755 --- a/examples/asr_wsj/run_xent.sh +++ b/examples/asr_wsj/run_xent.sh @@ -165,7 +165,7 @@ if [ ${stage} -le 5 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 \ --log-interval $((100/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --batch-size 256 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 256 --ddp-backend no_c10d --update-freq $update_freq \ @@ -192,10 +192,10 @@ if [ ${stage} -le 6 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=exp/$gmm/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/xent --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/xent --cpu --task speech_recognition_hybrid \ --max-tokens 256000 --batch-size 256 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB --chunk-width 150 --chunk-left-context 10 --chunk-right-context 10 --label-delay -3 \ - --max-source-positions 9999 --path $path --apply-log-softmax \| \ + --max-source-positions 9999 --path $path --apply-log-softmax True \| \ latgen-faster-mapped --max-active=7000 --min-active=20 --beam=15 --lattice-beam=8 --acoustic-scale=0.1 \ --allow-partial=true --word-symbol-table="$graph_dir/words.txt" \ exp/$gmm/final.mdl $graph_dir/HCLG.fst ark:- "ark:|gzip -c >$dir/decode_${lmtype}_${data_affix}/lat.JOB.gz" || exit 1 diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3bdc6d16d4..1db625f6e2 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -797,6 +797,47 @@ class GenerationConfig(FairseqDataclass): default=False, metadata={"help": "if set, dont use seed for initializing random generators"}, ) + # for espresso.speech_recognize.py + eos_factor: Optional[float] = field( + default=None, + metadata={ + "help": "only consider emitting EOS if its score is no less than " + "the specified factor of the best candidate score" + }, + ) + subwordlm_weight: Optional[float] = field( + default=0.8, + metadata={ + "help": "subword LM weight relative to word LM. Only relevant to " + "MultiLevelLanguageModel as an external LM" + }, + ) + oov_penalty: Optional[float] = field( + default=1e-4, + metadata={"help": "oov penalty with the pretrained external LM"}, + ) + disable_open_vocab: Optional[bool] = field( + default=False, + metadata={ + "help": "whether open vocabulary mode is enabled with the " + "pretrained external LM" + }, + ) + # for espresso.dump_posteriors.py + apply_log_softmax: Optional[bool] = field( + default=False, + metadata={ + "help": "apply log-softmax to the neural network outputs for Xent " + "hybrid systems; otherwise use the raw outputs" + }, + ) + state_prior_file: Optional[str] = field( + default=None, + metadata={ + "help": "state prior file. If provided, use this file instead of " + "that from the checkpoint" + }, + ) @dataclass diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index 96173be1e9..d00b95c885 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -13,7 +13,6 @@ import torch from espresso.data import AsrDictionary - import espresso.tools.utils as utils @@ -91,8 +90,11 @@ def test_speech_tokenizer(self): tensor, extra_symbols_to_ignore={self.dictionary.pad()} ) expected_tokens = " ".join( - [token if self.dictionary.index(token) != self.dictionary.unk() else - self.dictionary.unk_word for token in tokens.split(" ")] + [ + token if self.dictionary.index(token) != self.dictionary.unk() + else self.dictionary.unk_word + for token in tokens.split(" ") + ] ) self.assertEqual(reconstructed_tokens, expected_tokens)