Skip to content

Commit

Permalink
Set drop last to ensure modulo16 restriction for fp8 (#1189)
Browse files Browse the repository at this point in the history
* set drop last to ensure modulo16 restriction for fp8

* fix quality

* Use all eval samples for non-FP8 case
  • Loading branch information
ksivaman committed Mar 14, 2023
1 parent eac5d13 commit 41479fe
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,14 @@ def collate_fn(examples):

# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=EVAL_BATCH_SIZE,
drop_last=(accelerator.mixed_precision == "fp8"),
)

return train_dataloader, eval_dataloader
Expand Down

0 comments on commit 41479fe

Please sign in to comment.