Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 136 additions & 6 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ def gpu_preproc(x: StageOut) -> StageOut:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand Down Expand Up @@ -1695,7 +1695,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1744,7 +1744,7 @@ def gpu_preproc(x: StageOut) -> StageOut:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand All @@ -1762,7 +1762,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def test_model_detach(self) -> None:

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
stream=torch.cuda.Stream(),
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
)

Expand All @@ -1873,7 +1873,7 @@ def test_model_detach(self) -> None:
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.stream,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
]
Expand Down Expand Up @@ -1964,3 +1964,133 @@ def test_model_detach(self) -> None:
# Check pipeline exhausted
preproc_input = pipeline.progress(dataloader)
self.assertIsNone(preproc_input)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
@settings(max_examples=4, deadline=None)
# pyre-ignore[56]
@given(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]
),
cache_precision=st.sampled_from(
[
DataType.FP16,
DataType.FP32,
]
),
load_factor=st.sampled_from(
[
0.2,
0.4,
]
),
)
def test_pipelining_prefetch(
self,
sharding_type: str,
kernel_type: str,
cache_precision: DataType,
load_factor: float,
) -> None:
model = self._setup_model()

fused_params = {
"cache_load_factor": load_factor,
"cache_precision": cache_precision,
"stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16
}
fused_params_pipelined = {
**fused_params,
"prefetch_pipeline": True,
}

sharded_model, optim = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params
)
(
sharded_model_pipelined,
optim_pipelined,
) = self._generate_sharded_model_and_optimizer(
model, sharding_type, kernel_type, fused_params_pipelined
)

copy_state_dict(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

num_batches = 12
data = self._generate_data(
num_batches=num_batches,
batch_size=32,
)

non_pipelined_outputs = []
for batch in data:
batch = batch.to(self.device)
optim.zero_grad()
loss, pred = sharded_model(batch)
loss.backward()
optim.step()
non_pipelined_outputs.append(pred)

def gpu_preproc(x: StageOut) -> StageOut:
return x

sdd = SparseDataDistUtil[ModelInput](
model=sharded_model_pipelined,
data_dist_stream=torch.cuda.Stream(),
apply_jit=False,
prefetch_stream=torch.cuda.Stream(),
)

pipeline_stages = [
PipelineStage(
name="data_copy",
runnable=partial(get_h2d_func, device=self.device),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="start_sparse_data_dist",
runnable=sdd.start_sparse_data_dist,
stream=sdd.data_dist_stream,
fill_callback=sdd.wait_sparse_data_dist,
),
PipelineStage(
name="prefetch",
runnable=sdd.prefetch,
# pyre-ignore
stream=sdd.prefetch_stream,
fill_callback=sdd.load_prefetch,
),
]
pipeline = StagedTrainPipeline(
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
)
dataloader = iter(data)

pipelined_out = []
num_batches_processed = 0

while model_in := pipeline.progress(dataloader):
num_batches_processed += 1
optim_pipelined.zero_grad()
loss, pred = sharded_model_pipelined(model_in)
loss.backward()
optim_pipelined.step()
pipelined_out.append(pred)

self.assertEqual(num_batches_processed, num_batches)

self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
torch.testing.assert_close(out, ref_out)
49 changes: 11 additions & 38 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchrec.distributed.train_pipeline.utils import (
_override_input_dist_forwards,
_pipeline_detach_model,
_prefetch_embeddings,
_rewrite_model,
_start_data_dist,
_start_embedding_lookup,
Expand Down Expand Up @@ -1101,46 +1102,18 @@ def _prefetch(self, batch: Optional[In]) -> None:
batch.record_stream(
torch.get_device_module(self._device).current_stream()
)
data_per_pipelined_module = _prefetch_embeddings(
batch,
self._context,
self._pipelined_modules,
self._device,
self._stream_context,
self._data_dist_stream,
self._default_stream,
)
for sharded_module in self._pipelined_modules:
forward = sharded_module.forward
assert isinstance(forward, PrefetchPipelinedForward)

assert forward._name in self._context.input_dist_tensors_requests
request = self._context.input_dist_tensors_requests.pop(
forward._name
)
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with self._stream_context(self._data_dist_stream):
data = request.wait()

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
module_context = self._context.module_contexts[forward._name]
if self._data_dist_stream is not None:
torch.get_device_module(
self._device
).current_stream().wait_stream(self._data_dist_stream)
cur_stream = torch.get_device_module(
self._device
).current_stream()

assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)
data.record_stream(self._default_stream)

module_context.record_stream(cur_stream)
module_context.record_stream(self._default_stream)

sharded_module.prefetch(
ctx=module_context,
dist_input=data,
forward_stream=self._default_stream,
)
data = data_per_pipelined_module[forward._name]
self._context.module_input_post_prefetch[forward._name] = data
self._context.module_contexts_post_prefetch[forward._name] = (
self._context.module_contexts.pop(forward._name)
Expand Down
Loading