Skip to content

Commit

Permalink
[Dynamo] Refactoring code for Hidet remote compilation (#369)
Browse files Browse the repository at this point in the history
Hidet remote compilation needs to access the FlowGraph in order to hash
it, so it has been moved into a separate function (get_flow_graph)
Remote compilation also needs to access the CompilationGraph since it
will send that from the server to the client, so it as been moved into a
separate function (get_compiled_graph)
Since the remote compilation client receives the CompilationGraph, the
executor and wrapper function that uses this cgraph are moved to a
separate function (get_wrapper)
  • Loading branch information
destefy committed Oct 25, 2023
1 parent f1232eb commit 0d4ce21
Showing 1 changed file with 86 additions and 84 deletions.
170 changes: 86 additions & 84 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,65 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module
from typing import List, Callable, Sequence, Union
from typing import List, Sequence, Union
import logging
import torch
import hidet.option
from hidet import Tensor
from hidet.ir import dtypes
from hidet.ir.type import DataType
from hidet.ir.expr import SymbolVar
from hidet.runtime import CompiledGraph
from hidet.graph.flow_graph import FlowGraph
from hidet.graph.transforms import PassContext, optimize
from hidet.runtime import CompiledGraph
from hidet.cuda.graph import CudaGraphCreationError
from hidet.ir import dtypes
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .dynamo_config import dynamo_config

from .interpreter import Interpreter
from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph
from .utils import symbol_like_torch

logger = logging.getLogger(__name__)


def generate_executor(flow_graph: FlowGraph) -> Callable:
def get_flow_graph(interpreter, example_inputs) -> FlowGraph:
# prepare dummy and symbolic inputs for correctness and flow graph construction
inputs: List[Union[Tensor, SymbolVar, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(example_input)
elif isinstance(example_input, torch.SymInt):
from torch.fx.experimental.symbolic_shapes import SymNode

node: SymNode = example_input.node
try:
inputs.append(node.pytype(example_input))
except RuntimeError:
# is a symbolic scalar input
pytype2dtype = {int: dtypes.int32, float: dtypes.float32, bool: dtypes.boolean}
inputs.append(hidet.symbol_var(name=str(example_input), dtype=pytype2dtype[node.pytype]))
else:
raise ValueError(f"hidet_backend: unexpected example input {example_input}, type {type(example_input)}")

logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]

return (hidet.trace_from(output_tensors, inputs=input_tensors), inputs, output_format)


def get_compiled_graph(flow_graph: FlowGraph):
use_fp16 = dynamo_config['use_fp16']
use_fp16_reduction = dynamo_config['use_fp16_reduction']
use_cuda_graph = dynamo_config['use_cuda_graph']
use_attention = dynamo_config['use_attention']
search_space = dynamo_config['search_space']
parallel_k = dynamo_config['parallel_k']
Expand All @@ -55,21 +92,14 @@ def generate_executor(flow_graph: FlowGraph) -> Callable:
logger.info('finish optimizing the flow graph')

logger.info('schedule search space: %d', search_space)

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

logger.info('start to build the optimized computation graph')
cgraph: CompiledGraph = graph_opt.build(space=search_space)
logger.info('finish building computation graph')
return cgraph


def get_wrapper(cgraph: CompiledGraph, inputs, output_format):
use_cuda_graph = dynamo_config['use_cuda_graph']
if use_cuda_graph:
try:
runner = cgraph.cuda_graph()
Expand All @@ -78,20 +108,49 @@ def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
else:
runner = cgraph

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

def run(*inputs: torch.Tensor):
hidet_inputs = preprocess_inputs(inputs)
hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs)
torch_outputs: List[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs]
return torch_outputs

return run
def wrapper(*args: Tensor):
tensor_args = []
for param, arg in zip(inputs, args):
if isinstance(param, Tensor):
tensor_args.append(arg)
elif isinstance(param, SymbolVar):
dtype = param.type
assert isinstance(dtype, DataType)
if dtype.name == 'int32':
from hidet.ffi import runtime_api

runtime_api.set_symbol_value(param.name, int(arg))
else:
raise ValueError(f'hidet_backend: unsupported symbolic dtype {dtype}. We only support int32 now.')
else:
# ignore constant
pass
outputs: Sequence[torch.Tensor] = run(*tensor_args)
ret = deserialize_output(output_format, outputs)
return ret

logger.info('finish generating the executor')

return wrapper

def hidet_backend(graph_module, example_inputs):
from hidet import Tensor
from .interpreter import Interpreter
from .utils import symbol_like_torch

def hidet_backend(graph_module, example_inputs):
assert isinstance(graph_module, torch.fx.GraphModule)

logger.info('received a subgraph with %d nodes to optimize', len(graph_module.graph.nodes))
Expand All @@ -105,27 +164,6 @@ def hidet_backend(graph_module, example_inputs):
# get the interpreter for the subgraph
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

# prepare dummy and symbolic inputs for correctness and flow graph construction
inputs: List[Union[Tensor, SymbolVar, int, bool, float]] = [] # for flow graph construction
for example_input in example_inputs:
if isinstance(example_input, torch.Tensor):
symbolic_input = symbol_like_torch(example_input)
inputs.append(symbolic_input)
elif isinstance(example_input, (int, bool, float)):
inputs.append(example_input)
elif isinstance(example_input, torch.SymInt):
from torch.fx.experimental.symbolic_shapes import SymNode

node: SymNode = example_input.node
try:
inputs.append(node.pytype(example_input))
except RuntimeError:
# is a symbolic scalar input
pytype2dtype = {int: dtypes.int32, float: dtypes.float32, bool: dtypes.boolean}
inputs.append(hidet.symbol_var(name=str(example_input), dtype=pytype2dtype[node.pytype]))
else:
raise ValueError(f'hidet_backend: unexpected example input {example_input}, type {type(example_input)}')

if dynamo_config['correctness_report']:
# check correctness using random inputs
def wrapper(*args):
Expand All @@ -135,45 +173,9 @@ def wrapper(*args):
return output

return wrapper
else:
logger.info('hidet: inputs: ')
for arg in inputs:
if isinstance(arg, hidet.Tensor):
logger.info('hidet: %s', arg.signature())
else:
logger.info('hidet: %s', arg)

# symbolic run to get flow graph
output = interpreter(*inputs)
output_format, output_tensors = serialize_output(output)
input_tensors = [x for x in inputs if isinstance(x, hidet.Tensor)]
flow_graph: FlowGraph = hidet.trace_from(output_tensors, inputs=input_tensors)

executor = generate_executor(flow_graph)

def wrapper(*args: Tensor):
tensor_args = []
for param, arg in zip(inputs, args):
if isinstance(param, Tensor):
tensor_args.append(arg)
elif isinstance(param, SymbolVar):
dtype = param.type
assert isinstance(dtype, DataType)
if dtype.name == 'int32':
from hidet.ffi import runtime_api

runtime_api.set_symbol_value(param.name, int(arg))
else:
raise ValueError(
f'hidet_backend: unsupported symbolic dtype {dtype}. We only support int32 now.'
)
else:
# ignore constant
pass
outputs: Sequence[torch.Tensor] = executor(*tensor_args)
ret = deserialize_output(output_format, outputs)
return ret

logger.info('finish generating the executor')
flow_graph, inputs, output_format = get_flow_graph(interpreter, example_inputs)

return wrapper
cgraph = get_compiled_graph(flow_graph)

return get_wrapper(cgraph, inputs, output_format)

0 comments on commit 0d4ce21

Please sign in to comment.