Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torchrec
import torchrec.pt2.checks
from hypothesis import given, settings, strategies as st, Verbosity
from torch._dynamo.testing import reduce_to_scalar_loss
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
Expand Down Expand Up @@ -56,6 +57,7 @@
from torchrec.pt2.utils import kjt_for_pt2_tracing
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
Expand Down Expand Up @@ -139,17 +141,22 @@ def sharding_types(self, compute_device_type: str) -> List[str]:


def _gen_model(test_model_type: _ModelType, mi: TestModelInfo) -> torch.nn.Module:
emb_dim: int = max(t.embedding_dim for t in mi.tables)
if test_model_type == _ModelType.EBC:

class M_ebc(torch.nn.Module):
def __init__(self, ebc: EmbeddingBagCollection) -> None:
super().__init__()
self._ebc = ebc
self._linear = torch.nn.Linear(
mi.num_float_features, emb_dim, device=mi.dense_device
)

def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor:
kt: KeyedTensor = self._ebc(x)
v = kt.values()
return torch.sigmoid(torch.mean(v, dim=1))
y = self._linear(y)
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))

return M_ebc(
EmbeddingBagCollection(
Expand All @@ -164,10 +171,15 @@ class M_fpebc(torch.nn.Module):
def __init__(self, fpebc: FeatureProcessedEmbeddingBagCollection) -> None:
super().__init__()
self._fpebc = fpebc
self._linear = torch.nn.Linear(
mi.num_float_features, emb_dim, device=mi.dense_device
)

def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor:
kt: KeyedTensor = self._fpebc(x)
return kt.values()
v = kt.values()
y = self._linear(y)
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))

return M_fpebc(
FeatureProcessedEmbeddingBagCollection(
Expand All @@ -187,9 +199,13 @@ def __init__(self, ec: EmbeddingCollection) -> None:
super().__init__()
self._ec = ec

def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]:
def forward(
self, x: KeyedJaggedTensor, y: torch.Tensor
) -> List[JaggedTensor]:
d: Dict[str, JaggedTensor] = self._ec(x)
return list(d.values())
v = torch.stack(d.values(), dim=0).sum(dim=0)
y = self._linear(y)
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))

return M_ec(
EmbeddingCollection(
Expand Down Expand Up @@ -307,6 +323,7 @@ def _test_compile_rank_fn(
# pyre-ignore
sharders=sharders,
device=device,
init_data_parallel=False,
)

if input_type == _InputType.VARIABLE_BATCH:
Expand Down Expand Up @@ -336,19 +353,27 @@ def _test_compile_rank_fn(
local_model_input = local_model_inputs[0].to(device)

kjt = local_model_input.idlist_features
ff = local_model_input.float_features
ff.requires_grad = True
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)

compile_input_ff = ff.clone().detach()

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)

dmp.train(True)

eager_out = dmp(kjt_ft)
eager_out = dmp(kjt_ft, ff)

eager_loss = reduce_to_scalar_loss(eager_out)
eager_loss.backward()

if torch_compile_backend is None:
return

##### COMPILE #####
run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"]
with dynamo_skipfiles_allow("torchrec"):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Expand All @@ -357,8 +382,14 @@ def _test_compile_rank_fn(
backend=torch_compile_backend,
fullgraph=True,
)
compile_out = opt_fn(kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb))
torch.testing.assert_close(eager_out, compile_out)
compile_out = opt_fn(
kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb), compile_input_ff
)
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
if run_compile_backward:
loss = reduce_to_scalar_loss(compile_out)
loss.backward()

##### COMPILE END #####

##### NUMERIC CHECK #####
Expand All @@ -368,9 +399,20 @@ def _test_compile_rank_fn(
local_model_input = local_model_inputs[1 + i].to(device)
kjt = local_model_input.idlist_features
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
eager_out_i = dmp(kjt_ft)
compile_out_i = opt_fn(kjt_ft)
torch.testing.assert_close(eager_out_i, compile_out_i)
ff = local_model_input.float_features
ff.requires_grad = True
eager_out_i = dmp(kjt_ft, ff)
eager_loss_i = reduce_to_scalar_loss(eager_out_i)
eager_loss_i.backward()

compile_input_ff = ff.detach().clone()
compile_out_i = opt_fn(kjt_ft, ff)
torch.testing.assert_close(
eager_out_i, compile_out_i, atol=1e-3, rtol=1e-3
)
if run_compile_backward:
loss_i = torch._dynamo.testing.reduce_to_scalar_loss(compile_out_i)
loss_i.backward()
##### NUMERIC CHECK END #####


Expand All @@ -396,14 +438,14 @@ def disable_cuda_tf32(self) -> bool:
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"eager",
"inductor",
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"eager",
"inductor",
),
(
_ModelType.EBC,
Expand All @@ -412,6 +454,13 @@ def disable_cuda_tf32(self) -> bool:
_ConvertToVariableBatch.FALSE,
"eager",
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"eager",
),
]
),
)
Expand All @@ -424,7 +473,7 @@ def test_compile_multiprocess(
str,
_InputType,
_ConvertToVariableBatch,
str,
Optional[str],
],
) -> None:
model_type, sharding_type, input_type, tovb, compile_backend = (
Expand Down