Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

#!/usr/bin/env python3

from typing import Any, List

import click
import torch
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func
from torchrec.distributed.embedding import EmbeddingCollectionContext
from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


def _set_sharding_context_post_a2a_previous(
kjts: List[KeyedJaggedTensor],
ctx: EmbeddingCollectionContext,
) -> None:
for kjt, sharding_context in zip(kjts, getattr(ctx, "sharding_contexts", [])):
if (
hasattr(sharding_context, "batch_size_per_rank_per_feature")
and kjt.variable_stride_per_key()
and kjt.stride_per_key_per_rank()
):
sharding_context.batch_size_per_rank_per_feature = [
[
kjt.stride_per_key_per_rank()[i][j]
for i in range(len(kjt.stride_per_key_per_rank()))
]
for j in range(len(kjt.stride_per_key_per_rank()[0]))
]


# buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_set_sharding_context_post_a2a -- --num_list=0 --num_keys=0 | grep set_sharding_context_post_a2a


@click.command()
@click.option("--num_list", default=100)
@click.option("--num_keys", default=100)
def main(
num_list: int,
num_keys: int,
) -> None:
if num_list == 0 and num_keys == 0:
for num_list in [100, 1000, 10000]:
for num_keys in [10, 100]:
op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous)
op_bench(num_list, num_keys, _set_sharding_context_post_a2a)
else:
op_bench(num_list, num_keys, _set_sharding_context_post_a2a_previous)
op_bench(num_list, num_keys, _set_sharding_context_post_a2a)


def op_bench(
num_list: int,
num_keys: int,
func_to_benchmark: Any, # pyre-ignore[2]
) -> None:
kjts = [
KeyedJaggedTensor(
keys=["dummy_id"] * num_keys,
values=torch.IntTensor([1] * num_keys),
lengths=torch.IntTensor([1] * num_keys),
stride_per_key_per_rank=[[1]] * num_keys,
)
for _ in range(num_list)
]
for kjt in kjts:
kjt._variable_stride_per_key = True
ctx = EmbeddingCollectionContext(
sharding_contexts=[
SequenceShardingContext(batch_size_per_rank_per_feature=[])
for _ in range(num_list)
]
)

bench_inputs = []

result = benchmark_func(
name=f"{func_to_benchmark.__name__}-{num_list}-{num_keys}",
bench_inputs=bench_inputs,
prof_inputs=bench_inputs,
num_benchmarks=10,
num_profiles=2,
profile_dir=".",
world_size=1,
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs={"kjts": kjts, "ctx": ctx},
rank=0,
pre_gpu_load=0,
device_type="cpu",
)
print(result)


if __name__ == "__main__":
main()
19 changes: 14 additions & 5 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class BenchmarkResult:

def __str__(self) -> str:
runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms"
if len(self.mem_stats) == 0:
return f"{self.short_name: <{35}} | {runtime}"
mem_alloc = (
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB"
)
Expand Down Expand Up @@ -749,11 +751,18 @@ def benchmark_func(
func_to_benchmark(bench_inputs, **benchmark_func_kwargs)
end[i].record()
elif device_type == "cpu":
times = timeit.repeat(
lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)
if bench_inputs is None or len(bench_inputs) == 0:
times = timeit.repeat(
lambda: func_to_benchmark(**benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)
else:
times = timeit.repeat(
lambda: func_to_benchmark(bench_inputs, **benchmark_func_kwargs),
number=1,
repeat=num_benchmarks,
)

if device_type == "cuda":
if rank == -1:
Expand Down
8 changes: 3 additions & 5 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,12 +662,10 @@ def _set_sharding_context_post_a2a(
and kjt.variable_stride_per_key()
and kjt.stride_per_key_per_rank()
):
strides = kjt.stride_per_key_per_rank()
sharding_context.batch_size_per_rank_per_feature = [
[
kjt.stride_per_key_per_rank()[i][j]
for i in range(len(kjt.stride_per_key_per_rank()))
]
for j in range(len(kjt.stride_per_key_per_rank()[0]))
[strides[i][j] for i in range(len(strides))]
for j in range(len(strides[0]))
]


Expand Down
43 changes: 43 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from unittest.mock import MagicMock

import hypothesis.strategies as st
import torch

from hypothesis import given, settings
from torchrec.distributed.embedding import EmbeddingCollectionContext

from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel

Expand All @@ -24,14 +26,17 @@
_get_grouping_fused_params,
_get_weighted_avg_cache_load_factor,
_prefetch_and_cached,
_set_sharding_context_post_a2a,
group_tables,
)

from torchrec.distributed.embedding_types import (
GroupedEmbeddingConfig,
ShardedEmbeddingTable,
)
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
from torchrec.modules.embedding_configs import DataType, PoolingType
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase):
Expand Down Expand Up @@ -489,3 +494,41 @@ def test_use_one_tbe_per_table(
_get_table_names_by_groups(tables),
[["table_0", "table_2", "table_4"], ["table_1", "table_1"], ["table_3"]],
)

def test_set_sharding_context_post_a2a(self) -> None:
kjts = [
KeyedJaggedTensor(
keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"],
values=torch.IntTensor([1] * 10),
lengths=torch.IntTensor([1] * 10),
stride_per_key_per_rank=[
[1, 2],
[1, 2],
[2, 3],
[5, 7],
[3, 4],
],
),
KeyedJaggedTensor(
keys=["dummy_id", "video_id", "owner_id", "xray_concepts", "dummy_id2"],
values=torch.IntTensor([1] * 10),
lengths=torch.IntTensor([1] * 10),
stride_per_key_per_rank=[[3, 1], [5, 2], [7, 3], [1, 2], [6, 8]],
),
]
for kjt in kjts:
kjt._variable_stride_per_key = True

ctx = EmbeddingCollectionContext(
sharding_contexts=[
SequenceShardingContext(batch_size_per_rank_per_feature=[]),
SequenceShardingContext(batch_size_per_rank_per_feature=[]),
]
)
results = [
[[1, 1, 2, 5, 3], [2, 2, 3, 7, 4]],
[[3, 5, 7, 1, 6], [1, 2, 3, 2, 8]],
]
_set_sharding_context_post_a2a(kjts, ctx)
for context, result in zip(ctx.sharding_contexts, results):
self.assertEqual(context.batch_size_per_rank_per_feature, result)