Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,16 +456,23 @@ def _get_sharded_local_buckets_for_zero_collision(

for table in embedding_tables:
total_num_buckets = none_throws(table.total_num_buckets)
assert (
total_num_buckets % world_size == 0
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
assert (
table.total_num_buckets
and table.num_embeddings % table.total_num_buckets == 0
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
bucket_offset_start = total_num_buckets // world_size * local_rank
extra_local_buckets = int(local_rank < (total_num_buckets % world_size))
extra_bucket_padding = (
(total_num_buckets % world_size)
if local_rank >= (total_num_buckets % world_size)
else 0
)
bucket_offset_start = (
total_num_buckets // world_size + extra_local_buckets
) * local_rank + extra_bucket_padding
bucket_offset_end = min(
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
total_num_buckets,
(total_num_buckets // world_size + extra_local_buckets) * (local_rank + 1)
+ extra_bucket_padding,
)
bucket_size = (
table.num_embeddings + total_num_buckets - 1
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,13 @@ def create_virtual_table_global_metadata(
# Otherwise it will only set correct size on current rank and
# virtual PMT will trigger recalc for the correct global size/offset.
# NOTE this currently only works for row-wise sharding
my_rank_shard_size = metadata.shards_metadata[my_rank].shard_sizes[0]
for rank, shard_metadata in enumerate(metadata.shards_metadata):
if use_param_size_as_rows: # respect the param size and treat it as rows
curr_rank_rows = param.size()[0] # pyre-ignore[16]
# The param size only has the information for my_rank. In order to
# correctly calculate the size for other ranks, we need to use the current
# rank's shard size compared to the shard size of my_rank.
curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16]
else:
curr_rank_rows = (
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
Expand Down
34 changes: 33 additions & 1 deletion torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
ShardingType,
)
from torchrec.modules.embedding_configs import DataType
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
EmbeddingCollection,
)
from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection


Expand Down Expand Up @@ -178,7 +182,7 @@ def enumerate(
# skip for other device groups
if device_group and device_group != self._compute_device:
continue

num_buckets = self._get_num_buckets(name, child_module)
sharding_options_per_table: List[ShardingOption] = []

for sharding_type in self._filter_sharding_types(
Expand All @@ -200,6 +204,7 @@ def enumerate(
sharding_type=sharding_type,
col_wise_shard_dim=col_wise_shard_dim,
device_memory_sizes=self._device_memory_sizes,
num_buckets=num_buckets,
)
except ZeroDivisionError as e:
# Re-raise with additional context about the table and module
Expand Down Expand Up @@ -264,6 +269,33 @@ def enumerate(
self._last_stored_search_space = copy.deepcopy(sharding_options)
return sharding_options

def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]:
"""
Get the number of buckets for each embedding table.

Args:
parameter (str): name of the embedding table.
module (nn.Module): module to be sharded.

Returns:
Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found.
"""
# If module is not of type EmbeddingBagCollection, return None
if isinstance(module, EmbeddingBagCollection):
embedding_configs = module.embedding_bag_configs()
elif isinstance(module, EmbeddingCollection):
embedding_configs = module.embedding_configs()
else:
return None

# Find the embedding config for the table with the same name as parameter input
for config in embedding_configs:
if config.name == parameter and config.use_virtual_table:
return config.total_num_buckets

# If table with matching name not found, return None
return None

@property
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
# NOTE: This is the last search space stored by enumerate(...), do not use
Expand Down
197 changes: 196 additions & 1 deletion torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
EmbeddingTowerSharder,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.embeddingbag import (
EmbeddingBagCollection,
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.mc_embeddingbag import (
ManagedCollisionEmbeddingBagCollectionSharder,
)
Expand All @@ -45,13 +48,27 @@
[[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]],
]

EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
[[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]],
[[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]],
[[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]],
[[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]],
]

EXPECTED_RW_SHARD_OFFSETS = [
[[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]],
[[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]],
[[0, 0], [15, 0], [30, 0], [45, 0], [60, 0], [75, 0], [90, 0], [105, 0]],
[[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]],
]

EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
[[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]],
[[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]],
[[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]],
[[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]],
]


def get_expected_cache_aux_size(rows: int) -> int:
# 0.2 is the hardcoded cache load factor assumed in this test
Expand Down Expand Up @@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
],
]

EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [
[
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
Storage(hbm=165888, ddr=0),
],
[
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
Storage(hbm=1001472, ddr=0),
],
[
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
Storage(hbm=1003520, ddr=0),
],
[
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
Storage(hbm=2648064, ddr=0),
],
]

EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
[
Expand Down Expand Up @@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
],
]

EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
[
Storage(hbm=166352, ddr=1600),
Storage(hbm=166352, ddr=1600),
Storage(hbm=166120, ddr=800),
Storage(hbm=166120, ddr=800),
Storage(hbm=166120, ddr=800),
Storage(hbm=166120, ddr=800),
Storage(hbm=166120, ddr=800),
Storage(hbm=166120, ddr=800),
],
[
Storage(hbm=1002335, ddr=3520),
Storage(hbm=1002335, ddr=3520),
Storage(hbm=1001904, ddr=1760),
Storage(hbm=1001904, ddr=1760),
Storage(hbm=1001904, ddr=1760),
Storage(hbm=1001904, ddr=1760),
Storage(hbm=1001904, ddr=1760),
Storage(hbm=1001904, ddr=1760),
],
[
Storage(hbm=1004845, ddr=5760),
Storage(hbm=1004845, ddr=5760),
Storage(hbm=1004183, ddr=2880),
Storage(hbm=1004183, ddr=2880),
Storage(hbm=1004183, ddr=2880),
Storage(hbm=1004183, ddr=2880),
Storage(hbm=1004183, ddr=2880),
Storage(hbm=1004183, ddr=2880),
],
[
Storage(hbm=2649916, ddr=8320),
Storage(hbm=2649916, ddr=8320),
Storage(hbm=2648990, ddr=4160),
Storage(hbm=2648990, ddr=4160),
Storage(hbm=2648990, ddr=4160),
Storage(hbm=2648990, ddr=4160),
Storage(hbm=2648990, ddr=4160),
Storage(hbm=2648990, ddr=4160),
],
]

