Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic shape fx trace #294

Merged
merged 24 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 40 additions & 13 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
Expand Down Expand Up @@ -103,40 +104,66 @@ def hidet_backend(graph_module, example_inputs):
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
# unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters
# in example inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.

Say a torch model gives us

sample_inputs = [tensor(['m', 'n'], 'm', 'n']

We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any clue on this?

inputs = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
inputs.append(symbolic_input)
elif isinstance(example_input, int):
inputs.append(symbolic_input)
elif isinstance(example_input, torch.SymInt):
try:
inputs.append(int(example_input))
except Exception as e:
raise ValueError(f"hidet_backend: free symbolic example input {example_input}") from e
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
# there exist some symbolic shapes, currently we don't support this option
# as there is no way to principly get concrete shapes at this stage from symbolic shapes
# since some models like resnet requires the image to be above a certain size.
if any(not all(is_constant(s) for s in t.shape) for t in inputs if isinstance(t, hidet.Tensor)):
raise ValueError("hidet_backend: cannot print correctness report with dynamic=True")
dummy_inputs = [] # for correctness check
for arg in inputs:
if isinstance(arg, hidet.Tensor):
if data_type(arg.dtype).is_integer():
dummy_input = hidet.zeros_like(arg)
else:
dummy_input = hidet.randn_like(arg)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_input = arg
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)
input_tensors = list(filter(lambda x: isinstance(x, hidet.Tensor), inputs))
# essentially, I think this is a bug in torch._inductor
# the example inputs have instances of torch.SymInt (when dynamic=True), while the inputs to the compiled model
# are torch.Tensors.
input_map = [isinstance(x, hidet.Tensor) for x in inputs]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
args = [t for (t, is_hidet_tensor) in zip(args, input_map) if is_hidet_tensor]
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret
Expand Down
51 changes: 46 additions & 5 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ def tensor_from_torch(tensor: torch.Tensor) -> Tensor:
return hidet.graph.tensor.from_torch(tensor)


def is_torch_path(name: str) -> bool:
name = name.split(".")
if len(name) > 0:
return name[0] == "torch"
return False


def belong_to_torch(code_obj) -> bool:
belongs = False
if hasattr(code_obj, "__module__") and code_obj.__module__ is not None:
belongs |= is_torch_path(code_obj.__module__)
if not belongs and hasattr(code_obj, "__package__") and code_obj.__package__ is not None:
belongs |= is_torch_path(code_obj.__package__)
return belongs


class HidetModule:
def __init__(self, torch_module: torch.nn.Module):
self.mod: torch.nn.Module = torch_module
Expand Down Expand Up @@ -182,6 +198,15 @@ def __init__(self, graph_module: torch.fx.GraphModule):
self.torch_modules: Dict[str, torch.nn.Module] = dict(graph_module.named_modules())
self.hidet_modules: Dict[str, HidetModule] = {}

# basically dynamo further wraps some builtin functions with annoying locals functions
# which gets dispatched incorrectly
self.ignore_funcs: Dict[str, Callable] = {
# see torch._dynamo.variables.lists.SizeVariable.get_item_dyn
# this signifies that the target of getitem is a torch.Size, we overload torch.Tensor.size by
# returning a list, so this method needs to be overloaded in the interpreter as well
'_dynamo_get_item_lambda': lambda target, index: target[index]
}

self._check_support()

def __call__(self, *args):
Expand Down Expand Up @@ -210,8 +235,10 @@ def _check_support(self):
if torch_cls not in Registry.registered_modules:
not_supported.add(torch_cls)
elif node.op == "call_function":
if node.target not in Registry.registered_functions:
target_fn = self._lookup_function(node.target)
if target_fn is None:
not_supported.add(node.target)

if len(not_supported) > 0:
lines = []
lines.append("The following modules/functions are not supported by hidet yet:")
Expand All @@ -233,6 +260,20 @@ def _lookup_hidet_method(self, torch_method):
raise NotImplementedError(f"hidet: method {method_name} is not supported yet.")
return Registry.registered_methods[torch_method]

def _lookup_function(self, code_obj):
if code_obj.__name__ in self.ignore_funcs:
return self.ignore_funcs[code_obj.__name__]
if belong_to_torch(code_obj):
if code_obj in Registry.registered_functions:
return Registry.registered_functions[code_obj]
else:
return None
else:
# this branch handles all the other cases, such as getitem, operator.add, etc.
# since the inputs are all hidet tensors, applying this function should resolve to
# the actual traced implementation
return code_obj

@staticmethod
def _callable_info(f: Callable) -> Tuple[str, str, int]:
if inspect.ismethod(f):
Expand Down Expand Up @@ -315,13 +356,13 @@ def load_arg(a, env):
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
hidet_func = Registry.registered_functions[node.target]
exec_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
args = load_arg(node.args, hidet_env)
kwargs = load_arg(node.kwargs, hidet_env)
Expand Down Expand Up @@ -403,7 +444,7 @@ def load_arg(a, env):
torch_kwargs = load_arg(node.kwargs, torch_env)
torch_env[node.name] = torch_func(*torch_args, **torch_kwargs)

