diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index ee91b60c4eb..267f7809ef0 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -231,6 +231,9 @@ def find_batch_size(data): Returns: `int`: The batch size. """ + if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0): + raise ValueError(f"Cannot find the batch size from empty {type(data)}.") + if isinstance(data, (tuple, list)): return find_batch_size(data[0]) elif isinstance(data, Mapping):