Skip to content

Commit

Permalink
Fix pipline dataloader when batch elements contain tuple (#565)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
  • Loading branch information
3 people authored Aug 25, 2023
1 parent 0b7a760 commit c69bd1f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ def _exec_load_micro_batch(self, buffer_id):
loaded = batch[1]
if torch.is_tensor(batch[1]):
loaded = batch[1].to(self.device)
elif isinstance(batch[1], tuple):
# XXX: torch 1.6.0 DataLoader will auto convert tuple to list
elif isinstance(batch[1], (tuple, list)):
loaded = []
for x in batch[1]:
assert torch.is_tensor(x)
Expand Down

0 comments on commit c69bd1f

Please sign in to comment.