Skip to content

Commit

Permalink
adds chunk streaming mask support for Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Oct 10, 2022
1 parent 9944ec7 commit 5a47606
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 58 deletions.
22 changes: 2 additions & 20 deletions espresso/models/transformer/speech_transformer_base.py
Expand Up @@ -126,18 +126,6 @@ def build_model(cls, cfg, task):
else:
transformer_encoder_input_size = task.feat_dim

encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
cfg.encoder.transformer_context,
type=int,
)
if encoder_transformer_context is not None:
assert len(encoder_transformer_context) == 2
for i in range(2):
assert encoder_transformer_context[i] is None or (
isinstance(encoder_transformer_context[i], int)
and encoder_transformer_context[i] >= 0
)

scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler(
cfg.scheduled_sampling_probs,
cfg.start_scheduled_sampling_epoch,
Expand All @@ -147,7 +135,6 @@ def build_model(cls, cfg, task):
cfg,
pre_encoder=conv_layers,
input_size=transformer_encoder_input_size,
transformer_context=encoder_transformer_context,
)
decoder = cls.build_decoder(
cfg,
Expand All @@ -162,14 +149,9 @@ def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)

@classmethod
def build_encoder(
cls, cfg, pre_encoder=None, input_size=83, transformer_context=None
):
def build_encoder(cls, cfg, pre_encoder=None, input_size=83):
return SpeechTransformerEncoderBase(
cfg,
pre_encoder=pre_encoder,
input_size=input_size,
transformer_context=transformer_context,
cfg, pre_encoder=pre_encoder, input_size=input_size
)

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions espresso/models/transformer/speech_transformer_config.py
Expand Up @@ -62,6 +62,22 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig):
layer_type: LAYER_TYPE_CHOICES = field(
default="transformer", metadata={"help": "layer type in encoder"}
)
chunk_size: int = field(
default=0,
metadata={"help": "chunk size of Transformer in chunk streaming mode if > 0"},
)
chunk_left_window: int = field(
default=0,
metadata={
"help": "number of chunks to the left of the current chunk in chunk streaming mode"
},
)
chunk_right_window: int = field(
default=0,
metadata={
"help": "number of chunks to the right of the current chunk in chunk streaming mode"
},
)
# config specific to Conformer
depthwise_conv_kernel_size: int = field(
default=31,
Expand Down
32 changes: 28 additions & 4 deletions espresso/models/transformer/speech_transformer_encoder.py
Expand Up @@ -17,6 +17,7 @@
RelativePositionalEmbedding,
TransformerWithRelativePositionalEmbeddingEncoderLayerBase,
)
from fairseq.data import data_utils
from fairseq.distributed import fsdp_wrap
from fairseq.models.transformer import Linear, TransformerEncoderBase
from fairseq.modules import (
Expand Down Expand Up @@ -59,7 +60,6 @@ def __init__(
return_fc=False,
pre_encoder=None,
input_size=83,
transformer_context=None,
):
self.cfg = cfg
super(TransformerEncoderBase, self).__init__(None) # no src dictionary
Expand Down Expand Up @@ -159,7 +159,19 @@ def __init__(
else:
self.layer_norm = None

self.transformer_context = transformer_context
self.transformer_context = speech_utils.eval_str_nested_list_or_tuple(
cfg.encoder.transformer_context,
type=int,
)
if self.transformer_context is not None:
assert len(self.transformer_context) == 2
for i in range(2):
assert self.transformer_context[i] is None or (
isinstance(self.transformer_context[i], int)
and self.transformer_context[i] >= 0
)

self.num_updates = 0

def build_encoder_layer(
self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None
Expand All @@ -183,6 +195,10 @@ def build_encoder_layer(
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer

def set_num_updates(self, num_updates):
self.num_updates = num_updates
super().set_num_updates(num_updates)

def output_lengths(self, in_lengths):
return (
in_lengths
Expand All @@ -204,6 +220,16 @@ def get_attn_mask(self, in_lengths):
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`.
"""
if self.cfg.encoder.chunk_size > 0:
with data_utils.numpy_seed(self.num_updates):
return ~speech_utils.chunk_streaming_mask(
in_lengths,
self.cfg.encoder.chunk_size,
left_window=self.cfg.encoder.chunk_left_window,
right_window=self.cfg.encoder.chunk_right_window,
always_partial_in_last=(not self.training),
)

if self.transformer_context is None or (
self.transformer_context[0] is None and self.transformer_context[1] is None
):
Expand Down Expand Up @@ -383,15 +409,13 @@ def __init__(
return_fc=False,
pre_encoder=None,
input_size=83,
transformer_context=None,
):
self.args = args
super().__init__(
SpeechTransformerConfig.from_namespace(args),
return_fc=return_fc,
pre_encoder=pre_encoder,
input_size=input_size,
transformer_context=transformer_context,
)

def build_encoder_layer(
Expand Down
17 changes: 0 additions & 17 deletions espresso/models/transformer/speech_transformer_encoder_model.py
Expand Up @@ -102,23 +102,10 @@ def build_model(cls, cfg, task):
else:
transformer_encoder_input_size = task.feat_dim

encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
cfg.encoder.transformer_context,
type=int,
)
if encoder_transformer_context is not None:
assert len(encoder_transformer_context) == 2
for i in range(2):
assert encoder_transformer_context[i] is None or (
isinstance(encoder_transformer_context[i], int)
and encoder_transformer_context[i] >= 0
)

encoder = cls.build_encoder(
cfg,
pre_encoder=conv_layers,
input_size=transformer_encoder_input_size,
transformer_context=encoder_transformer_context,
vocab_size=(
len(task.target_dictionary)
if task.target_dictionary is not None
Expand All @@ -139,14 +126,12 @@ def build_encoder(
cfg,
pre_encoder=None,
input_size=83,
transformer_context=None,
vocab_size=None,
):
return SpeechTransformerEncoderForPrediction(
cfg,
pre_encoder=pre_encoder,
input_size=input_size,
transformer_context=transformer_context,
vocab_size=vocab_size,
)

Expand Down Expand Up @@ -174,15 +159,13 @@ def __init__(
return_fc=False,
pre_encoder=None,
input_size=83,
transformer_context=None,
vocab_size=None,
):
super().__init__(
cfg,
return_fc=return_fc,
pre_encoder=pre_encoder,
input_size=input_size,
transformer_context=transformer_context,
)

self.fc_out = (
Expand Down
Expand Up @@ -165,23 +165,10 @@ def build_model(cls, cfg, task):
else:
transformer_encoder_input_size = task.feat_dim

encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple(
cfg.encoder.transformer_context,
type=int,
)
if encoder_transformer_context is not None:
assert len(encoder_transformer_context) == 2
for i in range(2):
assert encoder_transformer_context[i] is None or (
isinstance(encoder_transformer_context[i], int)
and encoder_transformer_context[i] >= 0
)

encoder = cls.build_encoder(
cfg,
pre_encoder=conv_layers,
input_size=transformer_encoder_input_size,
transformer_context=encoder_transformer_context,
)
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
Expand All @@ -206,14 +193,11 @@ def build_embedding(cls, cfg, dictionary, embed_dim, path=None):
return emb

@classmethod
def build_encoder(
cls, cfg, pre_encoder=None, input_size=83, transformer_context=None
):
def build_encoder(cls, cfg, pre_encoder=None, input_size=83):
return SpeechTransformerEncoderBase(
cfg,
pre_encoder=pre_encoder,
input_size=input_size,
transformer_context=transformer_context,
)

@classmethod
Expand Down
69 changes: 69 additions & 0 deletions espresso/tools/utils.py
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import torch
import torch.nn.functional as F

try:
import kaldi_io
Expand Down Expand Up @@ -90,6 +91,74 @@ def sequence_mask(sequence_length, max_len=None):
return seq_range_expand < seq_length_expand


def chunk_streaming_mask(
sequence_length: torch.Tensor,
chunk_size: int,
left_window: int = 0,
right_window: int = 0,
always_partial_in_last: bool = False,
):
"""Returns a mask for chunk streaming Transformer models.
Args:
sequence_length (LongTensor): sequence_length of shape `(batch)`
chunk_size (int): chunk size
left_window (int): how many left chunks can be seen (default: 0)
right_window (int): how many right chunks can be seen (default: 0)
always_partial_in_last (bool): if True always makes the last chunk partial;
otherwise makes either the first or last chunk have partial size randomly,
which is to avoid learning to emit EOS just based on partial chunk size
(default: False)
Returns:
mask: (BoolTensor): a mask tensor of shape `(tgt_len, src_len)`, where
`tgt_len` is the length of output and `src_len` is the length of input
attn_mask[tgt_i, src_j] = True` means that when calculating the embedding
for `tgt_i`, we need `src_j`.
"""

max_len = sequence_length.data.max()
chunk_start_idx = torch.arange(
0,
max_len,
chunk_size,
dtype=sequence_length.dtype,
device=sequence_length.device,
) # e.g. [0,18,36,54]
if not always_partial_in_last and np.random.rand() > 0.5:
# either first or last chunk is partial. If only the last one is not complete, EOS is not effective
chunk_start_idx = max_len - chunk_start_idx
chunk_start_idx = chunk_start_idx.flip([0])
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = F.pad(chunk_start_idx, (1, 0))

start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) # [0,0,18,36,54]
end_pad = torch.nn.functional.pad(
chunk_start_idx, (0, 1), value=max_len
) # [0,18,36,54,max_len]
seq_range = torch.arange(
0, max_len, dtype=sequence_length.dtype, device=sequence_length.device
)
idx = (
(seq_range.unsqueeze(-1) >= start_pad) & (seq_range.unsqueeze(-1) < end_pad)
).nonzero()[
:, 1
] # max_len
seq_range_expand = seq_range.unsqueeze(0).expand(max_len, -1) # max_len x max_len

idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left] # max_len
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)

idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right] # max_len
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)

return mask_left & mask_right


def convert_padding_direction(
src_frames,
src_lengths,
Expand Down

0 comments on commit 5a47606

Please sign in to comment.