Skip to content

Commit

Permalink
adds an option to include EOS in transducer model training;
Browse files Browse the repository at this point in the history
rewrites masked_copy_cached_state() to make it clearer and more general;
removes clone_cached_state();
code adaptation/changes according to the commits on Nov 2, 2022
  • Loading branch information
freewym committed Nov 4, 2022
1 parent db4eeb2 commit 8ec673f
Show file tree
Hide file tree
Showing 15 changed files with 254 additions and 116 deletions.
3 changes: 2 additions & 1 deletion espresso/criterions/ctc_loss.py
Expand Up @@ -12,10 +12,11 @@
import torch.nn.functional as F
from omegaconf import II

from fairseq import metrics, utils
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data import data_utils
from fairseq.dataclass import FairseqDataclass
from fairseq.logging import metrics
from fairseq.tasks import FairseqTask

logger = logging.getLogger(__name__)
Expand Down
17 changes: 12 additions & 5 deletions espresso/criterions/transducer_loss.py
Expand Up @@ -11,10 +11,11 @@
import torch
from omegaconf import II

from fairseq import metrics, utils
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data import data_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.logging import metrics
from fairseq.tasks import FairseqTask

logger = logging.getLogger(__name__)
Expand All @@ -36,6 +37,7 @@ class TransducerLossCriterionConfig(FairseqDataclass):
default="torchaudio",
metadata={"help": "choice of loss backend (native or torchaudio)"},
)
include_eos: bool = II("task.include_eos_in_transducer_loss")


@register_criterion("transducer_loss", dataclass=TransducerLossCriterionConfig)
Expand Down Expand Up @@ -64,6 +66,7 @@ def __init__(self, cfg: TransducerLossCriterionConfig, task: FairseqTask):
)
self.rnnt_loss = rnnt_loss

self.include_eos = cfg.include_eos
self.dictionary = task.target_dictionary
self.prev_num_updates = -1

Expand All @@ -73,13 +76,15 @@ def forward(self, model, sample, reduce=True):
) # B x T x U x V, B

