Skip to content

Commit

Permalink
code adaptation/changes according to the commits on Oct 18-Nov 3, 202…
Browse files Browse the repository at this point in the history
…0 (lots of changes, mostly for adapting to hydra configs and code formatting)
  • Loading branch information
freewym committed Nov 4, 2020
1 parent c17beab commit b3ed99c
Show file tree
Hide file tree
Showing 41 changed files with 1,563 additions and 1,024 deletions.
2 changes: 1 addition & 1 deletion README_fairseq.md
Expand Up @@ -112,7 +112,7 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example

# Requirements and Installation

* [PyTorch](http://pytorch.org/) version >= 1.4.0
* [PyTorch](http://pytorch.org/) version >= 1.5.0
* Python version >= 3.6
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **To install fairseq** and develop locally:
Expand Down
35 changes: 23 additions & 12 deletions espresso/criterions/label_smoothed_cross_entropy_v2.py
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from omegaconf import II
import logging
import numpy as np

Expand All @@ -16,6 +15,7 @@
from fairseq.data import data_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from omegaconf import II


logger = logging.getLogger(__name__)
Expand All @@ -26,7 +26,7 @@

@dataclass
class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
sentence_avg: bool = II("optimization.sentence_avg")
label_smoothing: float = field(
default=0.0,
metadata={
Expand Down Expand Up @@ -85,7 +85,7 @@ def temporal_label_smoothing_prob_mask(
prob_mask[:, :, padding_index] = 0 # clear cumulative count on <pad>
prob_mask = prob_mask.float() # convert to float
sum_prob = prob_mask.sum(-1, keepdim=True)
sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem
sum_prob[sum_prob.squeeze(-1).eq(0.0)] = 1.0 # to deal with the "division by 0" problem
prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1))
return prob_mask

Expand All @@ -109,26 +109,32 @@ def label_smoothed_nll_loss(
raise ValueError("Unsupported smoothing type: {}".format(smoothing_type))
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.)
smooth_loss.masked_fill_(pad_mask, 0.)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1) if smoothing_type == "uniform" else epsilon
loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss


@register_criterion("label_smoothed_cross_entropy_v2", dataclass=LabelSmoothedCrossEntropyV2CriterionConfig)
class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion):

def __init__(
self, task, sentence_avg, label_smoothing, smoothing_type,
print_training_sample_interval, unigram_pseudo_count,
ignore_prefix_size=0, report_accuracy=False,
self,
task,
sentence_avg,
label_smoothing,
smoothing_type,
print_training_sample_interval,
unigram_pseudo_count,
ignore_prefix_size=0,
report_accuracy=False,
):
super().__init__(
task, sentence_avg, label_smoothing,
Expand All @@ -149,7 +155,7 @@ def __init__(
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, '__dataclass', None)
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())

Expand Down Expand Up @@ -212,8 +218,13 @@ def compute_loss(
padding_index=self.padding_idx,
) if smoothing_type == "temporal" else None
loss, nll_loss = label_smoothed_nll_loss(
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
smoothing_type=smoothing_type, prob_mask=prob_mask,
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
smoothing_type=smoothing_type,
prob_mask=prob_mask,
unigram_tensor=self.unigram_tensor,
)
return loss, nll_loss, lprobs
Expand Down
9 changes: 5 additions & 4 deletions espresso/criterions/lf_mmi_loss.py
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from omegaconf import II
import logging
import math

Expand All @@ -14,16 +13,17 @@
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging import metrics
from omegaconf import II


logger = logging.getLogger(__name__)


@dataclass
class LatticeFreeMMICriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
sentence_avg: bool = II("optimization.sentence_avg")
denominator_fst_path: str = field(
default=None, metadata={"help": "path to the denominator fst file"}
default="???", metadata={"help": "path to the denominator fst file"}
)
leaky_hmm_coefficient: float = field(
default=1.0e-05,
Expand Down Expand Up @@ -215,10 +215,11 @@ def compute_loss(self, net_output, sample, reduce=True):
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=7
)
Expand Down
50 changes: 32 additions & 18 deletions espresso/data/asr_chain_dataset.py
Expand Up @@ -12,7 +12,7 @@

import torch

from fairseq.data import data_utils, FairseqDataset
from fairseq.data import FairseqDataset, data_utils

import espresso.tools.utils as speech_utils

Expand Down Expand Up @@ -48,12 +48,15 @@ def merge(key, pad_to_length=None):
raise ValueError("Invalid key.")

id = torch.LongTensor([s["id"] for s in samples])
src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None)
src_frames = merge(
"source",
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
if pad_to_length is not None or src_bucketed:
src_lengths = torch.IntTensor([
s["source"].ne(0.0).any(dim=1).int().sum() for s in samples
])
src_lengths = torch.IntTensor(
[s["source"].ne(0.0).any(dim=1).int().sum() for s in samples]
)
else:
src_lengths = torch.IntTensor([s["source"].size(0) for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
Expand Down Expand Up @@ -134,8 +137,7 @@ def filter_and_reorder(self, indices):
assert isinstance(indices, (list, np.ndarray))
indices = np.array(indices)
assert all(indices < len(self.utt_ids)) and all(indices >= 0)
assert len(np.unique(indices)) == len(indices), \
"Duplicate elements in indices."
assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices."
self.utt_ids = [self.utt_ids[i] for i in indices]
self.rxfiles = [self.rxfiles[i] for i in indices]
self.numerator_graphs = [self.numerator_graphs[i] for i in indices]
Expand Down Expand Up @@ -172,8 +174,15 @@ class AsrChainDataset(FairseqDataset):
"""

def __init__(
self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True,
num_buckets=0, pad_to_multiple=1,
self,
src,
src_sizes,
tgt=None,
tgt_sizes=None,
text=None,
shuffle=True,
num_buckets=0,
pad_to_multiple=1,
):
self.src = src
self.tgt = tgt
Expand All @@ -196,10 +205,15 @@ def __init__(
"Removed {} examples due to empty numerator graphs or missing entries, "
"{} remaining".format(num_removed, num_after_matching)
)
self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)

if num_buckets > 0:
from espresso.data import FeatBucketPadLengthDataset

self.src = FeatBucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
Expand All @@ -215,8 +229,7 @@ def __init__(
num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens)
for num_tokens in np.unique(self.bucketed_num_tokens)
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
Expand Down Expand Up @@ -293,7 +306,7 @@ def collater(self, samples, pad_to_length=None):
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{'source': source_pad_to_length}
{"source": source_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
Expand Down Expand Up @@ -327,7 +340,10 @@ def num_tokens(self, index):
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)

def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
Expand All @@ -339,9 +355,7 @@ def ordered_indices(self):
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[
np.argsort(self.tgt_sizes[indices], kind="mergesort")
]
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is padded_src_len
Expand All @@ -358,7 +372,7 @@ def prefetch(self, indices):
self.src.prefetch(indices)

def filter_indices_by_size(self, indices, max_sizes):
""" Filter a list of sample indices. Remove those that are longer
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
Expand Down

0 comments on commit b3ed99c

Please sign in to comment.