Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def forward(self, x):

class Subsample(nn.Module):

def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0):
def __init__(self,
encoder_dim: int = 0,
input_dropout_rate: float = 0.0,
num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
self.input_dropout_rate = input_dropout_rate
Expand All @@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0):
self.conv2 = Conv2dSubsampling(
input_channels=encoder_dim, output_channels=encoder_dim)

self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
self.linear = nn.Linear(
in_features=self.encoder_dim * num_bins // 4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reasoning for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of the two subsampling layers reduce the number mel-spectrogram features by half.

out_features=self.encoder_dim,
bias=True)
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
self.dropout = nn.Dropout(p=self.input_dropout_rate)

Expand Down Expand Up @@ -123,6 +129,7 @@ def __init__(self,
self.kernel = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
self.bias = nn.Parameter(torch.zeros(output_channels))
self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))

def get_same_padding(self, input_shape):
in_height, in_width = input_shape[2:]
Expand Down Expand Up @@ -162,15 +169,11 @@ def forward(self, inputs, paddings):
input_length = paddings.shape[1]
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
padded_paddings = torch.cat([
paddings[:, None, :],
torch.zeros(
size=(paddings.shape[0], 1, pad_len), device=paddings.device)
],
dim=2)
padded_paddings = F.pad(
paddings[:, None, :], (0, pad_len), mode='constant', value=0)
out_padding = F.conv1d(
input=padded_paddings,
weight=torch.ones([1, 1, 1], device=paddings.device),
weight=self.paddings_kernel,
stride=self.filter_stride[:1])
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
Expand All @@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig):
self.config = config

self.ln = LayerNorm(dim=config.encoder_dim)
self.linear1 = nn.LazyLinear(
self.linear1 = nn.Linear(
in_features=config.encoder_dim,
out_features=config.encoder_dim * config.feed_forward_expansion_factor,
bias=True)
self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate)
self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
self.linear2 = nn.Linear(
in_features=config.encoder_dim * config.feed_forward_expansion_factor,
out_features=config.encoder_dim,
bias=True)

if config.feed_forward_residual_dropout_rate is None:
feed_forward_residual_dropout_rate = 0.1
Expand Down Expand Up @@ -253,217 +260,32 @@ def forward(self, inputs):
return inputs * scale


class MHSAwithQS(nn.MultiheadAttention):
# pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name
class MHSAwithQS(nn.Module):

def __init__(self, config: ConformerConfig):
super().__init__(
embed_dim=config.encoder_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout_rate,
bias=True,
batch_first=True)
super().__init__()
self.embed_dim = config.encoder_dim
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout_rate
self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)

def _scaled_in_proj_weight(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i assume that the new implementation is numerically equivalent, SDPA is probably the right bet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried the default multihead self-attention without query scaler and it still couldn't run successfully without adjusting the attention backends. If this is useful for your debugging, I can create a separate branch with this setup?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the previous implementation was quite long sdpa is something that's maintained in pytorch so this actually makes things better

Copy link
Contributor Author

@chandramouli-sastry chandramouli-sastry Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation was long because it was extending the nn.MultiheadAttention by changing the in-projection weights/biases in the forward pass but the attention was still managed entirely by pytorch.

# Scale the query projection weight.
qs_input = self.in_proj_weight[:self.embed_dim].view(
self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2)
in_proj_queryW_scaled = self.qs(qs_input).transpose(
1, 2).view(*self.in_proj_weight[:self.embed_dim].shape)
in_proj_weight = torch.cat(
[in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]])
return in_proj_weight

def _scaled_in_proj_bias(self):
# Scale the query bias.
in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view(
self.num_heads, self.embed_dim // self.num_heads)).view(-1)
in_proj_bias = torch.cat(
[in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]])
return in_proj_bias

def forward(self,
query,
key,
value,
key_padding_mask=None,
need_weights: bool = True,
attn_mask=None,
average_attn_weights: bool = True):
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)

Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.

.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ''
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.dropout:
why_not_fast_path = f"dropout was {self.dropout}, required zero"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"

if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device))
for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
# Scale the query bias parameter and the query projection weight.
in_proj_weight = self._scaled_in_proj_weight()
in_proj_bias = self._scaled_in_proj_bias()
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
in_proj_weight,
in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else
0 if attn_mask is not None else None)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
f"The fast path was not hit because {why_not_fast_path}")

if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
else:
# Scale the query bias parameter and the query projection weight.
in_proj_weight = self._scaled_in_proj_weight()
in_proj_bias = self._scaled_in_proj_bias()
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
in_proj_weight, in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def forward(self, inputs, key_padding_mask=None):
batch_size, seq_len, embed_dim = inputs.shape
q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2)
q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
out = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=~key_padding_mask[:, None, None],
dropout_p=self.dropout,
).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = self.out_proj(out)
return out


class MultiHeadedSelfAttention(nn.Module):
Expand All @@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig):

def forward(self, outputs, paddings):
outputs = self.ln(outputs)
outputs, _ = self.self_attention(
query=outputs,
key=outputs,
value=outputs,
key_padding_mask=paddings==1,
need_weights=False,
outputs = self.self_attention(
outputs,
key_padding_mask=paddings == 1,
)
outputs = self.dropout(outputs)
return outputs
Expand All @@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig):
self.register_buffer('running_var', running_var)
self.scale = nn.Parameter(torch.zeros(config.encoder_dim))
self.bias = nn.Parameter(torch.zeros(config.encoder_dim))
self.register_buffer('momentum',
torch.FloatTensor([config.batch_norm_momentum]))
self.register_buffer('epsilon',
torch.FloatTensor([config.batch_norm_epsilon]))

self.register_buffer('dim', torch.FloatTensor([config.encoder_dim]))
# self.momentum = config.batch_norm_momentum
# self.epsilon = config.batch_norm_epsilon
# self.dim = config.encoder_dim
self.momentum = config.batch_norm_momentum
self.epsilon = config.batch_norm_epsilon

def forward(self, inputs, input_paddings):
#inputs: NHD
#padding: NH
"""
Alternatively:
inputs[input_paddings==0] = F.batch_norm(
input = inputs[input_paddings==0],
running_mean = self.running_mean,
running_var = self.running_var,
weight = 1+self.scale,
bias = self.bias,
training = self.training,
momentum=1-self.momentum,
eps=self.epsilon
)
inputs.masked_fill(input_paddings[...,None] != 0, 0)
return inputs
"""
mask = 1 - input_paddings[:, :, None]
if self.training:
count = mask.sum()
Expand Down Expand Up @@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig):
else:
input_dropout_rate = config.input_dropout_rate
self.subsample = Subsample(
encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate)
encoder_dim=config.encoder_dim,
input_dropout_rate=input_dropout_rate,
num_bins=preprocessing_config.num_bins)
self.conformers = nn.ModuleList(
[ConformerBlock(config) for _ in range(config.num_encoder_layers)])

Expand Down
Loading