Skip to content

Commit

Permalink
code adaptation/changes according to the commits on Oct 2-15, 2020
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Nov 5, 2020
1 parent eb7356b commit 7996195
Show file tree
Hide file tree
Showing 33 changed files with 282 additions and 421 deletions.
4 changes: 2 additions & 2 deletions espresso/criterions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"):
criterion_name = file[: file.find(".py")]
importlib.import_module("espresso.criterions." + criterion_name)
file_name = file[: file.find(".py")]
importlib.import_module("espresso.criterions." + file_name)
9 changes: 1 addition & 8 deletions espresso/criterions/cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from omegaconf import II
import logging
import numpy as np

Expand All @@ -14,7 +13,6 @@
from fairseq.criterions import register_criterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig
from fairseq.data import data_utils
from fairseq.dataclass.utils import gen_parser_from_dataclass


logger = logging.getLogger(__name__)
Expand All @@ -30,7 +28,7 @@ class CrossEntropyV2CriterionConfig(CrossEntropyCriterionConfig):
)


@register_criterion("cross_entropy_v2")
@register_criterion("cross_entropy_v2", dataclass=CrossEntropyV2CriterionConfig)
class CrossEntropyV2Criterion(CrossEntropyCriterion):

def __init__(self, task, sentence_avg, print_training_sample_interval):
Expand All @@ -41,11 +39,6 @@ def __init__(self, task, sentence_avg, print_training_sample_interval):
self.epoch = 1
self.prev_num_updates = -1

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser. Optionally register config store"""
gen_parser_from_dataclass(parser, CrossEntropyV2CriterionConfig())

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample; periodically print out
randomly sampled predictions from the training set.
Expand Down
51 changes: 36 additions & 15 deletions espresso/criterions/label_smoothed_cross_entropy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from fairseq.data import data_utils
from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass


logger = logging.getLogger(__name__)
Expand All @@ -27,13 +27,24 @@
@dataclass
class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
label_smoothing: float = field(
default=0.0,
metadata={
"help": "epsilon for label smoothing, 0 means no label smoothing"
},
)
report_accuracy: bool = field(
default=False,
metadata={
"help": "report accuracy metric"
},
)
ignore_prefix_size: bool = field(
default=False,
metadata={
"help": "ignore first N tokens"
},
)
print_training_sample_interval: int = field(
default=500,
metadata={
Expand Down Expand Up @@ -111,14 +122,18 @@ def label_smoothed_nll_loss(
return loss, nll_loss


@register_criterion("label_smoothed_cross_entropy_v2")
@register_criterion("label_smoothed_cross_entropy_v2", dataclass=LabelSmoothedCrossEntropyV2CriterionConfig)
class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion):

def __init__(
self, task, sentence_avg, label_smoothing, smoothing_type,
print_training_sample_interval, unigram_pseudo_count,
ignore_prefix_size=0, report_accuracy=False,
):
super().__init__(task, sentence_avg, label_smoothing)
super().__init__(
task, sentence_avg, label_smoothing,
ignore_prefix_size=ignore_prefix_size, report_accuracy=report_accuracy,
)

self.dictionary = task.target_dictionary
self.smoothing_type = smoothing_type
Expand All @@ -131,10 +146,12 @@ def __init__(
self.unigram_tensor.div_(self.unigram_tensor.sum())
self.prev_num_updates = -1

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser. Optionally register config store"""
gen_parser_from_dataclass(parser, LabelSmoothedCrossEntropyV2CriterionConfig())
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, '__dataclass', None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample; periodically print out
Expand All @@ -159,6 +176,10 @@ def forward(self, model, sample, reduce=True):
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)

if (
hasattr(model, "num_updates") and model.training and
Expand All @@ -168,7 +189,7 @@ def forward(self, model, sample, reduce=True):
): # print a randomly sampled result every print_interval updates
self.prev_num_updates = model.num_updates
target = model.get_targets(sample, net_output)
pred = lprobs.argmax(-1).cpu() # bsz x len
pred = lprobs.view(target.size(0), -1, lprobs.size(-1)).argmax(-1).cpu() # bsz x len
assert pred.size() == target.size()
with data_utils.numpy_seed(model.num_updates):
i = np.random.randint(0, len(sample["id"]))
Expand All @@ -184,14 +205,14 @@ def forward(self, model, sample, reduce=True):
def compute_loss(
self, model, net_output, sample, reduce=True, smoothing_type="uniform"
):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
bsz = sample["target"].size(0)
prob_mask = temporal_label_smoothing_prob_mask(
lprobs, target, padding_index=self.padding_idx,
lprobs.view(bsz, -1, lprobs.size(-1)), target.view(bsz, -1),
padding_index=self.padding_idx,
) if smoothing_type == "temporal" else None
loss, nll_loss = label_smoothed_nll_loss(
lprobs.view(-1, lprobs.size(-1)), target.view(-1, 1), self.eps,
ignore_index=self.padding_idx, reduce=reduce,
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
smoothing_type=smoothing_type, prob_mask=prob_mask,
unigram_tensor=self.unigram_tensor,
)
Expand Down
23 changes: 8 additions & 15 deletions espresso/criterions/lf_mmi_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.logging import metrics


