diff --git a/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py new file mode 100644 index 000000000..dc393bcc0 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +#!/usr/bin/env python3 + +from typing import Any, List + +import click +import torch +from torchrec.distributed.benchmark.benchmark_utils import benchmark_func +from torchrec.distributed.embedding import EmbeddingCollectionContext +from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def _set_sharding_context_post_a2a_previous( + kjts: List[KeyedJaggedTensor], + ctx: EmbeddingCollectionContext, +) -> None: + for kjt, sharding_context in zip(kjts, getattr(ctx, "sharding_contexts", [])): + if ( + hasattr(sharding_context, "batch_size_per_rank_per_feature") + and kjt.variable_stride_per_key() + and kjt.stride_per_key_per_rank() + ): + sharding_context.batch_size_per_rank_per_feature = [ + [ + kjt.stride_per_key_per_rank()[i][j] + for i in range(len(kjt.stride_per_key_per_rank())) + ] + for j in range(len(kjt.stride_per_key_per_rank()[0])) + ] + + +# buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_set_sharding_context_post_a2a -- --num_list=0 --num_keys=0 | grep set_sharding_context_post_a2a + + +@click.command() +@click.option("--num_list", default=100) +@click.option("--num_keys", default=100) +def main( + num_list: int, + num_keys: int, +) -> None: + if num_list == 0 and num_keys == 0: + for num_list in [100, 1000, 10000]: + for num_keys in [10, 100]: + op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous) + op_bench(num_list, num_keys, _set_sharding_context_post_a2a) + else: + op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous) + op_bench(num_list, num_keys, _set_sharding_context_post_a2a) + + +def op_bench( + num_list: int, + num_keys: int, + func_to_benchmark: Any, # pyre-ignore[2] +) -> None: + kjts = [ + KeyedJaggedTensor( + keys=["dummy_id"] * num_keys, + values=torch.IntTensor([1] * num_keys), + lengths=torch.IntTensor([1] * num_keys), + stride_per_key_per_rank=[[1]] * num_keys, + ) + for _ in range(num_list) + ] + for kjt in kjts: + kjt._variable_stride_per_key = True + ctx = EmbeddingCollectionContext( + sharding_contexts=[ + SequenceShardingContext(batch_size_per_rank_per_feature=[]) + for _ in range(num_list) + ] + ) + + bench_inputs = [] + + result = benchmark_func( + name=f"{func_to_benchmark.__name__}-{num_list}-{num_keys}", + bench_inputs=bench_inputs, + prof_inputs=bench_inputs, + num_benchmarks=10, + num_profiles=2, + profile_dir=".", + world_size=1, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs={"kjts": kjts, "ctx": ctx}, + rank=0, + pre_gpu_load=0, + device_type="cpu", + ) + print(result) + + +if __name__ == "__main__": + main() diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 22af274d6..a138d5c13 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -135,6 +135,8 @@ class BenchmarkResult: def __str__(self) -> str: runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms" + if len(self.mem_stats) == 0: + return f"{self.short_name: <{35}} | {runtime}" mem_alloc = ( f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB" ) @@ -749,11 +751,18 @@ def benchmark_func( func_to_benchmark(bench_inputs, **benchmark_func_kwargs) end[i].record() elif device_type == "cpu": - times = timeit.repeat( - lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs), - number=1, - repeat=num_benchmarks, - ) + if bench_inputs is None or len(bench_inputs) == 0: + times = timeit.repeat( + lambda: func_to_benchmark(**benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) + else: + times = timeit.repeat( + lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs), + number=1, + repeat=num_benchmarks, + ) if device_type == "cuda": if rank == -1: diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 0f37e71a1..98fa2d15f 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -662,12 +662,10 @@ def _set_sharding_context_post_a2a( and kjt.variable_stride_per_key() and kjt.stride_per_key_per_rank() ): + strides = kjt.stride_per_key_per_rank() sharding_context.batch_size_per_rank_per_feature = [ - [ - kjt.stride_per_key_per_rank()[i][j] - for i in range(len(kjt.stride_per_key_per_rank())) - ] - for j in range(len(kjt.stride_per_key_per_rank()[0])) + [strides[i][j] for i in range(len(strides))] + for j in range(len(strides[0])) ] diff --git a/torchrec/distributed/tests/test_embedding_sharding.py b/torchrec/distributed/tests/test_embedding_sharding.py index be2783e61..466cf1a16 100644 --- a/torchrec/distributed/tests/test_embedding_sharding.py +++ b/torchrec/distributed/tests/test_embedding_sharding.py @@ -14,8 +14,10 @@ from unittest.mock import MagicMock import hypothesis.strategies as st +import torch from hypothesis import given, settings +from torchrec.distributed.embedding import EmbeddingCollectionContext from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel @@ -24,6 +26,7 @@ _get_grouping_fused_params, _get_weighted_avg_cache_load_factor, _prefetch_and_cached, + _set_sharding_context_post_a2a, group_tables, ) @@ -31,7 +34,9 @@ GroupedEmbeddingConfig, ShardedEmbeddingTable, ) +from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext from torchrec.modules.embedding_configs import DataType, PoolingType +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase): @@ -489,3 +494,41 @@ def test_use_one_tbe_per_table( _get_table_names_by_groups(tables), [["table_0", "table_2", "table_4"], ["table_1", "table_1"], ["table_3"]], ) + + def test_set_sharding_context_post_a2a(self) -> None: + kjts = [ + KeyedJaggedTensor( + keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"], + values=torch.IntTensor([1] * 10), + lengths=torch.IntTensor([1] * 10), + stride_per_key_per_rank=[ + [1, 2], + [1, 2], + [2, 3], + [5, 7], + [3, 4], + ], + ), + KeyedJaggedTensor( + keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"], + values=torch.IntTensor([1] * 10), + lengths=torch.IntTensor([1] * 10), + stride_per_key_per_rank=[[3, 1], [5, 2], [7, 3], [1, 2], [6, 8]], + ), + ] + for kjt in kjts: + kjt._variable_stride_per_key = True + + ctx = EmbeddingCollectionContext( + sharding_contexts=[ + SequenceShardingContext(batch_size_per_rank_per_feature=[]), + SequenceShardingContext(batch_size_per_rank_per_feature=[]), + ] + ) + results = [ + [[1, 1, 2, 5, 3], [2, 2, 3, 7, 4]], + [[3, 5, 7, 1, 6], [1, 2, 3, 2, 8]], + ] + _set_sharding_context_post_a2a(kjts, ctx) + for context, result in zip(ctx.sharding_contexts, results): + self.assertEqual(context.batch_size_per_rank_per_feature, result)