-
Notifications
You must be signed in to change notification settings - Fork 75
Conformer OOM fix #549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conformer OOM fix #549
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,7 +71,10 @@ def forward(self, x): | |
|
|
||
| class Subsample(nn.Module): | ||
|
|
||
| def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): | ||
| def __init__(self, | ||
| encoder_dim: int = 0, | ||
| input_dropout_rate: float = 0.0, | ||
| num_bins: int = 80): | ||
| super().__init__() | ||
| self.encoder_dim = encoder_dim | ||
| self.input_dropout_rate = input_dropout_rate | ||
|
|
@@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): | |
| self.conv2 = Conv2dSubsampling( | ||
| input_channels=encoder_dim, output_channels=encoder_dim) | ||
|
|
||
| self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True) | ||
| self.linear = nn.Linear( | ||
| in_features=self.encoder_dim * num_bins // 4, | ||
| out_features=self.encoder_dim, | ||
| bias=True) | ||
| self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) | ||
| self.dropout = nn.Dropout(p=self.input_dropout_rate) | ||
|
|
||
|
|
@@ -123,6 +129,7 @@ def __init__(self, | |
| self.kernel = nn.Parameter( | ||
| torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) | ||
| self.bias = nn.Parameter(torch.zeros(output_channels)) | ||
| self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) | ||
|
|
||
| def get_same_padding(self, input_shape): | ||
| in_height, in_width = input_shape[2:] | ||
|
|
@@ -162,15 +169,11 @@ def forward(self, inputs, paddings): | |
| input_length = paddings.shape[1] | ||
| stride = self.filter_stride[0] | ||
| pad_len = (input_length + stride - 1) // stride * stride - input_length | ||
| padded_paddings = torch.cat([ | ||
| paddings[:, None, :], | ||
| torch.zeros( | ||
| size=(paddings.shape[0], 1, pad_len), device=paddings.device) | ||
| ], | ||
| dim=2) | ||
| padded_paddings = F.pad( | ||
| paddings[:, None, :], (0, pad_len), mode='constant', value=0) | ||
| out_padding = F.conv1d( | ||
| input=padded_paddings, | ||
| weight=torch.ones([1, 1, 1], device=paddings.device), | ||
| weight=self.paddings_kernel, | ||
| stride=self.filter_stride[:1]) | ||
| out_padding = out_padding.squeeze(dim=1) | ||
| outputs = outputs * (1 - out_padding[:, None, :, None]) | ||
|
|
@@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig): | |
| self.config = config | ||
|
|
||
| self.ln = LayerNorm(dim=config.encoder_dim) | ||
| self.linear1 = nn.LazyLinear( | ||
| self.linear1 = nn.Linear( | ||
| in_features=config.encoder_dim, | ||
| out_features=config.encoder_dim * config.feed_forward_expansion_factor, | ||
| bias=True) | ||
| self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) | ||
| self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True) | ||
| self.linear2 = nn.Linear( | ||
| in_features=config.encoder_dim * config.feed_forward_expansion_factor, | ||
| out_features=config.encoder_dim, | ||
| bias=True) | ||
|
|
||
| if config.feed_forward_residual_dropout_rate is None: | ||
| feed_forward_residual_dropout_rate = 0.1 | ||
|
|
@@ -253,217 +260,32 @@ def forward(self, inputs): | |
| return inputs * scale | ||
|
|
||
|
|
||
| class MHSAwithQS(nn.MultiheadAttention): | ||
| # pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name | ||
| class MHSAwithQS(nn.Module): | ||
|
|
||
| def __init__(self, config: ConformerConfig): | ||
| super().__init__( | ||
| embed_dim=config.encoder_dim, | ||
| num_heads=config.num_attention_heads, | ||
| dropout=config.attention_dropout_rate, | ||
| bias=True, | ||
| batch_first=True) | ||
| super().__init__() | ||
| self.embed_dim = config.encoder_dim | ||
| self.num_heads = config.num_attention_heads | ||
| self.dropout = config.attention_dropout_rate | ||
| self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) | ||
| self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) | ||
| self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) | ||
|
|
||
| def _scaled_in_proj_weight(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok i assume that the new implementation is numerically equivalent, SDPA is probably the right bet
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also tried the default multihead self-attention without query scaler and it still couldn't run successfully without adjusting the attention backends. If this is useful for your debugging, I can create a separate branch with this setup?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the previous implementation was quite long sdpa is something that's maintained in pytorch so this actually makes things better
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation was long because it was extending the nn.MultiheadAttention by changing the in-projection weights/biases in the forward pass but the attention was still managed entirely by pytorch. |
||
| # Scale the query projection weight. | ||
| qs_input = self.in_proj_weight[:self.embed_dim].view( | ||
| self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2) | ||
| in_proj_queryW_scaled = self.qs(qs_input).transpose( | ||
| 1, 2).view(*self.in_proj_weight[:self.embed_dim].shape) | ||
| in_proj_weight = torch.cat( | ||
| [in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]]) | ||
| return in_proj_weight | ||
|
|
||
| def _scaled_in_proj_bias(self): | ||
| # Scale the query bias. | ||
| in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view( | ||
| self.num_heads, self.embed_dim // self.num_heads)).view(-1) | ||
| in_proj_bias = torch.cat( | ||
| [in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]]) | ||
| return in_proj_bias | ||
|
|
||
| def forward(self, | ||
| query, | ||
| key, | ||
| value, | ||
| key_padding_mask=None, | ||
| need_weights: bool = True, | ||
| attn_mask=None, | ||
| average_attn_weights: bool = True): | ||
| r""" | ||
| Args: | ||
| query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` | ||
| or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, | ||
| :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. | ||
| Queries are compared against key-value pairs to produce the output. | ||
| See "Attention Is All You Need" for more details. | ||
| key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` | ||
| or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, | ||
| :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. | ||
| See "Attention Is All You Need" for more details. | ||
| value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when | ||
| ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source | ||
| sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. | ||
| See "Attention Is All You Need" for more details. | ||
| key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` | ||
| to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. | ||
| Binary and byte masks are supported. | ||
| For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for | ||
| the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. | ||
| need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. | ||
| Default: ``True``. | ||
| attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape | ||
| :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, | ||
| :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be | ||
| broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. | ||
| Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the | ||
| corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the | ||
| corresponding position is not allowed to attend. For a float mask, the mask values will be added to | ||
| the attention weight. | ||
| average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across | ||
| heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an | ||
| effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) | ||
|
|
||
| Outputs: | ||
| - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, | ||
| :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, | ||
| where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the | ||
| embedding dimension ``embed_dim``. | ||
| - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, | ||
| returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or | ||
| :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and | ||
| :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per | ||
| head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. | ||
|
|
||
| .. note:: | ||
| `batch_first` argument is ignored for unbatched inputs. | ||
| """ | ||
| is_batched = query.dim() == 3 | ||
| if key_padding_mask is not None: | ||
| _kpm_dtype = key_padding_mask.dtype | ||
| if _kpm_dtype != torch.bool and not torch.is_floating_point( | ||
| key_padding_mask): | ||
| raise AssertionError( | ||
| "only bool and floating types of key_padding_mask are supported") | ||
| why_not_fast_path = '' | ||
| if not is_batched: | ||
| why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" | ||
| elif query is not key or key is not value: | ||
| # When lifting this restriction, don't forget to either | ||
| # enforce that the dtypes all match or test cases where | ||
| # they don't! | ||
| why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" | ||
| elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: | ||
| why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" | ||
| elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: | ||
| # this case will fail anyway, but at least they'll get a useful error message. | ||
| why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" | ||
| elif self.training: | ||
| why_not_fast_path = "training is enabled" | ||
| elif not self.batch_first: | ||
| why_not_fast_path = "batch_first was not True" | ||
| elif self.bias_k is not None: | ||
| why_not_fast_path = "self.bias_k was not None" | ||
| elif self.bias_v is not None: | ||
| why_not_fast_path = "self.bias_v was not None" | ||
| elif self.dropout: | ||
| why_not_fast_path = f"dropout was {self.dropout}, required zero" | ||
| elif self.add_zero_attn: | ||
| why_not_fast_path = "add_zero_attn was enabled" | ||
| elif not self._qkv_same_embed_dim: | ||
| why_not_fast_path = "_qkv_same_embed_dim was not True" | ||
| elif attn_mask is not None: | ||
| why_not_fast_path = "attn_mask was not None" | ||
| elif query.is_nested and key_padding_mask is not None: | ||
| why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" | ||
| elif self.num_heads % 2 == 1: | ||
| why_not_fast_path = "num_heads is odd" | ||
| elif torch.is_autocast_enabled(): | ||
| why_not_fast_path = "autocast is enabled" | ||
|
|
||
| if not why_not_fast_path: | ||
| tensor_args = ( | ||
| query, | ||
| key, | ||
| value, | ||
| self.in_proj_weight, | ||
| self.in_proj_bias, | ||
| self.out_proj.weight, | ||
| self.out_proj.bias, | ||
| ) | ||
| # We have to use list comprehensions below because TorchScript does not support | ||
| # generator expressions. | ||
| if torch.overrides.has_torch_function(tensor_args): | ||
| why_not_fast_path = "some Tensor argument has_torch_function" | ||
| elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) | ||
| for x in tensor_args]): | ||
| why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" | ||
| elif torch.is_grad_enabled() and any( | ||
| [x is not None and x.requires_grad for x in tensor_args]): | ||
| why_not_fast_path = ( | ||
| "grad is enabled and at least one of query or the " | ||
| "input/output projection weights or biases requires_grad") | ||
| if not why_not_fast_path: | ||
| # Scale the query bias parameter and the query projection weight. | ||
| in_proj_weight = self._scaled_in_proj_weight() | ||
| in_proj_bias = self._scaled_in_proj_bias() | ||
| return torch._native_multi_head_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| self.embed_dim, | ||
| self.num_heads, | ||
| in_proj_weight, | ||
| in_proj_bias, | ||
| self.out_proj.weight, | ||
| self.out_proj.bias, | ||
| key_padding_mask if key_padding_mask is not None else attn_mask, | ||
| need_weights, | ||
| average_attn_weights, | ||
| 1 if key_padding_mask is not None else | ||
| 0 if attn_mask is not None else None) | ||
| any_nested = query.is_nested or key.is_nested or value.is_nested | ||
| assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + | ||
| f"The fast path was not hit because {why_not_fast_path}") | ||
|
|
||
| if self.batch_first and is_batched: | ||
| # make sure that the transpose op does not affect the "is" property | ||
| if key is value: | ||
| if query is key: | ||
| query = key = value = query.transpose(1, 0) | ||
| else: | ||
| query, key = [x.transpose(1, 0) for x in (query, key)] | ||
| value = key | ||
| else: | ||
| query, key, value = [x.transpose(1, 0) for x in (query, key, value)] | ||
|
|
||
| if not self._qkv_same_embed_dim: | ||
| attn_output, attn_output_weights = F.multi_head_attention_forward( | ||
| query, key, value, self.embed_dim, self.num_heads, | ||
| self.in_proj_weight, self.in_proj_bias, | ||
| self.bias_k, self.bias_v, self.add_zero_attn, | ||
| self.dropout, self.out_proj.weight, self.out_proj.bias, | ||
| training=self.training, | ||
| key_padding_mask=key_padding_mask, need_weights=need_weights, | ||
| attn_mask=attn_mask, use_separate_proj_weight=True, | ||
| q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | ||
| v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) | ||
| else: | ||
| # Scale the query bias parameter and the query projection weight. | ||
| in_proj_weight = self._scaled_in_proj_weight() | ||
| in_proj_bias = self._scaled_in_proj_bias() | ||
| attn_output, attn_output_weights = F.multi_head_attention_forward( | ||
| query, key, value, self.embed_dim, self.num_heads, | ||
| in_proj_weight, in_proj_bias, | ||
| self.bias_k, self.bias_v, self.add_zero_attn, | ||
| self.dropout, self.out_proj.weight, self.out_proj.bias, | ||
| training=self.training, | ||
| key_padding_mask=key_padding_mask, need_weights=need_weights, | ||
| attn_mask=attn_mask, average_attn_weights=average_attn_weights) | ||
| if self.batch_first and is_batched: | ||
| return attn_output.transpose(1, 0), attn_output_weights | ||
| else: | ||
| return attn_output, attn_output_weights | ||
| def forward(self, inputs, key_padding_mask=None): | ||
| batch_size, seq_len, embed_dim = inputs.shape | ||
| q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) | ||
| q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) | ||
| k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) | ||
| v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) | ||
| out = F.scaled_dot_product_attention( | ||
| query=q, | ||
| key=k, | ||
| value=v, | ||
| attn_mask=~key_padding_mask[:, None, None], | ||
| dropout_p=self.dropout, | ||
| ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) | ||
| out = self.out_proj(out) | ||
| return out | ||
|
|
||
|
|
||
| class MultiHeadedSelfAttention(nn.Module): | ||
|
|
@@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig): | |
|
|
||
| def forward(self, outputs, paddings): | ||
| outputs = self.ln(outputs) | ||
| outputs, _ = self.self_attention( | ||
| query=outputs, | ||
| key=outputs, | ||
| value=outputs, | ||
| key_padding_mask=paddings==1, | ||
| need_weights=False, | ||
| outputs = self.self_attention( | ||
| outputs, | ||
| key_padding_mask=paddings == 1, | ||
| ) | ||
| outputs = self.dropout(outputs) | ||
| return outputs | ||
|
|
@@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig): | |
| self.register_buffer('running_var', running_var) | ||
| self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) | ||
| self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) | ||
| self.register_buffer('momentum', | ||
| torch.FloatTensor([config.batch_norm_momentum])) | ||
| self.register_buffer('epsilon', | ||
| torch.FloatTensor([config.batch_norm_epsilon])) | ||
|
|
||
| self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) | ||
| # self.momentum = config.batch_norm_momentum | ||
| # self.epsilon = config.batch_norm_epsilon | ||
| # self.dim = config.encoder_dim | ||
| self.momentum = config.batch_norm_momentum | ||
| self.epsilon = config.batch_norm_epsilon | ||
|
|
||
| def forward(self, inputs, input_paddings): | ||
| #inputs: NHD | ||
| #padding: NH | ||
| """ | ||
| Alternatively: | ||
| inputs[input_paddings==0] = F.batch_norm( | ||
| input = inputs[input_paddings==0], | ||
| running_mean = self.running_mean, | ||
| running_var = self.running_var, | ||
| weight = 1+self.scale, | ||
| bias = self.bias, | ||
| training = self.training, | ||
| momentum=1-self.momentum, | ||
| eps=self.epsilon | ||
| ) | ||
| inputs.masked_fill(input_paddings[...,None] != 0, 0) | ||
| return inputs | ||
| """ | ||
| mask = 1 - input_paddings[:, :, None] | ||
| if self.training: | ||
| count = mask.sum() | ||
|
|
@@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig): | |
| else: | ||
| input_dropout_rate = config.input_dropout_rate | ||
| self.subsample = Subsample( | ||
| encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate) | ||
| encoder_dim=config.encoder_dim, | ||
| input_dropout_rate=input_dropout_rate, | ||
| num_bins=preprocessing_config.num_bins) | ||
| self.conformers = nn.ModuleList( | ||
| [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reasoning for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each of the two subsampling layers reduce the number mel-spectrogram features by half.