Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add an option to print correctness report in hidet backend of torch dynamo #36

Merged
merged 6 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .availability import available, dynamo_available

if available():
from .interpreter import ImportedTorchModule, from_torch
from .interpreter import Interpreter, from_torch
from . import register_functions
from . import register_modules
from . import register_methods
Expand Down
47 changes: 38 additions & 9 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import torch
import hidet.option
from hidet.ir.type import data_type
from hidet.graph.ir.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from .utils import serialize_output, deserialize_output
Expand All @@ -19,6 +20,7 @@ def __init__(self):
self._use_fp16_reduction: bool = False
self._use_cuda_graph: bool = True
self._print_input_graph: bool = False
self._correctness_report: bool = False

def __getitem__(self, item: str):
assert isinstance(item, str)
Expand Down Expand Up @@ -72,6 +74,12 @@ def print_input_graph(self, flag=True):
"""
self._print_input_graph = flag

def correctness_report(self, flag=True):
"""
Whether to check correctness and print report error
"""
self._correctness_report = flag


dynamo_config = DynamoConfig()

Expand Down Expand Up @@ -157,7 +165,7 @@ def onnx2hidet_backend(subgraph):
def hidet_backend(subgraph):
from hidet import Tensor
from torch._dynamo.optimizations.subgraph import SubGraph
from .interpreter import ImportedTorchModule
from .interpreter import Interpreter
from .utils import symbol_like_torch

assert isinstance(subgraph, SubGraph)
Expand All @@ -168,30 +176,51 @@ def hidet_backend(subgraph):
if dynamo_config['print_input_graph']:
subgraph.model.graph.print_tabular()

symbolic_inputs: List[Tensor] = []
# get the interpreter for the subgraph
assert isinstance(subgraph.model, torch.fx.GraphModule)
graph_module: torch.fx.GraphModule = subgraph.model
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
symbolic_inputs: List[Tensor] = [] # for flow graph construction
for example_input in subgraph.example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_inputs.append(symbol_like_torch(example_input))
symbolic_input = symbol_like_torch(example_input)
symbolic_inputs.append(symbolic_input)
else:
raise ValueError('hidet_backend: only support torch.Tensor as example input')

if dynamo_config['correctness_report']:
# check correctness using random inputs
logger.info('start to check correctness')
dummy_inputs: List[Tensor] = [] # for correctness check
for symbolic_input in symbolic_inputs:
if data_type(symbolic_input.dtype).is_integer():
dummy_input = hidet.zeros_like(symbolic_input)
else:
dummy_input = hidet.randn_like(symbolic_input)
dummy_inputs.append(dummy_input)
report: str = interpreter.forward_with_check(*dummy_inputs)
logger.info('finish checking correctness')
print(report)

logger.info('hidet: symbolic inputs: ')
for symbolic_input in symbolic_inputs:
logger.info('hidet: %s', symbolic_input.signature())

assert isinstance(subgraph.model, torch.fx.GraphModule)
graph_module: torch.fx.GraphModule = subgraph.model
imported_module: ImportedTorchModule = hidet.frontend.from_torch(graph_module)

output = imported_module(*symbolic_inputs)
# symbolic run to get flow graph
output = interpreter(*symbolic_inputs)
output_format, output_tensors = serialize_output(output)
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=symbolic_inputs)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
outputs: Sequence[torch.Tensor] = executor(*args)
ret = deserialize_output(output_format, outputs)
return ret[0]
return ret

logger.info('finish generating the executor')

return wrapper

Expand Down