From e314653b8bf3fabcf1f4f0bedcb6812debffe44d Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Tue, 14 May 2024 21:41:46 -0700 Subject: [PATCH] RegroupAsDict module Summary: Currently, we have KT.regroup as a functional call. Issue with this two fold: (1) we don't caching values we effectively know after first batch, leading to marginally higher cpu computation (2) this values look like unbacked SymInt in PT2 IR and most graph captures. Reality is they are known. So while a user change, we are adding a new module, to leverage these above insights. Benchmark (fwd+backward) [fallback] _regroup_keyed_tenors | B: 512 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 72.0 [prod] KeyedTensor.regroup | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.8 ms | Memory (P90): 72.0 [prod] KTRegroupAsDict | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 72.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 160 | device: cuda | Runtime (P90): 7.7 ms | Memory (P90): 144.0 [prod] KeyedTensor.regroup | B: 512 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 512 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 320 | device: cuda | Runtime (P90): 10.8 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 512 | F: 320 | device: cuda | Runtime (P90): 7.5 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 512 | F: 320 | device: cuda | Runtime (P90): 9.9 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 640 | device: cuda | Runtime (P90): 22.7 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 512 | F: 640 | device: cuda | Runtime (P90): 13.8 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 512 | F: 640 | device: cuda | Runtime (P90): 18.6 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 512 | F: 1280 | device: cuda | Runtime (P90): 58.0 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 512 | F: 1280 | device: cuda | Runtime (P90): 27.9 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 512 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [prod] KeyedTensor.regroup | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [prod] KTRegroupAsDict | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.6 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 1024 | F: 160 | device: cuda | Runtime (P90): 4.1 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 320 | device: cuda | Runtime (P90): 15.0 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 640 | device: cuda | Runtime (P90): 23.6 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 1024 | F: 640 | device: cuda | Runtime (P90): 19.3 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 1024 | F: 640 | device: cuda | Runtime (P90): 13.6 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 55.7 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 28.4 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 26.8 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0 [prod] KeyedTensor.regroup | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.5 ms | Memory (P90): 288.0 [prod] KTRegroupAsDict | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 160 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 2048 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 2048 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 320 | device: cuda | Runtime (P90): 11.2 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.2 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.8 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 2048 | F: 640 | device: cuda | Runtime (P90): 20.6 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 2048 | F: 640 | device: cuda | Runtime (P90): 14.6 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 54.5 ms | Memory (P90): 4608.0 [prod] KeyedTensor.regroup | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 28.3 ms | Memory (P90): 4608.0 [prod] KTRegroupAsDict | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 4608.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 576.0 [prod] KeyedTensor.regroup | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.7 ms | Memory (P90): 576.0 [prod] KTRegroupAsDict | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 576.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 160 | device: cuda | Runtime (P90): 5.8 ms | Memory (P90): 1152.0 [prod] KeyedTensor.regroup | B: 4096 | F: 160 | device: cuda | Runtime (P90): 4.4 ms | Memory (P90): 1152.0 [prod] KTRegroupAsDict | B: 4096 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 1152.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 320 | device: cuda | Runtime (P90): 11.1 ms | Memory (P90): 2304.0 [prod] KeyedTensor.regroup | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.8 ms | Memory (P90): 2304.0 [prod] KTRegroupAsDict | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 2304.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 4608.0 [prod] KeyedTensor.regroup | B: 4096 | F: 640 | device: cuda | Runtime (P90): 14.5 ms | Memory (P90): 4608.0 [prod] KTRegroupAsDict | B: 4096 | F: 640 | device: cuda | Runtime (P90): 13.3 ms | Memory (P90): 4608.0 [fallback] _regroup_keyed_tenors | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 64.0 ms | Memory (P90): 9216.0 [prod] KeyedTensor.regroup | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 26.9 ms | Memory (P90): 9216.0 [prod] KTRegroupAsDict | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 25.1 ms | Memory (P90): 9216.0 Reviewed By: PaulZhang12 Differential Revision: D57312926 --- torchrec/modules/regroup.py | 158 ++++++++++++++++++ torchrec/modules/tests/test_regroup.py | 134 +++++++++++++++ torchrec/sparse/jagged_tensor.py | 8 + .../sparse/tests/jagged_tensor_benchmark.py | 18 +- 4 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 torchrec/modules/regroup.py create mode 100644 torchrec/modules/tests/test_regroup.py diff --git a/torchrec/modules/regroup.py b/torchrec/modules/regroup.py new file mode 100644 index 000000000..d9e34abb6 --- /dev/null +++ b/torchrec/modules/regroup.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +from typing import Dict, List, Optional, Tuple + +import torch +from torchrec.sparse.jagged_tensor import ( + _all_keys_used_once, + _desugar_keyed_tensors, + _remap_to_groups, + KeyedTensor, +) + + +@torch.fx.wrap +def _concat_values(kts: List[KeyedTensor], dim: int) -> torch.Tensor: + return torch.cat([kt.values() for kt in kts], dim=dim) + + +@torch.fx.wrap +def _permuted_values( + kts: List[KeyedTensor], remap: List[Tuple[int, str]], dim: int +) -> torch.Tensor: + embedding_dicts = [kt.to_dict() for kt in kts] + values = [embedding_dicts[idx][key] for (idx, key) in remap] + return torch.cat(values, dim=dim) + + +@torch.fx.wrap +def _build_dict( + keys: List[str], values: torch.Tensor, splits: List[int], dim: int +) -> Dict[str, torch.Tensor]: + return { + key: tensor for key, tensor in zip(keys, torch.split(values, splits, dim=dim)) + } + + +class KTRegroupAsDict(torch.nn.Module): + """ + KTRegroupAsDict is a nn.Module that mirrors beahvior of static method KeyedTensor.regroup_as_dict() + + The advantage of using this module it caches the regrouping logic after first batch. + + Args: + groups (List[List[str]]): features per output group + keys (List[str]): key of each output group + + Example:: + + keys = ['object', 'user'] + groups = [['f1', 'f2'], ['f3']] + regroup_module = KTRegroupAsDict(groups, keys) + + + tensor_list = [torch.randn(2, 4), torch.randn(2, 8), torch.randn(2, 2)] + kts = [KeyedTensor.from_tensor_list(['f1', 'f2', 'f3' ], tensor_list)] + out = regroup_module(kts) + + """ + + def __init__(self, groups: List[List[str]], keys: List[str]) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + assert len(groups) == len(keys), "Groups and keys should have same length" + self._groups = groups + self._keys = keys + self._is_inited = False + + # cached values populated on first forward call + self.device: Optional[torch.device] = None + self._concat_dim: int = 1 + self._use_fbgemm_regroup: bool = False + self._splits: List[int] = [] + self._idx_key_pairs: List[Tuple[int, str]] = [] + self._permute_tensor: Optional[torch.Tensor] = None + self._inv_permute_tensor: Optional[torch.Tensor] = None + self._offsets_tensor: Optional[torch.Tensor] = None + self._inv_offsets_tensor: Optional[torch.Tensor] = None + + def _init_fbgemm_regroup(self, kts: List[KeyedTensor]) -> None: + self._use_fbgemm_regroup = True + keys, lengths, values = _desugar_keyed_tensors(kts) + permute, inv_permute, offsets, inv_offsets, splits = _remap_to_groups( + keys, lengths, self._groups + ) + # no need to pin_memory() or to(..., non_blocking=True) since occurs only once + self._permute_tensor = permute.to(self.device) + self._inv_permute_tensor = inv_permute.to(self.device) + self._offsets_tensor = offsets.to(self.device) + self._inv_offsets_tensor = inv_offsets.to(self.device) + self._splits = splits + + def _init_regroup(self, kts: List[KeyedTensor]) -> None: + lengths = [kt.length_per_key() for kt in kts] + indices = [kt._key_indices() for kt in kts] + + key_to_idx: dict[str, int] = {} + for i, kt in enumerate(kts): + for key in kt.keys(): + if key in key_to_idx: + raise RuntimeError( + f"Duplicate key {key} found in KeyedTensors, undefined behavior" + ) + key_to_idx[key] = i + + splits: List[int] = [] + idx_key_pairs: List[Tuple[int, str]] = [] + for group in self._groups: + group_length = 0 + for name in group: + idx_key_pairs.append((key_to_idx[name], name)) + group_length += lengths[key_to_idx[name]][ + indices[key_to_idx[name]][name] + ] + splits.append(group_length) + + self._splits = splits + self._idx_key_pairs = idx_key_pairs + + def forward(self, keyed_tensors: List[KeyedTensor]) -> Dict[str, torch.Tensor]: + if not self._is_inited: + assert len(keyed_tensors) > 0, "Empty list provided" + assert all( + kt.device == keyed_tensors[0].device for kt in keyed_tensors + ), "All inputs should be on the same device." + self.device = keyed_tensors[0].device + assert all( + kt.key_dim() == keyed_tensors[0].key_dim() for kt in keyed_tensors + ), "All inputs should have the same key_dim" + self._dim = keyed_tensors[0].key_dim() + + if _all_keys_used_once(keyed_tensors, self._groups) and self._dim == 1: + self._init_fbgemm_regroup(keyed_tensors) + else: + self._init_regroup(keyed_tensors) + self._is_inited = True + + if self._use_fbgemm_regroup: + values = _concat_values(keyed_tensors, self._dim) + permuted_values = torch.ops.fbgemm.permute_pooled_embs_auto_grad( + values, + self._offsets_tensor, + self._permute_tensor, + self._inv_offsets_tensor, + self._inv_permute_tensor, + ) + else: + permuted_values = _permuted_values( + keyed_tensors, self._idx_key_pairs, self._dim + ) + + return _build_dict(self._keys, permuted_values, self._splits, self._dim) diff --git a/torchrec/modules/tests/test_regroup.py b/torchrec/modules/tests/test_regroup.py new file mode 100644 index 000000000..4f00b99c1 --- /dev/null +++ b/torchrec/modules/tests/test_regroup.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch +import torch.fx + +from torchrec.modules.regroup import KTRegroupAsDict +from torchrec.sparse.jagged_tensor import _all_keys_used_once, KeyedTensor +from torchrec.sparse.tests.utils import build_groups, build_kts + + +class KTRegroupAsDictTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.kts = build_kts( + dense_features=20, + sparse_features=20, + dim_dense=64, + dim_sparse=128, + batch_size=128, + device=torch.device("cpu"), + run_backward=True, + ) + self.num_groups = 2 + self.keys = ["user", "object"] + self.labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + + def test_regroup_backward_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_regroup_backward(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + tensor_groups = regroup_module(self.kts) + pred0 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred0, self.labels).sum() + actual_kt_0_grad, actual_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + # clear grads so can reuse inputs + self.kts[0].values().grad = None + self.kts[1].values().grad = None + + tensor_groups = KeyedTensor.regroup_as_dict( + keyed_tensors=self.kts, groups=groups, keys=self.keys + ) + pred1 = tensor_groups["user"].sum(dim=1).mul(tensor_groups["object"].sum(dim=1)) + loss = torch.nn.functional.l1_loss(pred1, self.labels).sum() + expected_kt_0_grad, expected_kt_1_grad = torch.autograd.grad( + loss, [self.kts[0].values(), self.kts[1].values()] + ) + + torch.allclose(pred0, pred1) + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) + + def test_fx_and_jit_regroup(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=False, duplicates=False + ) + assert _all_keys_used_once(self.kts, groups) is True + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) + + def test_fx_and_jit_regroup_skips_and_duplicates(self) -> None: + groups = build_groups( + kts=self.kts, num_groups=self.num_groups, skips=True, duplicates=True + ) + assert _all_keys_used_once(self.kts, groups) is False + + regroup_module = KTRegroupAsDict(groups=groups, keys=self.keys) + # first pass + regroup_module(self.kts) + + # now trace + gm = torch.fx.symbolic_trace(regroup_module) + jit_gm = torch.jit.script(gm) + + out = jit_gm(self.kts) + eager_out = regroup_module(self.kts) + for key in out.keys(): + torch.allclose(out[key], eager_out[key]) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 3db2a61be..a326a6962 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -650,6 +650,10 @@ def to_padded_dense_weights( self.weights(), [self.offsets()], [N], padding_value ) + @property + def device(self) -> torch.device: + return self._values.device + def lengths(self) -> torch.Tensor: _lengths = _maybe_compute_lengths(self._lengths, self._offsets) self._lengths = _lengths @@ -2570,6 +2574,10 @@ def values(self) -> torch.Tensor: def key_dim(self) -> int: return self._key_dim + @property + def device(self) -> torch.device: + return self._values.device + def offset_per_key(self) -> List[int]: _offset_per_key = _maybe_compute_offset_per_key_kt( self._length_per_key, diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index ae6368f0c..1745910ea 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -16,6 +16,7 @@ import torch from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult +from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import ( _regroup_keyed_tensors, KeyedJaggedTensor, @@ -53,7 +54,10 @@ def wrapped_func( ) -> None: result = fn(**fn_kwargs) if run_backward: - vectors = [tensor.sum(dim=1) for tensor in result] + if isinstance(result, dict): + vectors = [tensor.sum(dim=1) for tensor in result.values()] + else: + vectors = [tensor.sum(dim=1) for tensor in result] pred = vectors[0] for vector in vectors[1:]: pred.mul(vector) @@ -216,6 +220,18 @@ def main( KeyedTensor.regroup, {"keyed_tensors": kts, "groups": groups}, ) + bench( + "[prod] KTRegroupAsDict", + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KTRegroupAsDict( + groups=groups, keys=[str(i) for i in range(n_groups)] + ), + {"keyed_tensors": kts}, + ) if __name__ == "__main__":