Skip to content

Commit

Permalink
[dynamo] byteir backend switch to torch==2.12 (#315)
Browse files Browse the repository at this point in the history
- as title
- fix fx cache fail
- add debug backend

---------

Signed-off-by: huangchenhui.yellow <huangchenhui.yellow@bytedance.com>
  • Loading branch information
YellowHCH committed Jun 6, 2024
1 parent 0df1d45 commit 53b0948
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ declare_mlir_python_sources(TorchFrontendPythonSources.TopLevel
byteir_backend/compiled_function.py
byteir_backend/compiler.py
byteir_backend/config.py
byteir_backend/debug.py
byteir_backend/inner_compile.py
byteir_backend/utils.py
byteir_backend/byteir_fusible_pattern.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ def byteir(*args, **kwargs):

return byteir_compiler(*args, **kwargs)

@register_backend
def byteir_debug(*args, **kwargs):
from .debug import debug_backend

return debug_backend(*args, **kwargs)

def set_cache_dir(path: str):
from .compilation_cache import ByteIRFxGraphCache

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,18 @@
import logging
import shutil
from copy import copy
import dataclasses

from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union
from filelock import FileLock

import torch
from torch._inductor.codecache import (
BypassFxGraphCache,
LOCK_TIMEOUT,
sha256_hash,
OrderedSetHolder,
write_atomic,
_reduce_fake_tensor,
_reduce_symint,
)
from torch._dynamo.utils import counters
from torch.fx.experimental.symbolic_shapes import ShapeEnv, has_hint, hint_int
from torch._subclasses.fake_tensor import extract_tensor_metadata, FakeTensor
from torch.fx.experimental.symbolic_shapes import ShapeEnv, has_hint, hint_int, SYMPY_INTERP
from torch._subclasses.fake_tensor import FakeTensor
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete

try:
Expand All @@ -47,10 +41,22 @@
...

from .compiled_function import (CompiledArtifact, ByteIRFunction)
from .utils import (
dump_tensors_meta_info,
BypassFxGraphCache,
OrderedSetHolder,
TensorMetadata,
extract_tensor_metadata,
maybe_get_fake_mode,
_reduce_fake_tensor,
_reduce_symint,
sha256_hash,
)

log = logging.getLogger(__name__)



def get_system_info() -> Dict[str, Any]:
try:
system: Dict[str, Any] = {
Expand Down Expand Up @@ -88,7 +94,7 @@ def __init__(
example_inputs: List[torch.Tensor],
fx_kwargs: Dict[str, Any],
):
self.gm = gm
self.gm = gm.__str__()
self.example_inputs = example_inputs

# Order kwargs so hashing is stable to changes in kwarg order.
Expand All @@ -102,13 +108,12 @@ def __init__(
else:
self.fx_kwargs[k] = fx_kwargs[k]

# 'Deterministic algorithms' can affect codegen via lowering to cuda kernels.
self.deterministic_algorithms_settings = (
torch.are_deterministic_algorithms_enabled(),
torch.is_deterministic_algorithms_warn_only_enabled(),
torch.utils.deterministic.
fill_uninitialized_memory, # type: ignore[attr-defined]
)
# # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels.
# self.deterministic_algorithms_settings = (
# torch.are_deterministic_algorithms_enabled(),
# torch.is_deterministic_algorithms_warn_only_enabled(),
# byteir_backend.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined]
# )

# Global settings affecting matmul codegen.
self.cuda_matmul_settings = (
Expand Down Expand Up @@ -146,6 +151,9 @@ def get_str(obj) -> str:
for k, v in obj.items():
h = ByteIRFxGraphCachePickler.get_hash(v)
lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
elif isinstance(obj, torch.fx.GraphModule):
h = ByteIRFxGraphCachePickler.get_hash(obj.__str__())
lines.append(f"[{h}] {attr}: {get_str(obj)}")
else:
h = ByteIRFxGraphCachePickler.get_hash(obj)
lines.append(f"[{h}] {attr}: {get_str(obj)}")
Expand Down Expand Up @@ -288,11 +296,34 @@ def _get_shape_env() -> Optional[ShapeEnv]:
"""
Helper to get the shape env from the tracing context.
"""
ctx = torch._guards.TracingContext.try_get()
ctx = torch._guards.TracingContext.get()
if not ctx:
return None
return ctx.fake_mode.shape_env

@staticmethod
def _produce_guards_expression(shape_env, placeholders, ignore_static=True):
"""
Expected to be used with evaluate_guards_expression(). Produces the guards
for the given placeholders and returns a string expression to be evaluated
by evaluate_guards_expression given concrete values for the placeholders.
"""
from torch._dynamo.source import LocalSource
arg_names = [f"t{i}" for i in range(len(placeholders))]
guards = shape_env.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
if guards:
return " and ".join(guards)
return None

@staticmethod
def _evaluate_guards_expression(code, args):
"""
Expected to be used with produce_guards_expression(). Evaluates an expression
generated by produce_guards_expression for the given concrete args.
"""
arg_names = [f"t{i}" for i in range(len(args))]
return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})

@staticmethod
def _lookup_compiled_artifact(
key: str,
Expand Down Expand Up @@ -335,8 +366,8 @@ def _lookup_compiled_artifact(
# affect the current env, e.g., cause the creation of new guards,
# so we evaluate with the hints instead of the symbols.
hit = bool(
shape_env.evaluate_guards_expression(candidate.guards_expr,
hints))
ByteIRFxGraphCache._evaluate_guards_expression(
candidate.guards_expr, hints))
log.debug(
"fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
key,
Expand All @@ -354,8 +385,8 @@ def _lookup_compiled_artifact(
# Now re-evaluate with the symints to add any guards to the current env.
if artifact.guards_expr:
check = bool(
shape_env.evaluate_guards_expression(artifact.guards_expr,
symints))
ByteIRFxGraphCache._evaluate_guards_expression(
artifact.guards_expr, symints))
assert check is True
log.debug("fx graph cache key %s post-load guards: %s", key,
shape_env.guards)
Expand All @@ -377,8 +408,8 @@ def _save_compiled_artifact(key: str, compiled_artifact: CompiledArtifact,
shape_env = ByteIRFxGraphCache._get_shape_env()
assert shape_env is not None
symints = ByteIRFxGraphCache._filter_symints(example_inputs)
compiled_artifact.guards_expr = shape_env.produce_guards_expression(
symints)
compiled_artifact.guards_expr = ByteIRFxGraphCache._produce_guards_expression(
shape_env, symints)

try:
# FIXME compiled_artifact is not serializable.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import functools
import logging
from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union

import torch

from .compiler import byteir_compiler

log = logging.getLogger(__name__)


def debug_backend(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
"""
compare results between byteir compiled function and eager mode graph.
"""
_opt_gm = byteir_compiler(gm, example_inputs)

def f(*inputs):
opt_inputs = []
for inp in inputs:
_opt_inp = torch.empty_strided(size=inp.size(),
stride=inp.stride(),
storage_offset=inp.storage_offset())
opt_inputs.append(_opt_inp.copy_(inp))

eager_inputs = inputs
eager_res = gm(*eager_inputs)
opt_res = _opt_gm(*opt_inputs)

# compare results
# TODO: check meta info as well as numercial.
try:
torch.testing.assert_close(eager_res, opt_res)
except Exception as e:
log.error(f"******* debug backend fail *******")
raise e

print(f"******* debug backend pass *******")
return eager_res

return f
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
from torch._dynamo import (
utils as dynamo_utils, )
from torch._dynamo.utils import counters
from torch._dynamo.utils import detect_fake_mode
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._subclasses.fake_tensor import (
FakeTensorMode,
FakeTensor,
FakeTensorConverter,
TensorMetadata,
extract_tensor_metadata,
maybe_get_fake_mode,
unset_fake_temporarily,
)

import torch_frontend
Expand All @@ -40,7 +37,13 @@
ByteIRFunction,
)
from .utils import (
dump_tensors_meta_info, )
dump_tensors_meta_info,
BypassFxGraphCache,
OrderedSetHolder,
TensorMetadata,
extract_tensor_metadata,
maybe_get_fake_mode,
)
from . import config

log = logging.getLogger(__name__)
Expand All @@ -49,15 +52,15 @@
BACKEND_LEGAL_OPS = ["aten.max.dim"]


@dynamo_utils.dynamo_timed(phase_name="byteir_compile")
#@dynamo_utils.dynamo_timed(phase_name="byteir_compile")
def inner_compile(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
workdir: str = None,
compiler_type: str = "forward",
**kwargs) -> CompiledArtifact:

graph_id = next(g_graph_counter)
log.debug(f"byteir compiling {compiler_type} graph {graph_id}")
log.info(f"byteir compiling {compiler_type} graph {graph_id}")

if workdir is None:
key = compiled_fx_graph_hash(gm, example_inputs, kwargs)
Expand All @@ -75,9 +78,10 @@ def inner_compile(gm: torch.fx.GraphModule,
fxg_dir_name = f"fx_graph_{compiler_type}_{graph_id}"
fx_graph_folder = f"{workdir}/{fxg_dir_name}/"
os.makedirs(fx_graph_folder, exist_ok=True)
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
gm.to_folder(folder=fx_graph_folder, module_name="FxModule")
with FakeTensorMode(allow_non_fake_inputs=True):
with detect_fake_mode(example_inputs):
#with FakeTensorMode(allow_non_fake_inputs=True):
fake_outs = gm(*example_inputs)
dump_tensors_meta_info(
example_inputs,
Expand Down Expand Up @@ -124,7 +128,7 @@ def byteir_fx_compiler(gm: torch.fx.GraphModule,
log.info(
f"########################### {'FORWARD' if not is_backward else 'BACKWARD'} ###########################"
)
log.info(torch._guards.TracingContext.try_get())
log.info(torch._guards.TracingContext.get())

if config.byteir_not_use_cache:
compiled_artifact = inner_compile(gm, example_inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
is_symbol_binding_fx_node,
find_symbol_binding_fx_nodes
)
from torch.fx.experimental.sym_node import (
from torch.fx.experimental.symbolic_shapes import (
magic_methods,
method_to_operator,
)
Expand Down
Loading

0 comments on commit 53b0948

Please sign in to comment.