Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies
DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies

# TODO: This can be hardware dependent, need more empirical results to verify
A2A_INVERSE_BANDWITH_COEFFICIENT: float = 1 # empirical studies


def kernel_bw_lookup(
compute_device: str,
Expand Down
74 changes: 74 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.planner.constants import (
A2A_INVERSE_BANDWITH_COEFFICIENT,
BATCHED_COPY_PERF_FACTOR,
BIGINT_DTYPE,
DP_ELEMENTWISE_KERNELS_PERF_FACTOR,
Expand Down Expand Up @@ -461,6 +462,49 @@ def _get_expected_cache_prefetch_time(
prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
return prefetch_bytes / hbm_to_ddr_mem_bw

@classmethod
def _input_dist_expected_latency(
cls,
batch_sizes: List[int],
world_size: int,
local_world_size: int,
num_poolings: List[float],
input_lengths: List[float],
fwd_a2a_comm_data_type_size: float,
comms_bandwidths: GeneralizedCommsBandwidth,
) -> float:
"""
Calculates the expected latency for A2A input dist.

Args:
batch_sizes (int): The batch size for each input feature.
world_size (int): The total number of devices in the distributed setup.
local_world_size (int): The number of devices on a single host.
num_poolings (List[float]): Number of poolings per sample for each input feature.
input_lengths (List[float]): Average number of lookups per input feature.
fwd_a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.

Returns:
float: The expected latency (in seconds) for input distribution.
"""
batch_inputs = sum(
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
)
input_read_size = math.ceil(
batch_inputs * world_size * fwd_a2a_comm_data_type_size
)

comms_bw = comms_bandwidths.get_bw(
world_size=world_size,
local_world_size=local_world_size,
collective_type=CollectiveType.ALL_TO_ALL,
)
message_bw = input_read_size / comms_bw
input_dist_latency = message_bw * A2A_INVERSE_BANDWITH_COEFFICIENT

return input_dist_latency

@classmethod
def _get_tw_sharding_perf(
cls,
Expand Down Expand Up @@ -551,6 +595,15 @@ def _get_tw_sharding_perf(
hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
)

input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
fwd_a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)
# in order of model parallel execution, starting with:
# BWD DP -> BWD MP ... FWD MP -> FWD DP
return Perf(
Expand All @@ -559,6 +612,7 @@ def _get_tw_sharding_perf(
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down Expand Up @@ -658,13 +712,23 @@ def _get_rw_sharding_perf(
emb_dim,
table_data_type_size,
)
input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
fwd_a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)

return Perf(
fwd_compute=fwd_compute,
fwd_comms=fwd_comms,
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms + bwd_batched_copy,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down Expand Up @@ -790,13 +854,23 @@ def _get_twrw_sharding_perf(
emb_dim,
table_data_type_size,
)
input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
fwd_a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)

return Perf(
fwd_compute=fwd_compute,
fwd_comms=fwd_comms,
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms + bwd_batched_copy,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down
Loading
Loading