diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 102ed2060..5315b079a 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1526,7 +1526,7 @@ def update_shards( current_state = update_state_dict_post_resharding( state_dict=current_state, - shard_names_by_src_rank=local_shard_names_by_src_rank, + ordered_shard_names_and_lengths=local_shard_names_by_src_rank, output_tensor=local_output_tensor, new_sharding_params=changed_sharding_params, curr_rank=dist.get_rank(), diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 4e50c4f72..05ca485f2 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -27,14 +27,14 @@ def shards_all_to_all( changed_sharding_params: Dict[str, ParameterSharding], env: ShardingEnv, extend_shard_name: Callable[[str], str] = lambda x: x, -) -> Tuple[List[str], torch.Tensor]: +) -> Tuple[List[Tuple[str, int]], torch.Tensor]: """ Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. Assumes ranks are ordered in ParameterSharding.ranks. Args: - module (ShardedEmbeddingBagCollection): The module containing sharded tensors to be redistributed. - TODO: Update to support more modules + module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed. + TODO: Update to support more modules, currently only supports ShardedEmbeddingBagCollection. state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors. @@ -47,8 +47,9 @@ def shards_all_to_all( extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. Returns: - Tuple[List[str], torch.Tensor]: A tuple containing: - - A list of shard names that were sent from a specific rank to the current rank, ordered by rank, then shard order. + Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing: + - A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank. + This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order. - The tensor containing all shards received by the current rank after the all-to-all operation. """ if env.output_dtensor: @@ -62,10 +63,12 @@ def shards_all_to_all( rank = dist.get_rank() input_splits_per_rank = [[0] * world_size for _ in range(world_size)] output_splits_per_rank = [[0] * world_size for _ in range(world_size)] - local_input_tensor = torch.empty([0], device=device) - local_output_tensor = torch.empty([0], device=device) - shard_names_by_src_rank = [] + # 0 by default, as current rank may be recieving 0 shards + num_embeddings_received = 0 + output_tensor_tensor_count = 0 + shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)] + local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)] for shard_name, param in changed_sharding_params.items(): sharded_t = state_dict[extend_shard_name(shard_name)] assert param.ranks is not None @@ -84,27 +87,52 @@ def shards_all_to_all( src_rank = src_ranks[i] shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes - shard_size_dim_0 = shard_size[0] - input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_0 - output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_0 + shard_size_dim_1 = shard_size[1] + input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1 + output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1 if src_rank == rank: local_shards = sharded_t.local_shards() assert len(local_shards) == 1 - local_input_tensor = torch.cat( - ( - local_input_tensor, - sharded_t.local_shards()[0].tensor, - ) + local_table_to_input_tensor_by_dst_rank[dst_rank].append( + sharded_t.local_shards()[0].tensor ) if dst_rank == rank: - shard_names_by_src_rank.append(shard_name) - local_output_tensor = torch.cat( - (local_output_tensor, torch.empty(shard_size, device=device)) + shard_names_to_lengths_by_src_rank[src_rank].append( + (shard_name, shard_size_dim_1) ) + # NOTE: Only need to update num_embeddings_received to be the + # num_embeddings of shards if this rank is actually recieving + # any tensors + if num_embeddings_received == 0: + num_embeddings_received = shard_size[0] + else: + # TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable + # For now, assume that shard_sizes in dim 0 are all the same + assert num_embeddings_received == shard_size[0] + output_tensor_tensor_count += shard_size[1] local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] + local_input_tensor = torch.empty([0], device=device) + for sub_l in local_table_to_input_tensor_by_dst_rank: + for shard_info in sub_l: + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=1, + ) + + # Transposing the Tensors - because we are concatenating them along dimension 1 + # This is because dim 0 size may be different for different shards + # whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table + local_output_tensor = torch.empty( + [output_tensor_tensor_count, num_embeddings_received], device=device + ) + local_input_tensor = local_input_tensor.T.contiguous() + assert sum(local_output_splits) == len(local_output_tensor) assert sum(local_input_splits) == len(local_input_tensor) dist.all_to_all_single( @@ -115,12 +143,18 @@ def shards_all_to_all( group=dist.group.WORLD, ) - return shard_names_by_src_rank, local_output_tensor + flattened_output_names_lengths = [ + shard_info + for sub_l in shard_names_to_lengths_by_src_rank + for shard_info in sub_l + ] + + return flattened_output_names_lengths, local_output_tensor def update_state_dict_post_resharding( state_dict: Dict[str, ShardedTensor], - shard_names_by_src_rank: List[str], + ordered_shard_names_and_lengths: List[Tuple[str, int]], output_tensor: torch.Tensor, new_sharding_params: Dict[str, ParameterSharding], curr_rank: int, @@ -133,8 +167,9 @@ def update_state_dict_post_resharding( Args: state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards. - shard_names_by_src_rank (List[str]): A list of shard names that were sent from a specific rank to the - current rank, ordered by rank, then shard order. + shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1 + that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and + sizes by rank, then shard order. output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation. @@ -149,16 +184,14 @@ def update_state_dict_post_resharding( Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. """ slice_index = 0 - shard_names_by_src_rank shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} - for shard_name in shard_names_by_src_rank: - shard_size = state_dict[extend_shard_name(shard_name)].size(0) + for shard_name, shard_size in ordered_shard_names_and_lengths: end_slice_index = slice_index + shard_size shard_name_to_local_output_tensor[shard_name] = output_tensor[ slice_index:end_slice_index - ] + ].T slice_index = end_slice_index for shard_name, param in new_sharding_params.items(): diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 021ecc1a1..260e0cc78 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -26,6 +26,7 @@ from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.sharding_plan import ( + column_wise, construct_module_sharding_plan, get_module_to_default_sharders, table_wise, @@ -79,6 +80,23 @@ def generate_embedding_bag_config( return embedding_bag_config +def generate_rank_placements( + world_size: int, + num_tables: int, + ranks_per_tables: List[int], +) -> List[List[int]]: + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + placements = [] + max_rank = world_size - 1 + if ranks_per_tables == [0]: + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] + for i in range(num_tables): + ranks_per_table = ranks_per_tables[i] + placement = sorted(random.sample(range(world_size), ranks_per_table)) + placements.append(placement) + return placements + + def create_test_initial_state_dict( sharded_module_type: nn.Module, num_tables: int, @@ -379,19 +397,73 @@ def test_dynamic_sharding_ebc_tw( ) -> None: # Tests EBC dynamic sharding implementation for TW + # Table wise can only have 1 rank allocated per table: + ranks_per_tables = [1 for _ in range(num_tables)] # Cannot include old/new rank generation with hypothesis library due to depedency on world_size - old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] - new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + + while new_ranks == old_ranks: + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + per_param_sharding = {} + new_per_param_sharding = {} + + # Construct parameter shardings + for i in range(num_tables): + per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i][0]) + new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i][0]) + + self._run_ebc_resharding_test( + per_param_sharding, + new_per_param_sharding, + num_tables, + world_size, + data_type, + ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + num_tables=st.sampled_from([2, 3, 4]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + world_size=st.sampled_from([3, 4]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_dynamic_sharding_ebc_cw( + self, + num_tables: int, + data_type: DataType, + world_size: int, + ) -> None: + # Tests EBC dynamic sharding implementation for CW + + # Force the ranks per table to be consistent + ranks_per_tables = [ + random.randint(1, world_size - 1) for _ in range(num_tables) + ] + + old_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + new_ranks = generate_rank_placements(world_size, num_tables, ranks_per_tables) + + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size while new_ranks == old_ranks: - new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + old_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) per_param_sharding = {} new_per_param_sharding = {} # Construct parameter shardings for i in range(num_tables): - per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i]) - new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i]) + per_param_sharding[table_name(i)] = column_wise(ranks=old_ranks[i]) + new_per_param_sharding[table_name(i)] = column_wise(ranks=new_ranks[i]) self._run_ebc_resharding_test( per_param_sharding,