From 2faca09585a98f9fafbb5ce61b563494c5eba8f5 Mon Sep 17 00:00:00 2001 From: Felicity Liao Date: Thu, 10 Apr 2025 15:51:56 -0700 Subject: [PATCH] 3/n Support for multiple shards per table (#2877) Summary: This will be crucial for any non-TW sharding type going through dynamic sharding. Handles the case where the embedding dimension is not the same across shards. For example: ``` num_embeddings = 4 embedding_dim = 16 Table 0: CW sharded across ranks: [0, 1] Table 1: CW sharded across rank: [0] Table 0 shard 0 size: [8, 4] Table 1 shard 0 size: [16, 4] ``` This will require `Table 0 shard 0` and `Table 1 shard 0` to be concatenated in dimension 1. Main changes: ## All_to_all Collective input/output tensor composition & processing 1. Concatenating `local_input_tensor` to `all_to_all` collective by dimension 1 instead of 0. This is because dim 0 is variable for each shard depending , while dim 1 is consistently the same across all shards/tables as it is the number of embeddings. 2. This means we need to **transpose**, and properly process both the `local_input_tensor` and `local_output_tensor` to be passed into the `all_to_all` collective. 3. Made small optimization to the `local_output_tensor` to not be consistently updated via `torch.concat` since we only need the final dimensions for the empty tensor. ## Correct Order of `all_to_all` tensor output To handle multiple shards per table, we need to properly store the **order** which the `all_to_all` collective is collecting the tensors across ranks. The order of shards composing the `local_output_tensor` is: 1. First ordered by rank 2. Then ordered by table order in the EBC -> this can be inferred from the `module_sharding_plan` 3. Finally by the shard_order for this table. * Since we can assume each rank only contain 1 shard per table, we only need to track 1. and 2. The return type of `shards_all_to_all`, and input type of `update_state_dict_post_resharding` is updated to be a flattened list of the above order. * Also, I'm storing the `shard_size` in dim 1 for this output while composing the `local_output_tensor`, to avoid needing to re-query in `update_state_dict_post_resharding`. This will ensure correct behavior in the CW sharding implementation/test in the next diff. Reviewed By: iamzainhuda Differential Revision: D72486367 --- torchrec/distributed/embeddingbag.py | 2 +- .../distributed/sharding/dynamic_sharding.py | 87 +++++++++++++------ 2 files changed, 61 insertions(+), 28 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 102ed2060..5315b079a 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1526,7 +1526,7 @@ def update_shards( current_state = update_state_dict_post_resharding( state_dict=current_state, - shard_names_by_src_rank=local_shard_names_by_src_rank, + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, output_tensor=local_output_tensor, new_sharding_params=changed_sharding_params, curr_rank=dist.get_rank(), diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 4e50c4f72..05ca485f2 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -27,14 +27,14 @@ def shards_all_to_all( changed_sharding_params: Dict[str, ParameterSharding], env: ShardingEnv, extend_shard_name: Callable[[str], str] = lambda x: x, -) -> Tuple[List[str], torch.Tensor]: +) -> Tuple[List[Tuple[str, int]], torch.Tensor]: """ Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. Assumes ranks are ordered in ParameterSharding.ranks. Args: - module (ShardedEmbeddingBagCollection): The module containing sharded tensors to be redistributed. - TODO: Update to support more modules + module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed. + TODO: Update to support more modules, currently only supports ShardedEmbeddingBagCollection. state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors. @@ -47,8 +47,9 @@ def shards_all_to_all( extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. Returns: - Tuple[List[str], torch.Tensor]: A tuple containing: - - A list of shard names that were sent from a specific rank to the current rank, ordered by rank, then shard order. + Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing: + - A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank. + This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order. - The tensor containing all shards received by the current rank after the all-to-all operation. """ if env.output_dtensor: @@ -62,10 +63,12 @@ def shards_all_to_all( rank = dist.get_rank() input_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_splits_per_rank = [[0] * world_size for _ in range(world_size)] - local_input_tensor = torch.empty([0], device=device) - local_output_tensor = torch.empty([0], device=device) - shard_names_by_src_rank = [] + # 0 by default, as current rank may be recieving 0 shards + num_embeddings_received = 0 + output_tensor_tensor_count = 0 + shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] + local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] for shard_name, param in changed_sharding_params.items(): sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None @@ -84,27 +87,52 @@ def shards_all_to_all( src_rank = src_ranks[i] shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes - shard_size_dim_0 = shard_size[0] - input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_0 - output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_0 + shard_size_dim_1 = shard_size[1] + input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1 + output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1 if src_rank == rank: local_shards = sharded_t.local_shards() assert len(local_shards) == 1 - local_input_tensor = torch.cat( - ( - local_input_tensor, - sharded_t.local_shards()[0].tensor, - ) + local_table_to_input_tensor_by_dst_rank[dst_rank].append( + sharded_t.local_shards()[0].tensor ) if dst_rank == rank: - shard_names_by_src_rank.append(shard_name) - local_output_tensor = torch.cat( - (local_output_tensor, torch.empty(shard_size, device=device)) + shard_names_to_lengths_by_src_rank[src_rank].append( + (shard_name, shard_size_dim_1) ) + # NOTE: Only need to update num_embeddings_received to be the + # num_embeddings of shards if this rank is actually recieving + # any tensors + if num_embeddings_received == 0: + num_embeddings_received = shard_size[0] + else: + # TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable + # For now, assume that shard_sizes in dim 0 are all the same + assert num_embeddings_received == shard_size[0] + output_tensor_tensor_count += shard_size[1] local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] + local_input_tensor = torch.empty([0], device=device) + for sub_l in local_table_to_input_tensor_by_dst_rank: + for shard_info in sub_l: + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=1, + ) + + # Transposing the Tensors - because we are concatenating them along dimension 1 + # This is because dim 0 size may be different for different shards + # whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table + local_output_tensor = torch.empty( + [output_tensor_tensor_count, num_embeddings_received], device=device + ) + local_input_tensor = local_input_tensor.T.contiguous() + assert sum(local_output_splits) == len(local_output_tensor) assert sum(local_input_splits) == len(local_input_tensor) dist.all_to_all_single( @@ -115,12 +143,18 @@ def shards_all_to_all( group=dist.group.WORLD, ) - return shard_names_by_src_rank, local_output_tensor + flattened_output_names_lengths = [ + shard_info + for sub_l in shard_names_to_lengths_by_src_rank + for shard_info in sub_l + ] + + return flattened_output_names_lengths, local_output_tensor def update_state_dict_post_resharding( state_dict: Dict[str, ShardedTensor], - shard_names_by_src_rank: List[str], + ordered_shard_names_and_lengths: List[Tuple[str, int]], output_tensor: torch.Tensor, new_sharding_params: Dict[str, ParameterSharding], curr_rank: int, @@ -133,8 +167,9 @@ def update_state_dict_post_resharding( Args: state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards. - shard_names_by_src_rank (List[str]): A list of shard names that were sent from a specific rank to the - current rank, ordered by rank, then shard order. + shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1 + that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and + sizes by rank, then shard order. output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation. @@ -149,16 +184,14 @@ def update_state_dict_post_resharding( Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. """ slice_index = 0 - shard_names_by_src_rank shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} - for shard_name in shard_names_by_src_rank: - shard_size = state_dict[extend_shard_name(shard_name)].size(0) + for shard_name, shard_size in ordered_shard_names_and_lengths: end_slice_index = slice_index + shard_size shard_name_to_local_output_tensor[shard_name] = output_tensor[ slice_index:end_slice_index - ] + ].T slice_index = end_slice_index for shard_name, param in new_sharding_params.items():