Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Refactoring code for Hidet remote compilation #369

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
170 changes: 86 additions & 84 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,65 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module
from typing import List, Callable, Sequence, Union
from typing import List, Sequence, Union
import logging
import torch
import hidet.option
from hidet import Tensor
from hidet.ir import dtypes
from hidet.ir.type import DataType
from hidet.ir.expr import SymbolVar
from hidet.runtime import CompiledGraph
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
from hidet.cuda.graph import CudaGraphCreationError
from hidet.ir import dtypes
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .dynamo_config import dynamo_config

from .interpreter import Interpreter
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .utils import symbol_like_torch

logger = logging.getLogger(__name__)


def generate_executor(flow_graph: FlowGraph) -> Callable:
def get_flow_graph(interpreter, example_inputs) -> FlowGraph:
# prepare dummy and symbolic inputs for correctness and flow graph construction
inputs: List[Union[Tensor, SymbolVar, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(example_input)
elif isinstance(example_input, torch.SymInt):
from torch.fx.experimental.symbolic_shapes import SymNode

node: SymNode = example_input.node
try:
inputs.append(node.pytype(example_input))
except RuntimeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be concerned here?
RuntimeError seems to be quite generic so it might not just be the error that you think you can handle. I suggest we log a warning down below if it does not happen too often.

If possible, I prefer that you check the pre-condition that would cause the runtime error and instead of try-except, use

if not (the condition that would trigger runtime error): 
  inputs.append(...)
else:
  # is a symbolic scalar input

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this part is ported from the original code so you might not know why. Please ignore it for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xinli-git, FYI, we (Allan and I) have tried that, but find out it is hard to find the the condition that would trigger runtime error.

# is a symbolic scalar input
pytype2dtype = {int: dtypes.int32, float: dtypes.float32, bool: dtypes.boolean}
inputs.append(hidet.symbol_var(name=str(example_input), dtype=pytype2dtype[node.pytype]))
else:
raise ValueError(f"hidet_backend: unexpected example input {example_input}, type {type(example_input)}")

logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]

return (hidet.trace_from(output_tensors, inputs=input_tensors), inputs, output_format)


def get_compiled_graph(flow_graph: FlowGraph):
use_fp16 = dynamo_config['use_fp16']
use_fp16_reduction = dynamo_config['use_fp16_reduction']
use_cuda_graph = dynamo_config['use_cuda_graph']
use_attention = dynamo_config['use_attention']
search_space = dynamo_config['search_space']
parallel_k = dynamo_config['parallel_k']
Expand All @@ -55,21 +92,14 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
logger.info('finish optimizing the flow graph')

logger.info('schedule search space: %d', search_space)

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

logger.info('start to build the optimized computation graph')
cgraph: CompiledGraph = graph_opt.build(space=search_space)
logger.info('finish building computation graph')
return cgraph


def get_wrapper(cgraph: CompiledGraph, inputs, output_format):
use_cuda_graph = dynamo_config['use_cuda_graph']
if use_cuda_graph:
try:
runner = cgraph.cuda_graph()
Expand All @@ -78,20 +108,49 @@ def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
else:
runner = cgraph

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

def run(*inputs: torch.Tensor):
hidet_inputs = preprocess_inputs(inputs)
hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs)
torch_outputs: List[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs]
return torch_outputs

return run
def wrapper(*args: Tensor):
tensor_args = []
for param, arg in zip(inputs, args):
if isinstance(param, Tensor):
tensor_args.append(arg)
elif isinstance(param, SymbolVar):
dtype = param.type
assert isinstance(dtype, DataType)
if dtype.name == 'int32':
from hidet.ffi import runtime_api

runtime_api.set_symbol_value(param.name, int(arg))
else:
raise ValueError(f'hidet_backend: unsupported symbolic dtype {dtype}. We only support int32 now.')
else:
# ignore constant
pass
outputs: Sequence[torch.Tensor] = run(*tensor_args)
ret = deserialize_output(output_format, outputs)
return ret

logger.info('finish generating the executor')

return wrapper

def hidet_backend(graph_module, example_inputs):
from hidet import Tensor
from .interpreter import Interpreter
from .utils import symbol_like_torch

def hidet_backend(graph_module, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)

logger.info('received a subgraph with %d nodes to optimize', len(graph_module.graph.nodes))
Expand All @@ -105,27 +164,6 @@ def hidet_backend(graph_module, example_inputs):
# get the interpreter for the subgraph
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
inputs: List[Union[Tensor, SymbolVar, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(example_input)
elif isinstance(example_input, torch.SymInt):
from torch.fx.experimental.symbolic_shapes import SymNode

node: SymNode = example_input.node
try:
inputs.append(node.pytype(example_input))
except RuntimeError:
# is a symbolic scalar input
pytype2dtype = {int: dtypes.int32, float: dtypes.float32, bool: dtypes.boolean}
inputs.append(hidet.symbol_var(name=str(example_input), dtype=pytype2dtype[node.pytype]))
else:
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
def wrapper(*args):
Expand All @@ -135,45 +173,9 @@ def wrapper(*args):
return output

return wrapper
else:
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
tensor_args = []
for param, arg in zip(inputs, args):
if isinstance(param, Tensor):
tensor_args.append(arg)
elif isinstance(param, SymbolVar):
dtype = param.type
assert isinstance(dtype, DataType)
if dtype.name == 'int32':
from hidet.ffi import runtime_api

runtime_api.set_symbol_value(param.name, int(arg))
else:
raise ValueError(
f'hidet_backend: unsupported symbolic dtype {dtype}. We only support int32 now.'
)
else:
# ignore constant
pass
outputs: Sequence[torch.Tensor] = executor(*tensor_args)
ret = deserialize_output(output_format, outputs)
return ret

logger.info('finish generating the executor')
flow_graph, inputs, output_format = get_flow_graph(interpreter, example_inputs)

return wrapper
cgraph = get_compiled_graph(flow_graph)

return get_wrapper(cgraph, inputs, output_format)