From 4bc15d667d870a179c48e1164703b882c6e9629d Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Mon, 17 Jun 2024 09:27:58 -0700 Subject: [PATCH] Enable inductor compilation for EBC-VB (#2125) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2125 Enabling inductor compilation tests for VB-path. Adding non-VB testing for CW sharding. (non-VB inductor compilation needs more changes to land) Reviewed By: PaulZhang12 Differential Revision: D58672604 --- .../tests/test_pt2_multiprocess.py | 79 +++++++++++++++---- 1 file changed, 64 insertions(+), 15 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index e0de78909..d309280ef 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -19,6 +19,7 @@ import torchrec import torchrec.pt2.checks from hypothesis import given, settings, strategies as st, Verbosity +from torch._dynamo.testing import reduce_to_scalar_loss from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig @@ -56,6 +57,7 @@ from torchrec.pt2.utils import kjt_for_pt2_tracing from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -139,17 +141,22 @@ def sharding_types(self, compute_device_type: str) -> List[str]: def _gen_model(test_model_type: _ModelType, mi: TestModelInfo) -> torch.nn.Module: + emb_dim: int = max(t.embedding_dim for t in mi.tables) if test_model_type == _ModelType.EBC: class M_ebc(torch.nn.Module): def __init__(self, ebc: EmbeddingBagCollection) -> None: super().__init__() self._ebc = ebc + self._linear = torch.nn.Linear( + mi.num_float_features, emb_dim, device=mi.dense_device + ) - def forward(self, x: KeyedJaggedTensor) -> torch.Tensor: + def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor: kt: KeyedTensor = self._ebc(x) v = kt.values() - return torch.sigmoid(torch.mean(v, dim=1)) + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) return M_ebc( EmbeddingBagCollection( @@ -164,10 +171,15 @@ class M_fpebc(torch.nn.Module): def __init__(self, fpebc: FeatureProcessedEmbeddingBagCollection) -> None: super().__init__() self._fpebc = fpebc + self._linear = torch.nn.Linear( + mi.num_float_features, emb_dim, device=mi.dense_device + ) - def forward(self, x: KeyedJaggedTensor) -> torch.Tensor: + def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor: kt: KeyedTensor = self._fpebc(x) - return kt.values() + v = kt.values() + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) return M_fpebc( FeatureProcessedEmbeddingBagCollection( @@ -187,9 +199,13 @@ def __init__(self, ec: EmbeddingCollection) -> None: super().__init__() self._ec = ec - def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]: + def forward( + self, x: KeyedJaggedTensor, y: torch.Tensor + ) -> List[JaggedTensor]: d: Dict[str, JaggedTensor] = self._ec(x) - return list(d.values()) + v = torch.stack(d.values(), dim=0).sum(dim=0) + y = self._linear(y) + return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1)) return M_ec( EmbeddingCollection( @@ -307,6 +323,7 @@ def _test_compile_rank_fn( # pyre-ignore sharders=sharders, device=device, + init_data_parallel=False, ) if input_type == _InputType.VARIABLE_BATCH: @@ -336,19 +353,27 @@ def _test_compile_rank_fn( local_model_input = local_model_inputs[0].to(device) kjt = local_model_input.idlist_features + ff = local_model_input.float_features + ff.requires_grad = True kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb) + compile_input_ff = ff.clone().detach() + torchrec.distributed.comm_ops.set_use_sync_collectives(True) torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True) dmp.train(True) - eager_out = dmp(kjt_ft) + eager_out = dmp(kjt_ft, ff) + + eager_loss = reduce_to_scalar_loss(eager_out) + eager_loss.backward() if torch_compile_backend is None: return ##### COMPILE ##### + run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"] with dynamo_skipfiles_allow("torchrec"): torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -357,8 +382,14 @@ def _test_compile_rank_fn( backend=torch_compile_backend, fullgraph=True, ) - compile_out = opt_fn(kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)) - torch.testing.assert_close(eager_out, compile_out) + compile_out = opt_fn( + kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb), compile_input_ff + ) + torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3) + if run_compile_backward: + loss = reduce_to_scalar_loss(compile_out) + loss.backward() + ##### COMPILE END ##### ##### NUMERIC CHECK ##### @@ -368,9 +399,20 @@ def _test_compile_rank_fn( local_model_input = local_model_inputs[1 + i].to(device) kjt = local_model_input.idlist_features kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb) - eager_out_i = dmp(kjt_ft) - compile_out_i = opt_fn(kjt_ft) - torch.testing.assert_close(eager_out_i, compile_out_i) + ff = local_model_input.float_features + ff.requires_grad = True + eager_out_i = dmp(kjt_ft, ff) + eager_loss_i = reduce_to_scalar_loss(eager_out_i) + eager_loss_i.backward() + + compile_input_ff = ff.detach().clone() + compile_out_i = opt_fn(kjt_ft, ff) + torch.testing.assert_close( + eager_out_i, compile_out_i, atol=1e-3, rtol=1e-3 + ) + if run_compile_backward: + loss_i = torch._dynamo.testing.reduce_to_scalar_loss(compile_out_i) + loss_i.backward() ##### NUMERIC CHECK END ##### @@ -396,14 +438,14 @@ def disable_cuda_tf32(self) -> bool: ShardingType.TABLE_WISE.value, _InputType.SINGLE_BATCH, _ConvertToVariableBatch.TRUE, - "eager", + "inductor", ), ( _ModelType.EBC, ShardingType.COLUMN_WISE.value, _InputType.SINGLE_BATCH, _ConvertToVariableBatch.TRUE, - "eager", + "inductor", ), ( _ModelType.EBC, @@ -412,6 +454,13 @@ def disable_cuda_tf32(self) -> bool: _ConvertToVariableBatch.FALSE, "eager", ), + ( + _ModelType.EBC, + ShardingType.COLUMN_WISE.value, + _InputType.SINGLE_BATCH, + _ConvertToVariableBatch.FALSE, + "eager", + ), ] ), ) @@ -424,7 +473,7 @@ def test_compile_multiprocess( str, _InputType, _ConvertToVariableBatch, - str, + Optional[str], ], ) -> None: model_type, sharding_type, input_type, tovb, compile_backend = (