-
Notifications
You must be signed in to change notification settings - Fork 528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add eval_drop_last flag to fix TE eval bug #1247
Conversation
25e64ee
to
d821d42
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, this is not good because we should not be dropping evaluation data. Eval results should be exact.
Agreed with Daniel here |
Can we disable TE layers just for eval if they have this batch size requirement? Or turn off fp8 temporarily? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we doc for this? Agree dropping is bad
Closing as we are going with a different approach here |
Description
This PR introduces the
eval_drop_last
flag which enables thedrop_last
flag in ICL eval pytorch Dataloaders. This flag ensures that all dataset batches will be divisible byeval_batch_size
. This feature is necessary because TransformerEngine requires all inputs to be divisible by 8 and so we must pass in batches of size 8. Before, the eval dataloaders would return the remainder of the dataset size on the last batch which would result in an error.For example, if the dataset was of length 41 and the batch size was 8, the last batch would be of size
41 % 8 = 1
which would break TE. Now with thiseval_drop_last
flag enabled, we simply skip this last batch of size 1.Note: enabling this flag will result in different eval scores.
Testing
Unit Test:
test_icl_task_tokenizer_and_dataloader
Integration Test:
Before:
fp8-llama3-8b-metamath-4ep-4LEFPw
🔴Error Traceback:
fp8-llama3-8b-metamath-4ep-0uiOJb
✅llama3-8b-metamath-4ep-jaIcPX
with no skipped batchesIssues Fixed
https://databricks.atlassian.net/browse/RGENAI-165