diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 1236a8a13..cb2f99828 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -492,9 +492,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: def benchmark( name: str, model: torch.nn.Module, - warmup_inputs: List[KeyedJaggedTensor], - bench_inputs: List[KeyedJaggedTensor], - prof_inputs: List[KeyedJaggedTensor], + warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + bench_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], + prof_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]], world_size: int, output_dir: str, num_benchmarks: int, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index f703fe8ec..c25ef7ffa 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -36,6 +36,12 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: pass @@ -164,6 +170,21 @@ def _all_keys_used_once( return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) +@torch.fx.wrap +def permute_multi_embedding( + keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] +) -> List[torch.Tensor]: + keys, lengths, values = _desugar_keyed_tensors(keyed_tensors) + permutes, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups) + permuted_values = torch.ops.fbgemm.permute_multi_embedding( + values, + permutes, + in_lengths, + out_lengths, + ) + return permuted_values + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] @@ -240,6 +261,82 @@ def _remap_to_groups( return permute, inv_permute, offsets, inv_offsets, splits +def _multi_remap_to_groups( + keys: List[List[str]], + key_lengths: List[List[int]], + groups: List[List[str]], +) -> Tuple[List[int], List[int], List[int]]: + """ + Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths: + [[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length] + """ + # key => (tensor_idx, key_index) + key_map: Dict[str, Tuple[int, int]] = { + key: (tensor_idx, key_idx) + for tensor_idx, tensor in enumerate(keys) + for key_idx, key in enumerate(tensor) + } + + # [offsets per tensor] + in_offsets: List[List[int]] = [[] for _ in key_lengths] + for i, tensor in enumerate(key_lengths): + in_offsets[i] = _cumsum(tensor) + in_lengths: List[int] = [sum(lengths) for lengths in key_lengths] + + # set total_permutes as the jump stop sign + total_permutes: int = sum(len(tensor) for tensor in groups) + out_lengths: List[int] = [0] * len(groups) + + # [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump] + permute_param = 6 + permutes: List[int] = [0] * (total_permutes * permute_param) + + # record the last seen index, so that can make the jump from last_seen to current + last_seen: Dict[str, int] = {} + permute_idx = 0 + for output_tensor_idx, output_tenser in enumerate(groups): + output_start = 0 + for output_key in output_tenser: + input_tensor_idx, input_key_idx = key_map[output_key] + input_start = in_offsets[input_tensor_idx][input_key_idx] + length = key_lengths[input_tensor_idx][input_key_idx] + + # add jump data + if output_key not in last_seen: + jump = 0 # don't need to jump yet + # positive as a potential jump start + last_seen[output_key] = permute_idx + else: + prev = last_seen[output_key] + if prev >= 0: # positive ==> it's a jump start + # jump to current idx, positive as the jump start + permutes[prev * permute_param + 5] = permute_idx + else: # it's already in a jump sequence, mark as negative + permutes[-prev * permute_param + 5] = -permute_idx + # mark last_seen negative since it's already in jump + last_seen[output_key] = -permute_idx + # it's a potential jump stop + jump = -total_permutes + + permutes[permute_idx * permute_param : permute_idx * permute_param + 6] = [ + input_tensor_idx, + output_tensor_idx, + input_start, + output_start, + length, + jump, + ] + permute_idx += 1 + output_start += length + out_lengths[output_tensor_idx] = output_start + + return ( + permutes, + in_lengths, + out_lengths, + ) + + def _values_string(values: torch.Tensor, start: int, end: int) -> str: size = values.size() if len(size) == 1: diff --git a/torchrec/sparse/tests/jagged_tensor_benchmark.py b/torchrec/sparse/tests/jagged_tensor_benchmark.py index 1745910ea..5d13b6e6b 100644 --- a/torchrec/sparse/tests/jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/jagged_tensor_benchmark.py @@ -21,6 +21,7 @@ _regroup_keyed_tensors, KeyedJaggedTensor, KeyedTensor, + permute_multi_embedding, ) from torchrec.sparse.tests.utils import build_groups, build_kts @@ -40,6 +41,7 @@ def bench( run_backward: bool, fn: Callable[..., List[torch.Tensor]], fn_kwargs: Dict[str, Any], + output_dir: str = "", ) -> None: # initial call @@ -49,8 +51,8 @@ def wrapped_func( model: torch.nn.Module, # not used bench_inputs: List[KeyedJaggedTensor], # not used fn: Callable[..., List[torch.Tensor]], - fn_kwargs: Dict[str, Any], run_backward: bool, + **kwargs: Dict[str, Any], ) -> None: result = fn(**fn_kwargs) if run_backward: @@ -64,26 +66,27 @@ def wrapped_func( loss = torch.nn.functional.l1_loss(pred, labels) loss.sum().backward() + model = DummyModel() + setattr(model, "forward", lambda kwargs: fn(**kwargs)) if device_type == "cuda": result = benchmark( name=name, - model=DummyModel(), - warmup_inputs=[], + model=model, + warmup_inputs=[fn_kwargs] * 10, bench_inputs=[], - prof_inputs=[], + prof_inputs=[fn_kwargs] * 10, world_size=1, - output_dir="", + output_dir=output_dir, num_benchmarks=20, func_to_benchmark=functools.partial( wrapped_func, fn=fn, run_backward=run_backward, fn_kwargs=fn_kwargs ), benchmark_func_kwargs={}, rank=0, - enable_logging=False, + enable_logging=True, ) else: # cpu - model = DummyModel() times = timeit.repeat( lambda: wrapped_func( model=model, @@ -160,6 +163,12 @@ def wrapped_func( default=2, help="Total num of regrouping", ) +@click.option( + "--profile", + type=str, + default="", + help="profile output directory", +) def main( cuda_matrix: bool, run_backward: bool, @@ -170,6 +179,7 @@ def main( dim_sparse: int, batch_size: int, n_groups: int, + profile: str, ) -> None: if cuda_matrix: n_denses = [64, 128, 256, 512, 1024] @@ -184,54 +194,69 @@ def main( for device_type in device_types: for batch_size in batch_sizes: - for n_dense, n_sparse in zip(n_denses, n_sparses): - - device = torch.device(device_type) - kts = build_kts( - n_dense, - n_sparse, - dim_dense, - dim_sparse, - batch_size, - device, - run_backward, - ) - labels = torch.randint( - 0, 1, (batch_size,), device=torch.device(device_type) - ).float() - groups = build_groups(kts, n_groups) - bench( - "[fallback] _regroup_keyed_tenors", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - _regroup_keyed_tensors, - {"keyed_tensors": kts, "groups": groups}, - ) - bench( - "[prod] KeyedTensor.regroup", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - KeyedTensor.regroup, - {"keyed_tensors": kts, "groups": groups}, - ) - bench( - "[prod] KTRegroupAsDict", - labels, - batch_size, - n_dense + n_sparse, - device_type, - run_backward, - KTRegroupAsDict( - groups=groups, keys=[str(i) for i in range(n_groups)] - ), - {"keyed_tensors": kts}, - ) + for duplicates in [False, True]: + for n_dense, n_sparse in zip(n_denses, n_sparses): + dup = "_dup" if duplicates else "" + device = torch.device(device_type) + kts = build_kts( + n_dense, + n_sparse, + dim_dense, + dim_sparse, + batch_size, + device, + run_backward, + ) + labels = torch.randint( + 0, 1, (batch_size,), device=torch.device(device_type) + ).float() + groups = build_groups(kts, n_groups, duplicates=duplicates) + bench( + "_regroup_keyed_tenors" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + _regroup_keyed_tensors, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "KeyedTensor.regroup" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KeyedTensor.regroup, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) + bench( + "KTRegroupAsDict" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + KTRegroupAsDict( + groups=groups, keys=[str(i) for i in range(n_groups)] + ), + {"keyed_tensors": kts}, + profile, + ) + bench( + "permute_multi_embs" + dup, + labels, + batch_size, + n_dense + n_sparse, + device_type, + run_backward, + permute_multi_embedding, + {"keyed_tensors": kts, "groups": groups}, + profile, + ) if __name__ == "__main__": diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 9efeb444c..5e49f3516 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -16,6 +16,7 @@ from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( + _multi_remap_to_groups, _regroup_keyed_tensors, ComputeJTDictToKJT, ComputeKJTToJTDict, @@ -1374,6 +1375,170 @@ def test_permute_vb(self) -> None: ) self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + def test_multi_remap_to_group(self) -> None: + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + ref_permutes = [ + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start + [1, 0, 0, 3, 5, 0], # f3 + [0, 1, 3, 0, 4, 0], # f2 + [1, 2, 5, 0, 6, 0], # f4 + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence + [2, 2, 0, 9, 8, 0], # f6 + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary + [1, 3, 11, 3, 7, 0], # f5 + ] + self.assertEqual(permutes, [i for p in ref_permutes for i in p]) + self.assertEqual(in_lengths, [7, 18, 8]) + self.assertEqual(out_lengths, [8, 4, 17, 10]) + + def test_multi_permute_forward_cpu(self) -> None: + batch_size = 5 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) + for lens in lengths + ] + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(len(permutes) // 6): + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_lengths, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + def test_multi_permute_forward_meta(self) -> None: + batch_size = 5 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) + for lens in lengths + ] + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(len(permutes) // 6): + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_lengths, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertEqual(out.shape, ref.shape) + + def test_multi_permute_forward_gpu(self) -> None: + batch_size = 5 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) + for lens in lengths + ] + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(len(permutes) // 6): + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_lengths, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + def test_multi_permute_backward_cpu(self) -> None: + batch_size = 5 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(len(permutes) // 6): + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_lengths, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + + def test_multi_permute_backward_gpu(self) -> None: + batch_size = 2048 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[96, 256], [512, 128, 768], [1024]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) + for lens in lengths + ] + ref_values = [v.detach() for v in values] + for v in ref_values: + v.requires_grad = True + permutes, in_lengths, out_lengths = _multi_remap_to_groups( + keys, lengths, groups + ) + refs = [[] for _ in groups] + for i in range(len(permutes) // 6): + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_lengths, out_lengths + ) + for out, ref in zip(outputs, refs): + self.assertTrue(torch.allclose(out, ref)) + + ref_loss, loss = refs[0].sum(), outputs[0].sum() + for i in range(1, len(refs)): + ref_loss += (i + 1.1) * refs[i].sum() + loss += (i + 1.1) * outputs[i].sum() + ref_loss.backward() + loss.backward() + for val, ref in zip(values, ref_values): + val_grad, ref_grad = val.grad, ref.grad + assert isinstance(val_grad, torch.Tensor) + self.assertTrue(torch.allclose(val_grad, ref_grad)) + def test_permute_duplicates(self) -> None: values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) @@ -1650,8 +1815,6 @@ def test_string_vb(self) -> None: stride_per_key_per_rank=stride_per_key_per_rank, ) - print(str(jag_tensor)) - self.assertEqual( str(jag_tensor), """\