diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py new file mode 100644 index 000000000..d9e34abb6 --- /dev/null +++ b/torchrec/modules/regroup.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from typing import Dict, List, Optional, Tuple + +import torch +from torchrec.sparse.jagged_tensor import ( + _all_keys_used_once, + _desugar_keyed_tensors, + _remap_to_groups, + KeyedTensor, +) + + +@torch.fx.wrap +def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor: + return torch.cat([kt.values() for kt in kts], dim=dim) + + +@torch.fx.wrap +def _permuted_values( + kts: List[KeyedTensor], remap: List[Tuple[int, str]], dim: int +) -> torch.Tensor: + embedding_dicts = [kt.to_dict() for kt in kts] + values = [embedding_dicts[idx][key] for (idx, key) in remap] + return torch.cat(values, dim=dim) + + +@torch.fx.wrap +def _build_dict( + keys: List[str], values: torch.Tensor, splits: List[int], dim: int +) -> Dict[str, torch.Tensor]: + return { + key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim)) + } + + +class KTRegroupAsDict(torch.nn.Module): + """ + KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict() + + The advantage of using this module it caches the regrouping logic after first batch. + + Args: + groups (List[List[str]]): features per output group + keys (List[str]): key of each output group + + Example:: + + keys = ['object', 'user'] + groups = [['f1', 'f2'], ['f3']] + regroup_module = KTRegroupAsDict(groups, keys) + + + tensor_list = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)] + kts = [KeyedTensor.from_tensor_list(['f1', 'f2', 'f3' ], tensor_list)] + out = regroup_module(kts) + + """ + + def __init__(self, groups: List[List[str]], keys: List[str]) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + assert len(groups) == len(keys), "Groups and keys should have same length" + self._groups = groups + self._keys = keys + self._is_inited = False + + # cached values populated on first forward call + self.device: Optional[torch.device] = None + self._concat_dim: int = 1 + self._use_fbgemm_regroup: bool = False + self._splits: List[int] = [] + self._idx_key_pairs: List[Tuple[int, str]] = [] + self._permute_tensor: Optional[torch.Tensor] = None + self._inv_permute_tensor: Optional[torch.Tensor] = None + self._offsets_tensor: Optional[torch.Tensor] = None + self._inv_offsets_tensor: Optional[torch.Tensor] = None + + def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None: + self._use_fbgemm_regroup = True + keys, lengths, values = _desugar_keyed_tensors(kts) + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + keys, lengths, self._groups + ) + # no need to pin_memory() or to(..., non_blocking=True) since occurs only once + self._permute_tensor = permute.to(self.device) + self._inv_permute_tensor = inv_permute.to(self.device) + self._offsets_tensor = offsets.to(self.device) + self._inv_offsets_tensor = inv_offsets.to(self.device) + self._splits = splits + + def _init_regroup(self, kts: List[KeyedTensor]) -> None: + lengths = [kt.length_per_key() for kt in kts] + indices = [kt._key_indices() for kt in kts] + + key_to_idx: dict[str, int] = {} + for i, kt in enumerate(kts): + for key in kt.keys(): + if key in key_to_idx: + raise RuntimeError( + f"Duplicate key {key} found in KeyedTensors, undefined behavior" + ) + key_to_idx[key] = i + + splits: List[int] = [] + idx_key_pairs: List[Tuple[int, str]] = [] + for group in self._groups: + group_length = 0 + for name in group: + idx_key_pairs.append((key_to_idx[name], name)) + group_length += lengths[key_to_idx[name]][ + indices[key_to_idx[name]][name] + ] + splits.append(group_length) + + self._splits = splits + self._idx_key_pairs = idx_key_pairs + + def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + if not self._is_inited: + assert len(keyed_tensors) > 0, "Empty list provided" + assert all( + kt.device == keyed_tensors[0].device for kt in keyed_tensors + ), "All inputs should be on the same device." + self.device = keyed_tensors[0].device + assert all( + kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors + ), "All inputs should have the same key_dim" + self._dim = keyed_tensors[0].key_dim() + + if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1: + self._init_fbgemm_regroup(keyed_tensors) + else: + self._init_regroup(keyed_tensors) + self._is_inited = True + + if self._use_fbgemm_regroup: + values = _concat_values(keyed_tensors, self._dim) + permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( + values, + self._offsets_tensor, + self._permute_tensor, + self._inv_offsets_tensor, + self._inv_permute_tensor, + ) + else: + permuted_values = _permuted_values( + keyed_tensors, self._idx_key_pairs, self._dim + ) + + return _build_dict(self._keys, permuted_values, self._splits, self._dim) diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py new file mode 100644 index 000000000..4f00b99c1 --- /dev/null +++ b/torchrec/modules/tests/test_regroup.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +import torch.fx + +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor +from torchrec.sparse.tests.utils import build_groups, build_kts + + +class KTRegroupAsDictTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + self.num_groups = 2 + self.keys = ["user", "object"] + self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + def test_regroup_backward_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_backward(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_fx_and_jit_regroup(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) + + def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 3db2a61be..a326a6962 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -650,6 +650,10 @@ def to_padded_dense_weights( self.weights(), [self.offsets()], [N], padding_value ) + @property + def device(self) -> torch.device: + return self._values.device + def lengths(self) -> torch.Tensor: _lengths = _maybe_compute_lengths(self._lengths, self._offsets) self._lengths = _lengths @@ -2570,6 +2574,10 @@ def values(self) -> torch.Tensor: def key_dim(self) -> int: return self._key_dim + @property + def device(self) -> torch.device: + return self._values.device + def offset_per_key(self) -> List[int]: _offset_per_key = _maybe_compute_offset_per_key_kt( self._length_per_key, diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index ae6368f0c..1745910ea 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -16,6 +16,7 @@ import torch from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult +from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( _regroup_keyed_tensors, KeyedJaggedTensor, @@ -53,7 +54,10 @@ def wrapped_func( ) -> None: result = fn(**fn_kwargs) if run_backward: - vectors = [tensor.sum(dim=1) for tensor in result] + if isinstance(result, dict): + vectors = [tensor.sum(dim=1) for tensor in result.values()] + else: + vectors = [tensor.sum(dim=1) for tensor in result] pred = vectors[0] for vector in vectors[1:]: pred.mul(vector) @@ -216,6 +220,18 @@ def main( KeyedTensor.regroup, {"keyed_tensors": kts, "groups": groups}, ) + bench( + "[prod] KTRegroupAsDict", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KTRegroupAsDict( + groups=groups, keys=[str(i) for i in range(n_groups)] + ), + {"keyed_tensors": kts}, + ) if __name__ == "__main__":