Skip to content

Commit

Permalink
[DistPart] Fix corner case in dist partition which always led to an a…
Browse files Browse the repository at this point in the history
…ssertion error being triggered. (#7395)
  • Loading branch information
thvasilo committed May 14, 2024
1 parent 6475057 commit 0851db7
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions tools/distpartitioning/data_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def exchange_feature(
"""
# type_ids for this feature subset on the current rank
gids_feat = np.arange(gid_start, gid_end)
tids_feat = np.arange(type_id_start, type_id_end)
local_idx = np.arange(0, type_id_end - type_id_start)

feats_per_rank = []
Expand Down Expand Up @@ -473,12 +472,22 @@ def exchange_feature(
)

# exchange actual data here.
if featdata_key != None:
logging.debug(f"Rank: {rank} {featdata_key.shape=}")
if featdata_key is not None:
feat_dims_dtype = list(featdata_key.shape)
assert (
len(featdata_key.shape) == 2 or len(featdata_key.shape) == 1
), f"We expect 1D or 2D tensors for features, got shape {featdata_key.shape}"
# When a feature is 2-dim, the shape should match the feature dimension.
if len(featdata_key.shape) == 2:
feature_dimension = feat_dims_dtype[1]
else:
feature_dimension = 0
feat_dims_dtype.append(DATA_TYPE_ID[featdata_key.dtype])
else:
feat_dims_dtype = list(np.zeros((rank0_shape_len), dtype=np.int64))
feat_dims_dtype.append(DATA_TYPE_ID[torch.float32])
feature_dimension = 0

logging.debug(f"Sending the feature shape information - {feat_dims_dtype}")
all_dims_dtype = allgather_sizes(
Expand All @@ -488,13 +497,18 @@ def exchange_feature(
for idx in range(world_size):
cond = partid_slice == (idx + local_part_id * world_size)
gids_per_partid = gids_feat[cond]
tids_per_partid = tids_feat[cond]
local_idx_partid = local_idx[cond]

if gids_per_partid.shape[0] == 0:
assert len(all_dims_dtype) % world_size == 0
dim_len = int(len(all_dims_dtype) / world_size)
rank0_shape = tuple(list(np.zeros((dim_len - 1), dtype=np.int32)))
rank0_shape = list(np.zeros((dim_len - 1), dtype=np.int32))
assert (
len(rank0_shape) == 2 or len(rank0_shape) == 1
), f"We expect 1D or 2D tensors for features, got shape {rank0_shape}"
# When a feature is 2-dim, the shape[1] (number of columns) should match the feature dimension.
if len(rank0_shape) == 2:
rank0_shape[1] = feature_dimension
rank0_dtype = REV_DATA_TYPE_ID[
all_dims_dtype[(dim_len - 1) : (dim_len)][0]
]
Expand Down

0 comments on commit 0851db7

Please sign in to comment.