From 0851db72deee619be9d980f1372279162c865aaa Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Wed, 15 May 2024 02:44:28 +0300 Subject: [PATCH] [DistPart] Fix corner case in dist partition which always led to an assertion error being triggered. (#7395) --- tools/distpartitioning/data_shuffle.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 5dda0139acee..e85bff5ecc4c 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -402,7 +402,6 @@ def exchange_feature( """ # type_ids for this feature subset on the current rank gids_feat = np.arange(gid_start, gid_end) - tids_feat = np.arange(type_id_start, type_id_end) local_idx = np.arange(0, type_id_end - type_id_start) feats_per_rank = [] @@ -473,12 +472,22 @@ def exchange_feature( ) # exchange actual data here. - if featdata_key != None: + logging.debug(f"Rank: {rank} {featdata_key.shape=}") + if featdata_key is not None: feat_dims_dtype = list(featdata_key.shape) + assert ( + len(featdata_key.shape) == 2 or len(featdata_key.shape) == 1 + ), f"We expect 1D or 2D tensors for features, got shape {featdata_key.shape}" + # When a feature is 2-dim, the shape should match the feature dimension. + if len(featdata_key.shape) == 2: + feature_dimension = feat_dims_dtype[1] + else: + feature_dimension = 0 feat_dims_dtype.append(DATA_TYPE_ID[featdata_key.dtype]) else: feat_dims_dtype = list(np.zeros((rank0_shape_len), dtype=np.int64)) feat_dims_dtype.append(DATA_TYPE_ID[torch.float32]) + feature_dimension = 0 logging.debug(f"Sending the feature shape information - {feat_dims_dtype}") all_dims_dtype = allgather_sizes( @@ -488,13 +497,18 @@ def exchange_feature( for idx in range(world_size): cond = partid_slice == (idx + local_part_id * world_size) gids_per_partid = gids_feat[cond] - tids_per_partid = tids_feat[cond] local_idx_partid = local_idx[cond] if gids_per_partid.shape[0] == 0: assert len(all_dims_dtype) % world_size == 0 dim_len = int(len(all_dims_dtype) / world_size) - rank0_shape = tuple(list(np.zeros((dim_len - 1), dtype=np.int32))) + rank0_shape = list(np.zeros((dim_len - 1), dtype=np.int32)) + assert ( + len(rank0_shape) == 2 or len(rank0_shape) == 1 + ), f"We expect 1D or 2D tensors for features, got shape {rank0_shape}" + # When a feature is 2-dim, the shape[1] (number of columns) should match the feature dimension. + if len(rank0_shape) == 2: + rank0_shape[1] = feature_dimension rank0_dtype = REV_DATA_TYPE_ID[ all_dims_dtype[(dim_len - 1) : (dim_len)][0] ]