EXPECTED_TWRW_SHARD_SIZES = [
[[25, 20], [25, 20], [25, 20], [25, 20]],
Expand Down Expand Up @@ -248,6 +349,16 @@ def compute_kernels(
return [EmbeddingComputeKernel.FUSED.value]


class VirtualTableRWSharder(EmbeddingBagCollectionSharder):
def sharding_types(self, compute_device_type: str) -> List[str]:
return [ShardingType.ROW_WISE.value]

def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value]


class UVMCachingRWSharder(EmbeddingBagCollectionSharder):
def sharding_types(self, compute_device_type: str) -> List[str]:
return [ShardingType.ROW_WISE.value]
Expand Down Expand Up @@ -357,6 +468,27 @@ def setUp(self) -> None:
min_partition=40, pooling_factors=[2, 1, 3, 7]
),
}
self._virtual_table_constraints = {
"table_0": ParameterConstraints(
min_partition=20,
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
),
"table_1": ParameterConstraints(
min_partition=20,
pooling_factors=[1, 3, 5],
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
),
"table_2": ParameterConstraints(
min_partition=20,
pooling_factors=[8, 2],
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
),
"table_3": ParameterConstraints(
min_partition=40,
pooling_factors=[2, 1, 3, 7],
compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value],
),
}
self.num_tables = 4
tables = [
EmbeddingBagConfig(
Expand All @@ -367,6 +499,17 @@ def setUp(self) -> None:
)
for i in range(self.num_tables)
]
tables_with_buckets = [
EmbeddingBagConfig(
num_embeddings=100 + i * 10,
embedding_dim=20 + i * 20,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
total_num_buckets=10,
use_virtual_table=True,
)
for i in range(self.num_tables)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
Expand All @@ -377,6 +520,9 @@ def setUp(self) -> None:
for i in range(4)
]
self.model = TestSparseNN(tables=tables, weighted_tables=[])
self.model_with_buckets = EmbeddingBagCollection(
tables=tables_with_buckets,
)
self.enumerator = EmbeddingEnumerator(
topology=Topology(
world_size=self.world_size,
Expand All @@ -386,6 +532,15 @@ def setUp(self) -> None:
batch_size=self.batch_size,
constraints=self.constraints,
)
self.virtual_table_enumerator = EmbeddingEnumerator(
topology=Topology(
world_size=self.world_size,
compute_device=self.compute_device,
local_world_size=self.local_world_size,
),
batch_size=self.batch_size,
constraints=self._virtual_table_constraints,
)
self.tower_model = TestTowerSparseNN(
tables=tables, weighted_tables=weighted_tables
)
Expand Down Expand Up @@ -514,6 +669,26 @@ def test_rw_sharding(self) -> None:
EXPECTED_RW_SHARD_STORAGE[i],
)

def test_virtual_table_rw_sharding_with_buckets(self) -> None:
sharding_options = self.virtual_table_enumerator.enumerate(
self.model_with_buckets,
[cast(ModuleSharder[torch.nn.Module], VirtualTableRWSharder())],
)
for i, sharding_option in enumerate(sharding_options):
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
self.assertEqual(
[shard.size for shard in sharding_option.shards],
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
)
self.assertEqual(
[shard.offset for shard in sharding_option.shards],
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
)
self.assertEqual(
[shard.storage for shard in sharding_option.shards],
EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS[i],
)

def test_uvm_caching_rw_sharding(self) -> None:
sharding_options = self.enumerator.enumerate(
self.model,
Expand All @@ -535,6 +710,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i],
)

def test_uvm_caching_rw_sharding_with_buckets(self) -> None:
sharding_options = self.enumerator.enumerate(
self.model_with_buckets,
[cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())],
)
for i, sharding_option in enumerate(sharding_options):
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
self.assertEqual(
[shard.size for shard in sharding_option.shards],
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
)
self.assertEqual(
[shard.offset for shard in sharding_option.shards],
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
)
self.assertEqual(
[shard.storage for shard in sharding_option.shards],
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i],
)

def test_twrw_sharding(self) -> None:
sharding_options = self.enumerator.enumerate(
self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())]
Expand Down
Loading
Loading