Expand All @@ -23,7 +22,6 @@
@dataclass
class LatticeFreeMMICriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
denominator_fst_path: str = field(
default=None, metadata={"help": "path to the denominator fst file"}
)
Expand Down Expand Up @@ -131,12 +129,12 @@ def backward(ctx, objf_grad):
return input_grad, None, None, None, None


@register_criterion("lattice_free_mmi")
@register_criterion("lattice_free_mmi", dataclass=LatticeFreeMMICriterionConfig)
class LatticeFreeMMICriterion(FairseqCriterion):

def __init__(
self, task, sentence_avg, denominator_fst_path,
leaky_hmm_coefficient, xent_regularize, output_l2_regularize,
self, task, sentence_avg, denominator_fst_path, leaky_hmm_coefficient,
xent_regularization_coefficient, output_l2_regularization_coefficient,
):
super().__init__(task)
try:
Expand All @@ -152,13 +150,8 @@ def __init__(
den_fst = simplefst.StdVectorFst.read(denominator_fst_path)
self.den_graph = ChainGraph(den_fst, initial_mode="leaky", final_mode="ones")
self.leaky_hmm_coefficient = leaky_hmm_coefficient
self.xent_regularize = xent_regularize
self.output_l2_regularize = output_l2_regularize

@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser. Optionally register config store"""
gen_parser_from_dataclass(parser, LatticeFreeMMICriterionConfig())
self.xent_regularize = xent_regularization_coefficient
self.output_l2_regularize = output_l2_regularization_coefficient

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down Expand Up @@ -218,8 +211,8 @@ def compute_loss(self, net_output, sample, reduce=True):

return loss, nll_loss

@staticmethod
def reduce_metrics(logging_outputs) -> None:
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
Expand Down
12 changes: 3 additions & 9 deletions espresso/criterions/subsampled_cross_entropy_with_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from fairseq.criterions import register_criterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.logging import metrics


Expand All @@ -23,7 +22,7 @@ class SubsampledCrossEntropyWithAccuracyCriterionConfig(CrossEntropyCriterionCon
pass


@register_criterion("subsampled_cross_entropy_with_accuracy")
@register_criterion("subsampled_cross_entropy_with_accuracy", dataclass=SubsampledCrossEntropyWithAccuracyCriterionConfig)
class SubsampledCrossEntropyWithAccuracyCriterion(CrossEntropyCriterion):

def __init__(self, task, sentence_avg):
Expand All @@ -34,11 +33,6 @@ def __init__(self, task, sentence_avg):
self.transpose_net_output = getattr(task, "transpose_net_output", True)
self.state_prior_update_interval = getattr(task, "state_prior_update_interval", None)

@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser. optionaly register config store"""
gen_parser_from_dataclass(parser, SubsampledCrossEntropyWithAccuracyCriterionConfig())

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand Down Expand Up @@ -107,8 +101,8 @@ def compute_loss(self, model, net_output, sample, reduce=True):

return loss, num_corr, num_tot, state_post

@staticmethod
def reduce_metrics(logging_outputs) -> None:
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
CrossEntropyCriterion.reduce_metrics(logging_outputs)
num_corr = sum(log.get("num_corr", 0) for log in logging_outputs)
Expand Down
5 changes: 5 additions & 0 deletions espresso/data/asr_chain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ def filter_indices_by_size(self, indices, max_sizes):
max_sizes,
)

@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return False

@property
def can_reuse_epoch_itr_across_epochs(self):
return False # to avoid running out of CPU RAM
Expand Down
5 changes: 5 additions & 0 deletions espresso/data/asr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,11 @@ def filter_indices_by_size(self, indices, max_sizes):
max_sizes,
)

@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return False

@property
def can_reuse_epoch_itr_across_epochs(self):
return False # to avoid running out of CPU RAM
Expand Down
2 changes: 1 addition & 1 deletion espresso/data/asr_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def load(cls, f, f_non_lang_syms=None):
if f_non_lang_syms is not None:
assert isinstance(f_non_lang_syms, str)
try:
with PathManager.open(f_non_lang_syms, "r", encoding="utf-8") as fd:
with open(PathManager.get_local_path(f_non_lang_syms), "r", encoding="utf-8") as fd:
non_lang_syms = [x.rstrip() for x in fd.readlines()]
except FileNotFoundError as fnfe:
raise fnfe
Expand Down
5 changes: 5 additions & 0 deletions espresso/data/asr_xent_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,11 @@ def filter_indices_by_size(self, indices, max_sizes):
max_sizes,
)

@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return False

@property
def can_reuse_epoch_itr_across_epochs(self):
return False # to avoid running out of CPU RAM
Expand Down
4 changes: 2 additions & 2 deletions espresso/data/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# automatically import any Python files in the encoders/ directory
for file in os.listdir(os.path.dirname(__file__)):
if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"):
module = file[:file.find(".py")]
importlib.import_module("espresso.data.encoders." + module)
file_name = file[: file.find(".py")]
importlib.import_module("espresso.data.encoders." + file_name)
Loading

0 comments on commit 7996195

Please sign in to comment.