if "target_lengths" in sample:
target_lengths = (
sample["target_lengths"].int() - 1
) # Note: ensure EOS is excluded
target_lengths = sample["target_lengths"].int()
if not self.include_eos:
target_lengths -= 1 # excludes EOS
else:
target_lengths = (
(
(sample["target"] != self.pad_idx)
if self.include_eos
else (sample["target"] != self.pad_idx)
& (sample["target"] != self.eos_idx)
)
.sum(-1)
Expand Down Expand Up @@ -124,7 +129,9 @@ def forward(self, model, sample, reduce=True):

loss = self.rnnt_loss(
net_output,
sample["target"][:, :-1].int().contiguous(), # exclude the last EOS column
(sample["target"] if self.include_eos else sample["target"][:, :-1])
.int()
.contiguous(),
encoder_out_lengths.int(),
target_lengths,
blank=self.blank_idx,
Expand Down
17 changes: 16 additions & 1 deletion espresso/data/asr_dataset.py
Expand Up @@ -21,6 +21,7 @@ def collate(
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
maybe_bos_idx=None,
pad_to_length=None,
pad_to_multiple=1,
src_bucketed=False,
Expand Down Expand Up @@ -89,11 +90,16 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
move_eos_to_beginning=(maybe_bos_idx is None),
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
if maybe_bos_idx is not None:
all_bos_vec = prev_output_tokens.new_full((1, 1), maybe_bos_idx).expand(
len(samples), 1
)
prev_output_tokens = torch.cat([all_bos_vec, prev_output_tokens], dim=1)
else:
ntokens = src_lengths.sum().item()

Expand Down Expand Up @@ -148,6 +154,10 @@ class AsrDataset(FairseqDataset):
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for teacher forcing (default: True).
prepend_bos_as_input_feeding (bool, optional): target prepended with BOS symbol
(instead of moving EOS to the beginning of that) as input feeding. This is
currently only for a transducer model training setting where EOS is retained
in target when evaluating the loss (default: False).
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
num_buckets (int, optional): if set to a value greater than 0, then
Expand Down Expand Up @@ -176,6 +186,7 @@ def __init__(
left_pad_target=False,
shuffle=True,
input_feeding=True,
prepend_bos_as_input_feeding=False,
constraints=None,
num_buckets=0,
src_lang_id=None,
Expand All @@ -193,6 +204,7 @@ def __init__(
self.left_pad_target = left_pad_target
self.shuffle = shuffle
self.input_feeding = input_feeding
self.prepend_bos_as_input_feeding = prepend_bos_as_input_feeding
self.constraints = constraints
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
Expand Down Expand Up @@ -334,6 +346,9 @@ def collater(self, samples, pad_to_length=None):
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
maybe_bos_idx=self.dictionary.bos()
if self.prepend_bos_as_input_feeding
else None,
pad_to_length=pad_to_length,
pad_to_multiple=self.pad_to_multiple,
src_bucketed=(self.buckets is not None),
Expand Down
3 changes: 1 addition & 2 deletions espresso/data/feat_text_dataset.py
Expand Up @@ -57,7 +57,7 @@ def __init__(
):
super().__init__()
assert len(utt_ids) == len(rxfiles)
self.dtype = np.float
self.dtype = float
self.utt_ids = utt_ids
self.rxfiles = rxfiles
self.size = len(utt_ids) # number of utterances
Expand Down Expand Up @@ -338,7 +338,6 @@ def __init__(
self, utt_ids: List[str], texts: List[str], dictionary=None, append_eos=True
):
super().__init__()
self.dtype = np.float
self.dictionary = dictionary
self.append_eos = append_eos
self.read_text(utt_ids, texts, dictionary)
Expand Down
13 changes: 8 additions & 5 deletions espresso/models/external_language_model.py
Expand Up @@ -9,8 +9,9 @@

from espresso.data import AsrDictionary
from espresso.tools.lexical_prefix_tree import lexical_prefix_tree
from espresso.tools.utils import clone_cached_state, tokenize
from espresso.tools.utils import tokenize
from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel
from fairseq.utils import apply_to_sample


class RawOutExternalLanguageModelBase(FairseqLanguageModel):
Expand Down Expand Up @@ -125,8 +126,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
).unsqueeze(
-1
) # B x 1
old_cached_state = clone_cached_state(
self.lm_decoder.get_cached_state(incremental_state)
old_cached_state = apply_to_sample(
torch.clone,
self.lm_decoder.get_cached_state(incremental_state),
)
# recompute cumsum_probs from inter-word transition probabilities
# only for those whose prev_output_token is <space>
Expand Down Expand Up @@ -432,8 +434,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
).unsqueeze(
-1
) # B x 1
old_wordlm_cached_state = clone_cached_state(
self.wordlm_decoder.get_cached_state(incremental_state)
old_wordlm_cached_state = apply_to_sample(
torch.clone,
self.wordlm_decoder.get_cached_state(incremental_state),
)

# recompute wordlm_logprobs from inter-word transition probabilities
Expand Down
32 changes: 10 additions & 22 deletions espresso/models/speech_lstm.py
Expand Up @@ -1017,28 +1017,16 @@ def masked_copy_cached_state(
src_cached_state[2],
)

def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]):
if state is None:
assert src_state is None
return None
else:
assert (
state.size(0) == mask.size(0)
and src_state is not None
and state.size() == src_state.size()
)
state[mask, ...] = src_state[mask, ...]
return state

prev_hiddens = [
masked_copy_state(p, src_p)
for (p, src_p) in zip(prev_hiddens, src_prev_hiddens)
]
prev_cells = [
masked_copy_state(p, src_p)
for (p, src_p) in zip(prev_cells, src_prev_cells)
]
input_feed = masked_copy_state(input_feed, src_input_feed)
mask = mask.unsqueeze(1)
prev_hiddens = speech_utils.apply_to_sample_pair(
lambda x, y, z=mask: torch.where(z, x, y), src_prev_hiddens, prev_hiddens
)
prev_cells = speech_utils.apply_to_sample_pair(
lambda x, y, z=mask: torch.where(z, x, y), src_prev_cells, prev_cells
)
input_feed = speech_utils.apply_to_sample_pair(
lambda x, y, z=mask: torch.where(z, x, y), src_input_feed, input_feed
)

cached_state_new = torch.jit.annotate(
Dict[str, Optional[Tensor]],
Expand Down
8 changes: 5 additions & 3 deletions espresso/models/tensorized_lookahead_language_model.py
Expand Up @@ -10,8 +10,9 @@
from espresso.data import AsrDictionary
from espresso.models.external_language_model import RawOutExternalLanguageModelBase
from espresso.tools.tensorized_prefix_tree import TensorizedPrefixTree
from espresso.tools.utils import clone_cached_state, tokenize
from espresso.tools.utils import tokenize
from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel
from fairseq.utils import apply_to_sample


class TensorizedLookaheadLanguageModel(RawOutExternalLanguageModelBase):
Expand Down Expand Up @@ -131,8 +132,9 @@ def forward(
w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1]
w[w < 0] = self.word_unk_idx

old_cached_state = clone_cached_state(
self.lm_decoder.get_cached_state(incremental_state)
old_cached_state = apply_to_sample(
torch.clone,
self.lm_decoder.get_cached_state(incremental_state),
)
# recompute cumsum_probs from inter-word transition probabilities
# only for those whose prev_output_token is <space>
Expand Down
38 changes: 14 additions & 24 deletions espresso/models/transformer/speech_transformer_decoder.py
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from torch import Tensor

import espresso.tools.utils as speech_utils
from espresso.models.transformer import SpeechTransformerConfig
from espresso.modules import (
RelativePositionalEmbedding,
Expand Down Expand Up @@ -430,34 +431,23 @@ def masked_copy_cached_state(
F.pad(src_p, (0, 1)) for src_p in src_prev_key_padding_mask
]

def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]):
if state is None:
assert src_state is None
return None
else:
assert (
state.size(0) == mask.size(0)
and src_state is not None
and state.size() == src_state.size()
)
state[mask, ...] = src_state[mask, ...]
return state

prev_key = [
masked_copy_state(p, src_p) for (p, src_p) in zip(prev_key, src_prev_key)
]
prev_value = [
masked_copy_state(p, src_p)
for (p, src_p) in zip(prev_value, src_prev_value)
]
kv_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
prev_key = speech_utils.apply_to_sample_pair(
lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_key, prev_key
)
prev_value = speech_utils.apply_to_sample_pair(
lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_value, prev_value
)
if prev_key_padding_mask is None:
prev_key_padding_mask = src_prev_key_padding_mask
else:
assert src_prev_key_padding_mask is not None
prev_key_padding_mask = [
masked_copy_state(p, src_p)
for (p, src_p) in zip(prev_key_padding_mask, src_prev_key_padding_mask)
]
pad_mask = mask.unsqueeze(1)
prev_key_padding_mask = speech_utils.apply_to_sample_pair(
lambda x, y, z=pad_mask: torch.where(z, x, y),
src_prev_key_padding_mask,
prev_key_padding_mask,
)

cached_state = torch.jit.annotate(
Dict[str, Optional[Tensor]],
Expand Down
13 changes: 10 additions & 3 deletions espresso/models/transformer/speech_transformer_encoder.py
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
from torch import Tensor

import espresso.tools.utils as speech_utils
from espresso.models.transformer import SpeechTransformerConfig
Expand Down Expand Up @@ -314,7 +315,12 @@ def forward_scriptable(
src_tokens,
~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)),
)
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
has_pads = (
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
)
# Torchscript doesn't handle bool Tensor correctly, so we need to work around.
if torch.jit.is_scripting():
has_pads = torch.tensor(1) if has_pads else torch.tensor(0)

if self.fc0 is not None:
x = self.dropout_module(x)
Expand All @@ -330,8 +336,9 @@ def forward_scriptable(
x = self.quant_noise(x)

# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
x = x * (
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
)

# B x T x C -> T x B x C
x = x.transpose(0, 1)
Expand Down

0 comments on commit 8ec673f

Please sign in to comment.