Skip to content

Commit

Permalink
[Frontend] Dynamic shape fx trace (#294)
Browse files Browse the repository at this point in the history
enable the option torch.compile(..., dynamic=True)

- convert torch FakeTensor to hidet Symbolic Tensor
- There may be a bug in torch.dynamo, so we filter/pre-process inputs in
both the example inputs and the wrapped function
- Altered the graph interpreter to support non-torch functions, such as
builtins add, getitem, etc. that remove dependence on register functions

---------

Co-authored-by: Allan Lin <allan.lin@centml.ai>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
  • Loading branch information
3 people committed Jul 12, 2023
1 parent 02d9a10 commit 9d51c74
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 125 deletions.
50 changes: 36 additions & 14 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module
from typing import List, Callable, Sequence
from typing import List, Callable, Sequence, Union
import logging
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
Expand Down Expand Up @@ -103,40 +104,61 @@ def hidet_backend(graph_module, example_inputs):
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
inputs: List[Union[Tensor, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(symbolic_input)
elif isinstance(example_input, torch.SymInt):
try:
inputs.append(int(example_input))
except Exception as e:
raise ValueError(f"hidet_backend: free symbolic example input {example_input}") from e
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
# there exist some symbolic shapes, currently we don't support this option
# as there is no way to principly get concrete shapes at this stage from symbolic shapes
# since some models like resnet requires the image to be above a certain size.
if any(not all(is_constant(s) for s in t.shape) for t in inputs if isinstance(t, hidet.Tensor)):
raise ValueError("hidet_backend: cannot print correctness report with dynamic=True")
dummy_inputs = [] # for correctness check
for arg in inputs:
if isinstance(arg, hidet.Tensor):
if data_type(arg.dtype).is_integer():
dummy_input = hidet.zeros_like(arg)
else:
dummy_input = hidet.randn_like(arg)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_input = arg
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]
input_tensor_indices = [i for (i, x) in enumerate(inputs) if isinstance(x, hidet.Tensor)]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
args = [args[i] for i in input_tensor_indices]
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret
Expand Down
51 changes: 46 additions & 5 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ def tensor_from_torch(tensor: torch.Tensor) -> Tensor:
return hidet.graph.tensor.from_torch(tensor)


def is_torch_path(name: str) -> bool:
name = name.split(".")
if len(name) > 0:
return name[0] == "torch"
return False


def belong_to_torch(code_obj) -> bool:
belongs = False
if hasattr(code_obj, "__module__") and code_obj.__module__ is not None:
belongs |= is_torch_path(code_obj.__module__)
if not belongs and hasattr(code_obj, "__package__") and code_obj.__package__ is not None:
belongs |= is_torch_path(code_obj.__package__)
return belongs


class HidetModule:
def __init__(self, torch_module: torch.nn.Module):
self.mod: torch.nn.Module = torch_module
Expand Down Expand Up @@ -204,6 +220,15 @@ def __init__(self, graph_module: torch.fx.GraphModule):
self.torch_modules: Dict[str, torch.nn.Module] = dict(graph_module.named_modules())
self.hidet_modules: Dict[str, HidetModule] = {}

# basically dynamo further wraps some builtin functions with annoying locals functions
# which gets dispatched incorrectly
self.ignore_funcs: Dict[str, Callable] = {
# see torch._dynamo.variables.lists.SizeVariable.get_item_dyn
# this signifies that the target of getitem is a torch.Size, we overload torch.Tensor.size by
# returning a list, so this method needs to be overloaded in the interpreter as well
'_dynamo_get_item_lambda': lambda target, index: target[index]
}

self._check_support()

def __call__(self, *args):
Expand Down Expand Up @@ -232,8 +257,10 @@ def _check_support(self):
if torch_cls not in Registry.registered_modules:
not_supported.add(torch_cls)
elif node.op == "call_function":
if node.target not in Registry.registered_functions:
target_fn = self._lookup_function(node.target)
if target_fn is None:
not_supported.add(node.target)

if len(not_supported) > 0:
lines = []
lines.append("The following modules/functions are not supported by hidet yet:")
Expand All @@ -255,6 +282,20 @@ def _lookup_hidet_method(self, torch_method):
raise NotImplementedError(f"hidet: method {method_name} is not supported yet.")
return Registry.registered_methods[torch_method]

def _lookup_function(self, code_obj):
if code_obj.__name__ in self.ignore_funcs:
return self.ignore_funcs[code_obj.__name__]
if belong_to_torch(code_obj):
if code_obj in Registry.registered_functions:
return Registry.registered_functions[code_obj]
else:
return None
else:
# this branch handles all the other cases, such as getitem, operator.add, etc.
# since the inputs are all hidet tensors, applying this function should resolve to
# the actual traced implementation
return code_obj

@staticmethod
def _callable_info(f: Callable) -> Tuple[str, str, int]:
if inspect.ismethod(f):
Expand Down Expand Up @@ -337,13 +378,13 @@ def load_arg(a, env):
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
hidet_func = Registry.registered_functions[node.target]
exec_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
args = load_arg(node.args, hidet_env)
kwargs = load_arg(node.kwargs, hidet_env)
Expand Down Expand Up @@ -425,7 +466,7 @@ def load_arg(a, env):
torch_kwargs = load_arg(node.kwargs, torch_env)
torch_env[node.name] = torch_func(*torch_args, **torch_kwargs)

hidet_func = Registry.registered_functions[torch_func]
hidet_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)

Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
@register_function(operator.add)
@register_function(torch.ops.aten.add.Tensor)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(torch.sin)
Expand Down Expand Up @@ -363,7 +363,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals

@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
*size: Union[Int, Sequence[Int]],
out: Optional[Tensor] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
Expand All @@ -382,7 +382,7 @@ def ones(
if isinstance(size[0], (list, tuple)):
size = size[0]

shape = [int(v) for v in size]
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

Expand Down
12 changes: 10 additions & 2 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Union
import torch

from hidet.ir.type import DataType
from hidet.ir.type import DataType, Int
from hidet.graph.tensor import Tensor
from hidet.graph import ops
from hidet.runtime.device import instantiate_device
Expand Down Expand Up @@ -130,7 +130,7 @@ def tensor_view(self: Tensor, *args) -> Tensor:
else:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
dst_shape = [int(arg) for arg in args]
dst_shape = list(args)
return ops.reshape(self, dst_shape)


Expand Down Expand Up @@ -161,6 +161,14 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts_or_sections=parts)


@register_method(torch.Tensor.size)
def tensor_size(self: Tensor, dim=None) -> List[Int]:
if dim is None:
return self.shape
else:
return self.shape[dim]


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
Expand Down
15 changes: 13 additions & 2 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,19 @@ def device_from_torch(torch_device) -> Device:
def symbol_like_torch(tensor) -> Tensor:
import hidet
import torch

if isinstance(tensor, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
shape=list(tensor.shape), dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type
)
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/graph_utils/instruments/debug_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
class GraphForwardDebugInstrument(GraphForwardInstrument):
_template = '{:>5} {:>30} {:>3} {:<25} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10}'

def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False):
def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False, dump_op=False):
self.output_dir: str = output_dir
self.print_summary: bool = print_summary
self.dump_outputs: bool = dump_outputs
self.dump_op: bool = dump_op