hidet_func = Registry.registered_functions[torch_func]
hidet_func = self._lookup_function(node.target)
hidet_args = load_arg(node.args, hidet_env)
hidet_kwargs = load_arg(node.kwargs, hidet_env)

Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def bilinear(x_1: Tensor, x_2: Tensor, weight: Tensor, bias: Optional[Tensor]):
@register_function(operator.add)
@register_function(torch.ops.aten.add.Tensor)
def add(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y


@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the x and y could be DynInt?



@register_function(torch.sin)
Expand Down Expand Up @@ -362,7 +362,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, pin_memory=Fals

@register_function(torch.ones)
def ones(
*size: Union[int, Sequence[int]],
*size: Union[Int, Sequence[Int]],
out: Optional[Tensor] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
Expand All @@ -381,7 +381,7 @@ def ones(
if isinstance(size[0], (list, tuple)):
size = size[0]

shape = [int(v) for v in size]
shape = [v if isinstance(v, hidet.ir.Expr) else int(v) for v in size]
if dtype is None:
dtype = torch.get_default_dtype()

Expand Down
9 changes: 7 additions & 2 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Union
import torch

from hidet.ir.type import DataType
from hidet.ir.type import DataType, Int
from hidet.graph.tensor import Tensor
from hidet.graph import ops
from hidet.runtime.device import instantiate_device
Expand Down Expand Up @@ -130,7 +130,7 @@ def tensor_view(self: Tensor, *args) -> Tensor:
else:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
dst_shape = [int(arg) for arg in args]
dst_shape = list(args)
return ops.reshape(self, dst_shape)


Expand Down Expand Up @@ -161,6 +161,11 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts_or_sections=parts)


@register_method(torch.Tensor.size)
def tensor_size(self: Tensor) -> List[Int]:
return self.shape


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
Expand Down
15 changes: 13 additions & 2 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,19 @@ def device_from_torch(torch_device) -> Device:
def symbol_like_torch(tensor) -> Tensor:
import hidet
import torch

if isinstance(tensor, torch.Tensor):
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
# this should be fine for now; torch wraps around the sympy library
symbolic_shape = []
for s in tensor.shape:
try:
i = int(s)
except Exception: # pylint: disable=broad-except
i = str(s)
symbolic_shape.append(i)
return hidet.symbol(shape=symbolic_shape, dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type)
elif isinstance(tensor, torch.Tensor):
return hidet.symbol(
shape=list(tensor.shape), dtype=dtype_from_torch(tensor.dtype).name, device=tensor.device.type
)
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/graph_utils/instruments/debug_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
class GraphForwardDebugInstrument(GraphForwardInstrument):
_template = '{:>5} {:>30} {:>3} {:<25} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10}'

def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False):
def __init__(self, output_dir='./outs/debug', print_summary=False, dump_outputs=False, dump_op=False):
self.output_dir: str = output_dir
self.print_summary: bool = print_summary
self.dump_outputs: bool = dump_outputs
self.dump_op: bool = dump_op

self.debugging: bool = False
self.summary_file: Optional[str] = None
Expand Down Expand Up @@ -141,6 +142,13 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso
with open(array_path, 'w') as f:
with np.printoptions(precision=8, edgeitems=30, linewidth=512):
f.write(str(array))
if self.dump_op:
op_path = os.path.join(
self.output_dir, '{}_{}{}.txt'.format(self.operator_idx, op.name, f'_def{idx}' if idx > 0 else '')
)
with open(op_path, 'w') as f:
f.write('Operator:\n{}\n'.format(op))
f.write('Task:\n{}\n'.format(op.task))

with open(self.summary_file, 'a') as f:
f.write('\n'.join(lines) + '\n')
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class FullTask(Task):
def __init__(
self, shape: Sequence[int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
self, shape: Sequence[Int], value: Union[int, float, bool, Constant, Expr], dtype: Union[DataType, str]
):
dtype: DataType = data_type(dtype)
value: Constant = dtype(value) if isinstance(value, (int, float, bool)) else value
Expand Down Expand Up @@ -123,12 +123,12 @@ def infer_dtype(self, start, stop, step):
class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
shape: Sequence[Int],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
shape = list(shape)
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
Expand Down
21 changes: 3 additions & 18 deletions python/hidet/graph/ops/normalize/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hidet

from ..utils import Tensor, normalize_dim
from ..arithmetic import rsqrt
from .norm import normalize
from .norm_f16 import normalize_f16


def resolve_norm_func(dtype):
if dtype == hidet.float32:
return normalize
elif dtype == hidet.float16:
return normalize_f16
else:
raise NotImplementedError("normalize function for dtype {} is not implemented".format(dtype))


def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor:
Expand Down Expand Up @@ -58,8 +46,7 @@ def instance_norm(x: Tensor, epsilon: float = 1e-5, accumulate_dtype: str = 'flo
The normalized tensor.
"""
dims = [dim for dim in range(2, len(x.shape))]
norm_func = resolve_norm_func(x.dtype)
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumulate_dtype: str = 'float32') -> Tensor:
Expand All @@ -82,9 +69,8 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5, accumul
ret: Tensor
The normalized tensor.
"""
norm_func = resolve_norm_func(x.dtype)
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
return normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: str = 'float32'):
Expand Down Expand Up @@ -119,7 +105,6 @@ def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5, accumulate_dtype: s

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
norm_func = resolve_norm_func(x.dtype)
normed = norm_func(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)
normed = normalize(x, axis=dims, epsilon=epsilon, accumulate_dtype=accumulate_dtype)

return normed.reshape(x_shape)