Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions torchrec/distributed/benchmark/benchmark_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,14 @@ def non_blocking_copy(
num_concat: int,
ctx: MultiProcessContext,
preallocated: bool = False,
use_data_copy_stream: bool = True,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## setup ##"):
main_stream = torch.cuda.current_stream()
data_copy_stream = torch.cuda.Stream()
data_copy_stream = (
torch.cuda.Stream() if use_data_copy_stream else nullcontext()
)
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5

# the host to device data transfer will block cuda execution without the `pin_memory()`
Expand Down Expand Up @@ -519,7 +522,8 @@ def non_blocking_copy(

with record_function("## pre-comms compute ##"):
# make sure the data copy is done before the pre-comms compute
main_stream.wait_stream(data_copy_stream)
if use_data_copy_stream:
main_stream.wait_stream(data_copy_stream)
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=device_data
)
Expand All @@ -543,6 +547,24 @@ def preallocated_non_blocking_copy(
)


def blocking_copy(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
return non_blocking_copy(
_batch_inputs=_batch_inputs,
dim=dim,
num_mul=num_mul,
num_concat=num_concat,
ctx=ctx,
use_data_copy_stream=False,
)


# single-rank runner
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
# Ensure GPUs are available and we have enough of them
Expand Down Expand Up @@ -576,6 +598,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
func = non_blocking_copy
case "preallocated_non_blocking_copy":
func = preallocated_non_blocking_copy
case "blocking_copy":
func = blocking_copy
case _:
raise ValueError(f"Unknown benchmark name: {arg.name}")

Expand Down
Loading