-
Notifications
You must be signed in to change notification settings - Fork 228
add pad-vocab-size-to argument and tests
#255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2025ac2
c33343d
390b4dc
784b751
6f3a472
1d9649a
7fa5c10
bcc6d8d
9e17a4f
92614bf
8322f89
6d72073
84333d3
b2382d8
d4a15a3
cd5e8b4
a6ee894
f534c43
0a1167b
34bfd60
50cb3ca
786e02d
20d08a8
915bd6c
119a0d2
5c6dec0
de3353f
8485770
df24492
46fc9da
9ffafb1
1eb5baa
56af695
3ea0c6b
be2e371
8986962
a72fa03
1e5b2af
8d8be7e
b2867a7
ef61e89
c10a359
fc975b4
a2b86b7
ae9f83c
ecdda50
c170fd9
c82d615
3587b52
a90a8f9
982d88c
78b7686
c922204
806cbb5
f515b67
02f86f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -217,6 +217,9 @@ def __init__(self, num_embeddings, embedding_dim, | |
|
|
||
|
|
||
| def forward(self, input_): | ||
| if torch.any(input_ >= self.num_embeddings): | ||
| raise ValueError(f"There is an input id in the input that is greater than the highest possible input id.\nInput: {input_}\nnum_embeddings: {self.num_embeddings}") | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. killing the training at run time because the input is broken? This we can't afford to support. Additionally this assert can't be acted upon since the operator would have no idea how to fix this as you're not including the input id and the sample id - impossible to act upon. If there is a need to validate data before the training it should happen separately from the training. worst case scenario that can be supported is probably skip the bad input, i.e. do the checking at the dataloader retrieve stage. but still this feels like a bad design in the software.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the forward should be as lean as possible and not need to check anything so that it could run as fast as possible.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't agree, we need this sanity check so that we're not doing something wrong. Right now, model would silently bypass this issue if you use TP>2. IMO it should kill the training as we are doing something VERY bad (Essentially if you use TP=1, this does get killed when you call
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't do data sanity checks in Which case are you guarding against - all of the data is completely borked or only some inputs are invalid?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and practically, let's do a an imaginary scenario - you started the training and this assert happens 5 days in - what do you do? I fail to see how this assert is actionable.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, unfortunately I don't have the resources to dive right now into this as I have to finish lots of things before the launch, so I trust you will do the best thing you can and if things break during the training I will ping you and you will know what to do. It's not great to have this sort of last minute change that hasn't been thoroughly tested but what to do. It's not the only last minute change. e.g. the whole bf16 optimizer was re-written last week.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thomas and I discussed this, he helped me to clear out an important misunderstanding and he will post an update. Thank you for working on this, @SaulLu and @DanielHesslow - my apologies that I can't be involved at a deeper level at the moment.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay talked to @stas00 :
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, so the idea is do a check on the dataloader level, and warn when skipping a bad sample. If it's a lot of samples but not the majority - we continue training while someone is fixing the data. of course if the data is very broken then it'll be skipping them all and we can't train. But the point is that - don't crash the training unless you have to.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and additionally we can pre-check all our data outside of the training so that it doesn't have any bad samples. This PR is mainly for future users, but in our training we should make sure it never hits this assert |
||
| if self.tensor_model_parallel_size > 1: | ||
| # Build the mask. | ||
| input_mask = (input_ < self.vocab_start_index) | \ | ||
|
|
@@ -225,7 +228,9 @@ def forward(self, input_): | |
| masked_input = input_.clone() - self.vocab_start_index | ||
| masked_input[input_mask] = 0 | ||
| else: | ||
| # input_ is garanted to be in the range [0:self.vocab_end_index - self.vocab_start_index] thanks to the first check | ||
| masked_input = input_ | ||
|
|
||
| # Get the embeddings. | ||
| output_parallel = F.embedding(masked_input, self.weight, | ||
| self.padding_idx, self.max_norm, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, I liked the naming
vocabbecause it emphasize the fact that we're really choosing our tokenizer's vocabulary size.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually works for me, I just didn't want to indicate we were modifying the tokenizer, but really just the embedding layer. But in Meg-DS they seem to use
padded_vocab_sizeso your solution makes more sense.