diff --git a/torchrec/distributed/benchmark/benchmark_comms.py b/torchrec/distributed/benchmark/benchmark_comms.py index 802e30793..456237fb6 100644 --- a/torchrec/distributed/benchmark/benchmark_comms.py +++ b/torchrec/distributed/benchmark/benchmark_comms.py @@ -483,11 +483,14 @@ def non_blocking_copy( num_concat: int, ctx: MultiProcessContext, preallocated: bool = False, + use_data_copy_stream: bool = True, **_kwargs: Dict[str, Any], ) -> None: with record_function("## setup ##"): main_stream = torch.cuda.current_stream() - data_copy_stream = torch.cuda.Stream() + data_copy_stream = ( + torch.cuda.Stream() if use_data_copy_stream else nullcontext() + ) irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 # the host to device data transfer will block cuda execution without the `pin_memory()` @@ -519,7 +522,8 @@ def non_blocking_copy( with record_function("## pre-comms compute ##"): # make sure the data copy is done before the pre-comms compute - main_stream.wait_stream(data_copy_stream) + if use_data_copy_stream: + main_stream.wait_stream(data_copy_stream) pre_comms = _compute( dim=dim, num_mul=num_mul, num_concat=1, ctx=ctx, x=device_data ) @@ -543,6 +547,24 @@ def preallocated_non_blocking_copy( ) +def blocking_copy( + _batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], +) -> None: + return non_blocking_copy( + _batch_inputs=_batch_inputs, + dim=dim, + num_mul=num_mul, + num_concat=num_concat, + ctx=ctx, + use_data_copy_stream=False, + ) + + # single-rank runner def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None: # Ensure GPUs are available and we have enough of them @@ -576,6 +598,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) func = non_blocking_copy case "preallocated_non_blocking_copy": func = preallocated_non_blocking_copy + case "blocking_copy": + func = blocking_copy case _: raise ValueError(f"Unknown benchmark name: {arg.name}")