-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Training using "facebook/wav2vec2-base" as backbone consistently fails with the following error:
1020it [01:35, 10.71it/s]
Starting epoch 0 ...
Traceback (most recent call last):
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 365, in <module>
train(args)
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 147, in train
logits = sfc_model(wav2vec_hidden, out_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/models.py", line 41, in forward
x = self.transformer(x, src_key_padding_mask=attention_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 198, in forward
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 336, in forward
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 189, in forward
return F.layer_norm(
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/functional.py", line 2347, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Given normalized_shape=[1024], expected input with shape [*, 1024], but got input of size[28, 999, 768]
Training with the default "facebook/wav2vec2-xls-r-300m" using the same setup gives me no issues.
Could this have something to do with the fact that wav2vec2-base uses "do_stable_layer_norm": false, whereas facebook/wav2vec2-xls-r-300m uses "do_stable_layer_norm": true?
My first guess would be that the assumptions made here might not hold if "do_stable_layer_norm": false.
SHAS/src/supervised_hybrid/models.py
Line 80 in 418b5e6
| wav2vec_model.encoder.layer_norm = torch.nn.Identity() |
I will let you know if I find any additional information about this.
EDIT:
Actually it was something much simpler, the wav2vec2 base model has different hidden dimension (768 instead of 1024). Changing constants.py seems to fix everything:
https://github.com/mt-upc/SHAS/blob/main/src/supervised_hybrid/constants.py#L4
Feel free to close the issue if you think this is obvious.