From edcb3fc356b3fe65051d859086311448004c0acb Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 17:11:17 -0700 Subject: [PATCH] Fix stride computation formula used during compute estimation Turns out the previous PR https://github.com/pytorch-labs/autoparallel/pull/37 was not correct. It divided the wrong dim's stride. This PR divides the dim to the left of the one being sharded, which is what really happens. Note: that we have this util at all is worrying me. Why don't we just use dtensors to propagate? --- autoparallel/compute_estimation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index c104a42e..ba5a5f90 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -169,9 +169,10 @@ def _get_sharded_shape_stride(spec): if placement.is_shard(): dim = placement.dim new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size - new_tensor_stride[dim] = ( - new_tensor_stride[dim] + mesh_size - 1 - ) // mesh_size + if dim - 1 > 0: + new_tensor_stride[dim - 1] = ( + new_tensor_stride[dim - 1] + mesh_size - 1 + ) // mesh_size return new_tensor_shape, new_tensor_stride