diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index 89c2fe690..9e9d3bd73 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -14,6 +14,7 @@ StagedTrainPipeline, # noqa TrainPipeline, # noqa TrainPipelineBase, # noqa + TrainPipelinePT2, # noqa TrainPipelineSparseDist, # noqa ) from torchrec.distributed.train_pipeline.utils import ( # noqa diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index b41de4346..a48e3ca17 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -17,6 +17,7 @@ import torch from hypothesis import given, settings, strategies as st, Verbosity from torch import nn, optim +from torch._dynamo.testing import reduce_to_scalar_loss from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder @@ -45,6 +46,7 @@ PrefetchTrainPipelineSparseDist, StagedTrainPipeline, TrainPipelineBase, + TrainPipelinePT2, TrainPipelineSemiSync, TrainPipelineSparseDist, ) @@ -63,10 +65,13 @@ ShardingPlan, ShardingType, ) -from torchrec.modules.embedding_configs import DataType +from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.optim.keyed import KeyedOptimizerWrapper from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.pt2.utils import kjt_for_pt2_tracing +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Pipelineable @@ -93,6 +98,7 @@ def __init__(self) -> None: super().__init__() self.model = nn.Linear(10, 1) self.loss_fn = nn.BCEWithLogitsLoss() + self._dummy_setting: str = "dummy" def forward( self, model_input: ModelInputSimple @@ -156,6 +162,154 @@ def test_equal_to_non_pipelined(self) -> None: self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) +class TrainPipelinePT2Test(unittest.TestCase): + def setUp(self) -> None: + self.device = torch.device("cuda:0") + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + def gen_eb_conf_list(self, is_weighted: bool = False) -> List[EmbeddingBagConfig]: + weighted_prefix = "weighted_" if is_weighted else "" + + return [ + EmbeddingBagConfig( + num_embeddings=256, + embedding_dim=12, + name=weighted_prefix + "table_0", + feature_names=[weighted_prefix + "f0"], + ), + EmbeddingBagConfig( + num_embeddings=256, + embedding_dim=12, + name=weighted_prefix + "table_1", + feature_names=[weighted_prefix + "f1"], + ), + ] + + def gen_model( + self, device: torch.device, ebc_list: List[EmbeddingBagConfig] + ) -> nn.Module: + class M_ebc(torch.nn.Module): + def __init__(self, vle: EmbeddingBagCollection) -> None: + super().__init__() + self.model = vle + + def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]: + kt: KeyedTensor = self.model(x) + return list(kt.to_dict().values()) + + return M_ebc( + EmbeddingBagCollection( + device=device, + tables=ebc_list, + ) + ) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_equal_to_non_pipelined(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelinePT2(model_gpu, optimizer_gpu, self.device) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_pre_compile_fn(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + + def pre_compile_fn(model: nn.Module) -> None: + model._dummy_setting = "dummy modified" + + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model_gpu, optimizer_gpu, self.device, pre_compile_fn=pre_compile_fn + ) + self.assertEqual(model_gpu._dummy_setting, "dummy") + for _ in range(len(data)): + pipeline.progress(dataloader) + self.assertEqual(model_gpu._dummy_setting, "dummy modified") + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_equal_to_non_pipelined_with_input_transformer(self) -> None: + cpu = torch.device("cpu:0") + eb_conf_list = self.gen_eb_conf_list() + eb_conf_list_weighted = self.gen_eb_conf_list(is_weighted=True) + + model_cpu = self.gen_model(cpu, eb_conf_list) + model_gpu = self.gen_model(self.device, eb_conf_list).to(self.device) + + _, local_model_inputs = ModelInput.generate( + batch_size=10, + world_size=4, + num_float_features=8, + tables=eb_conf_list, + weighted_tables=eb_conf_list_weighted, + variable_batch_size=False, + ) + + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + + data = [i.idlist_features for i in local_model_inputs] + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss = reduce_to_scalar_loss(loss) + loss.backward() + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + torch.testing.assert_close(pred_gpu.cpu(), pred) + + class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 90dcb0aaf..1f48444d3 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -10,6 +10,7 @@ import abc import logging from collections import deque +from dataclasses import dataclass from typing import ( Any, Callable, @@ -70,6 +71,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: pass +@dataclass +class TorchCompileConfig: + """ + Configs for torch.compile + + fullgraph: bool = False, whether to compile the whole graph or not + dynamic: bool = False, whether to use dynamic shapes or not + backend: str = "inductor", which compiler to use (either inductor or aot) + compile_on_iter: int = 3, compile the model on which iteration + this is useful when we want to profile the first few iterations of training + and then start using compiled model from iteration #3 onwards + """ + + fullgraph: bool = False + dynamic: bool = False + backend: str = "inductor" + compile_on_iter: int = 3 + + class TrainPipelineBase(TrainPipeline[In, Out]): """ This class runs training iterations using a pipeline of two stages, each as a CUDA @@ -138,6 +158,82 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: return output +class TrainPipelinePT2(TrainPipelineBase[In, Out]): + """ + This pipeline uses PT2 compiler to compile the model and run it in a single stream (default) + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where the model is run + compile_configs (TorchCompileConfig): configs for compling the model + pre_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute before compiling the model + post_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute after compiling the model + input_transformer (Callable[[In], In]): transforms the input before passing it to the model. + This is useful when we want to transform KJT parameters for PT2 tracing + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + compile_configs: Optional[TorchCompileConfig] = None, + pre_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, + post_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, + input_transformer: Optional[Callable[[In], In]] = None, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._compile_configs: TorchCompileConfig = ( + compile_configs or TorchCompileConfig() + ) + self._pre_compile_fn = pre_compile_fn + self._post_compile_fn = post_compile_fn + self._input_transformer = input_transformer + self._iter = 0 + self._cur_batch: Optional[In] = None + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + cc = self._compile_configs + + with record_function("## load_batch ##"): + cur_batch = next(dataloader_iter) + + if self._input_transformer: + cur_batch = self._input_transformer(cur_batch) + + with record_function("## copy_batch_to_gpu ##"): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## forward ##"): + if self._iter == cc.compile_on_iter: + logger.info("Compiling model...") + if self._pre_compile_fn: + self._pre_compile_fn(self._model) + self._model.compile( + fullgraph=cc.fullgraph, dynamic=cc.dynamic, backend=cc.backend + ) + if self._post_compile_fn: + self._post_compile_fn(self._model) + + (losses, output) = self._model(self._cur_batch) + self._iter += 1 + + if self._model.training: + with record_function("## backward ##"): + torch.sum(losses).backward() + + with record_function("## optimizer ##"): + self._optimizer.step() + + return output + + class TrainPipelineSparseDist(TrainPipeline[In, Out]): """ This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with