self.debugging: bool = False
self.summary_file: Optional[str] = None
Expand Down Expand Up @@ -141,6 +142,13 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
with open(array_path, 'w') as f:
with np.printoptions(precision=8, edgeitems=30, linewidth=512):
f.write(str(array))
if self.dump_op:
op_path = os.path.join(
self.output_dir, '{}_{}{}.txt'.format(self.operator_idx, op.name, f'_def{idx}' if idx > 0 else '')
)
with open(op_path, 'w') as f:
f.write('Operator:\n{}\n'.format(op))
f.write('Task:\n{}\n'.format(op.task))

with open(self.summary_file, 'a') as f:
f.write('\n'.join(lines) + '\n')
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class FullTask(Task):
def __init__(
self, shape: Sequence[int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
self, shape: Sequence[Int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
):
dtype: DataType = data_type(dtype)
value: Constant = dtype(value) if isinstance(value, (int, float, bool)) else value
Expand Down Expand Up @@ -123,12 +123,12 @@ def infer_dtype(self, start, stop, step):
class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
shape: Sequence[Int],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
shape = list(shape)
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
Expand Down
21 changes: 3 additions & 18 deletions python/hidet/graph/ops/normalize/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hidet

from ..utils import Tensor, normalize_dim
from ..arithmetic import rsqrt
from .norm import normalize
from .norm_f16 import normalize_f16


def resolve_norm_func(dtype):
if dtype == hidet.float32:
return normalize
elif dtype == hidet.float16:
return normalize_f16
else:
raise NotImplementedError("normalize function for dtype {} is not implemented".format(dtype))


def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor:
Expand Down Expand Up @@ -58,8 +46,7 @@ def instance_norm(x: Tensor, epsilon: float = 1e-5, accumulate_dtype: str = 'flo
The normalized tensor.
"""
dims = [dim for dim in range(2, len(x.shape))]
norm_func = resolve_norm_func(x.dtype)
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumulate_dtype: str = 'float32') -> Tensor:
Expand All @@ -82,9 +69,8 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumul
ret: Tensor
The normalized tensor.
"""
norm_func = resolve_norm_func(x.dtype)
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: str = 'float32'):
Expand Down Expand Up @@ -119,7 +105,6 @@ def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: s

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
norm_func = resolve_norm_func(x.dtype)
normed = norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
normed = normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)

return normed.reshape(x_shape)

0 comments on commit 9d51c74

Please sign in to comment.