diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index c104a42e..ba5a5f90 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -169,9 +169,10 @@ def _get_sharded_shape_stride(spec): if placement.is_shard(): dim = placement.dim new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size - new_tensor_stride[dim] = ( - new_tensor_stride[dim] + mesh_size - 1 - ) // mesh_size + if dim - 1 > 0: + new_tensor_stride[dim - 1] = ( + new_tensor_stride[dim - 1] + mesh_size - 1 + ) // mesh_size return new_tensor_shape, new_tensor_stride