forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Force synced KJT to trace unbacked SymInt (pytorch#108960)
Summary: The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1. The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples. The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked. Test Plan: ``` buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu ``` from aakhundov 1. first build feed_lower_benchmark: ``` buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark ``` 2. then run the lowering of the model with it: ``` TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace ``` cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0 From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/ From ge0405 baseline (without your diff): f477293168 your diff: f477292363 Reviewed By: voznesenskym Differential Revision: D49019987
- Loading branch information
1 parent
a6b153b
commit f28a6b7
Showing
4 changed files
with
247 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import unittest | ||
import sys | ||
from typing import Dict, List | ||
|
||
import torch | ||
import torch._dynamo.test_case | ||
from torch import nn | ||
|
||
from torch._dynamo.test_case import TestCase | ||
from torch._dynamo.testing import CompileCounter | ||
from torch.testing._internal.common_utils import NoTest | ||
|
||
try: | ||
from torchrec.datasets.random import RandomRecDataset | ||
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor | ||
|
||
HAS_TORCHREC = True | ||
except ImportError: | ||
HAS_TORCHREC = False | ||
|
||
|
||
class BucketizeMod(torch.nn.Module): | ||
def __init__(self, feature_boundaries: Dict[str, List[float]]): | ||
super().__init__() | ||
self.bucket_w = torch.nn.ParameterDict() | ||
self.boundaries_dict = {} | ||
for key, boundaries in feature_boundaries.items(): | ||
self.bucket_w[key] = torch.nn.Parameter( | ||
torch.empty([len(boundaries) + 1]).fill_(1.0), | ||
requires_grad=True, | ||
) | ||
buf = torch.tensor(boundaries, requires_grad=False) | ||
self.register_buffer( | ||
f"{key}_boundaries", | ||
buf, | ||
persistent=False, | ||
) | ||
self.boundaries_dict[key] = buf | ||
|
||
def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": | ||
weights_list = [] | ||
for key, boundaries in self.boundaries_dict.items(): | ||
jt = features[key] | ||
bucketized = torch.bucketize(jt.weights(), boundaries) | ||
# doesn't super matter I guess | ||
# hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries)) | ||
hashed = bucketized | ||
weights = torch.gather(self.bucket_w[key], dim=0, index=hashed) | ||
weights_list.append(weights) | ||
return KeyedJaggedTensor( | ||
keys=features.keys(), | ||
values=features.values(), | ||
weights=torch.cat(weights_list), | ||
lengths=features.lengths(), | ||
offsets=features.offsets(), | ||
stride=features.stride(), | ||
length_per_key=features.length_per_key(), | ||
) | ||
|
||
|
||
if not HAS_TORCHREC: | ||
print("torchrec not available, skipping tests", file=sys.stderr) | ||
TestCase = NoTest # noqa: F811 | ||
|
||
|
||
@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") | ||
class TorchRecTests(TestCase): | ||
def test_pooled(self): | ||
tables = [ | ||
(nn.EmbeddingBag(2000, 8), ["a0", "b0"]), | ||
(nn.EmbeddingBag(2000, 8), ["a1", "b1"]), | ||
(nn.EmbeddingBag(2000, 8), ["b2"]), | ||
] | ||
|
||
embedding_groups = { | ||
"a": ["a0", "a1"], | ||
"b": ["b0", "b1", "b2"], | ||
} | ||
|
||
counter = CompileCounter() | ||
|
||
@torch.compile(backend=counter, fullgraph=True, dynamic=True) | ||
def f(id_list_features: KeyedJaggedTensor): | ||
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() | ||
pooled_embeddings = {} | ||
# TODO: run feature processor | ||
for emb_module, feature_names in tables: | ||
features_dict = id_list_jt_dict | ||
for feature_name in feature_names: | ||
f = features_dict[feature_name] | ||
pooled_embeddings[feature_name] = emb_module( | ||
f.values(), f.offsets() | ||
) | ||
|
||
pooled_embeddings_by_group = {} | ||
for group_name, group_embedding_names in embedding_groups.items(): | ||
group_embeddings = [ | ||
pooled_embeddings[name] for name in group_embedding_names | ||
] | ||
pooled_embeddings_by_group[group_name] = torch.cat( | ||
group_embeddings, dim=1 | ||
) | ||
|
||
return pooled_embeddings_by_group | ||
|
||
dataset = RandomRecDataset( | ||
keys=["a0", "a1", "b0", "b1", "b2"], | ||
batch_size=4, | ||
hash_size=2000, | ||
ids_per_feature=3, | ||
num_dense=0, | ||
) | ||
di = iter(dataset) | ||
|
||
# unsync should work | ||
|
||
d1 = next(di).sparse_features.unsync() | ||
d2 = next(di).sparse_features.unsync() | ||
d3 = next(di).sparse_features.unsync() | ||
|
||
r1 = f(d1) | ||
r2 = f(d2) | ||
r3 = f(d3) | ||
|
||
self.assertEqual(counter.frame_count, 1) | ||
counter.frame_count = 0 | ||
|
||
# sync should work too | ||
|
||
d1 = next(di).sparse_features.sync() | ||
d2 = next(di).sparse_features.sync() | ||
d3 = next(di).sparse_features.sync() | ||
|
||
r1 = f(d1) | ||
r2 = f(d2) | ||
r3 = f(d3) | ||
|
||
self.assertEqual(counter.frame_count, 1) | ||
|
||
# export only works with unsync | ||
|
||
gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module | ||
gm.print_readable() | ||
|
||
self.assertEqual(gm(d1), r1) | ||
self.assertEqual(gm(d2), r2) | ||
self.assertEqual(gm(d3), r3) | ||
|
||
def test_bucketize(self): | ||
mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]}) | ||
features = KeyedJaggedTensor.from_lengths_sync( | ||
keys=["f1"], | ||
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), | ||
lengths=torch.tensor([2, 0, 1, 1, 1, 3]), | ||
weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]), | ||
).unsync() | ||
|
||
def f(x): | ||
# This is a trick to populate the computed cache and instruct | ||
# ShapeEnv that they're all sizey | ||
x.to_dict() | ||
return mod(x) | ||
|
||
torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable() | ||
|
||
@unittest.expectedFailure | ||
def test_simple(self): | ||
jag_tensor1 = KeyedJaggedTensor( | ||
values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), | ||
keys=["index_0", "index_1"], | ||
lengths=torch.tensor([0, 0, 1, 1, 1, 3]), | ||
).sync() | ||
|
||
# ordinarily, this would trigger one specialization | ||
self.assertEqual(jag_tensor1.length_per_key(), [1, 5]) | ||
|
||
counter = CompileCounter() | ||
|
||
@torch._dynamo.optimize(counter, nopython=True) | ||
def f(jag_tensor): | ||
# The indexing here requires more symbolic reasoning | ||
# and doesn't work right now | ||
return jag_tensor["index_0"].values().sum() | ||
|
||
f(jag_tensor1) | ||
|
||
self.assertEqual(counter.frame_count, 1) | ||
|
||
jag_tensor2 = KeyedJaggedTensor( | ||
values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), | ||
keys=["index_0", "index_1"], | ||
lengths=torch.tensor([2, 0, 1, 1, 1, 3]), | ||
).sync() | ||
|
||
f(jag_tensor2) | ||
|
||
self.assertEqual(counter.frame_count, 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters