Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Option] Add debug_cache_tuning option #120

Merged
merged 3 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
jinja2<3.1.0
sphinx
sphinx-gallery
sphinx-copybutton
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,15 @@ def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
for node_output, symbolic_output in zip(node_outputs, node.outputs):
tensor_map[symbolic_output] = node_output

outputs = [tensor_map[x] for x in self.outputs]
outputs = []
for graph_output in self.outputs:
if graph_output in tensor_map:
outputs.append(tensor_map[graph_output])
elif graph_output.storage is not None:
outputs.append(graph_output) # constant output, not the graph input or produced by any operator
else:
raise RuntimeError('Graph output {} is not produced by any operator.'.format(graph_output.signature()))

GraphForwardContext.current()._trigger_after_graph(self, inputs, outputs)
return outputs[0] if len(outputs) == 1 else outputs

Expand Down
13 changes: 8 additions & 5 deletions python/hidet/graph/ops/schedules/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import time
from typing import List, Optional
import shutil
import numpy as np
from tqdm import tqdm

Expand All @@ -22,6 +23,7 @@
from hidet.ir.task import Task
from hidet.utils import TableBuilder, strict_zip, error_tolerance
from hidet.graph.tensor import randn, zeros, ones, Tensor
from hidet.option import get_option
from .common import Schedule


Expand Down Expand Up @@ -124,12 +126,9 @@ def resolve_ir_modules(
# compiled_funcs: List[Optional[CompiledFunction]] = batch_build_ir_modules(
# build_instances, parallel=parallel, verbose=verbose
# )
resolve_dir = os.path.join(output_dir, 'resolve')
compiled_funcs: List[Optional[CompiledFunction]] = build_ir_module_batch(
ir_modules,
func_name=func_name,
output_dir=os.path.join(output_dir, 'resolve'),
parallel=parallel,
verbose=verbose,
ir_modules, func_name=func_name, output_dir=resolve_dir, parallel=parallel, verbose=verbose
)
dummy_inputs = dummy_inputs_from_task(ir_modules[0].task, target_device)
best_latency = 1e30
Expand Down Expand Up @@ -176,6 +175,10 @@ def resolve_ir_modules(
best_latency = latency
best_ir_module = ir_module

# remove the resolve directory
if not get_option('debug_cache_tuning'):
shutil.rmtree(resolve_dir)

# generate summary
headers = ['idx'] + [v[0] for v in (schedules[0].keys() + schedules[0].derived_keys())] + ['Error', 'latency']
with TableBuilder(headers=headers) as tb:
Expand Down
8 changes: 7 additions & 1 deletion python/hidet/graph/ops/schedules/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Union, Sequence, TypeVar, Any, Dict, List, Optional
import os
import itertools
import shutil
from tqdm import tqdm
import numpy as np
from hidet.ir.func import IRModule
Expand Down Expand Up @@ -143,8 +144,9 @@ def tune(template_func, task: Task, target_device: str, working_dir: str) -> IRM
return ir_modules[0]

# build ir modules into compiled functions
tuning_dir = os.path.join(working_dir, 'tuning')
compiled_funcs: List[Optional[CompiledFunction]] = build_ir_module_batch(
ir_modules, func_name=task.name, output_dir=os.path.join(working_dir, 'tuning'), parallel=True, verbose=True
ir_modules, func_name=task.name, output_dir=tuning_dir, parallel=True, verbose=True
)
assert len(compiled_funcs) == len(ir_modules)
if any(f is None for f in compiled_funcs):
Expand All @@ -163,6 +165,10 @@ def tune(template_func, task: Task, target_device: str, working_dir: str) -> IRM
latency = 1e30
latencies.append(latency)

# remove tuning directory
if not hidet.option.get_option('debug_cache_tuning'):
shutil.rmtree(tuning_dir)

# generate summary
summary = _generate_summary(ir_modules_kwargs, latencies)
with open(os.path.join(working_dir, 'tuning_summary.txt'), 'w') as f:
Expand Down
39 changes: 31 additions & 8 deletions python/hidet/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def register_option(
name: str,
type_hint: str,
description: str,
defalut_value: Any,
default_value: Any,
normalizer: Optional[Callable[[Any], Any]] = None,
choices: Optional[Iterable[Any]] = None,
checker: Optional[Callable[[Any], bool]] = None,
Expand All @@ -49,7 +49,7 @@ def register_option(
if name in registered_options:
raise KeyError(f'Option {name} has already been registered.')
registered_options[name] = OptionRegistry(
name, type_hint, description, defalut_value, normalizer, choices, checker
name, type_hint, description, default_value, normalizer, choices, checker
)
return OptionRegistry

Expand All @@ -65,24 +65,24 @@ def register_hidet_options():
type_hint='Tuple[int, int, int]',
description='The (warmup, number, repeat) parameters for benchmarking. '
'The benchmarking will run warmup + number * repeat times.',
defalut_value=(3, 10, 3),
default_value=(3, 10, 3),
).register_option(
name='search_space', #
type_hint='int',
description='The search space level.',
defalut_value=0,
default_value=0,
choices=[0, 1, 2],
).register_option(
name='cache_operator',
type_hint='bool',
description='Whether to enable operator cache on disk.',
defalut_value=True,
default_value=True,
choices=[True, False],
).register_option(
name='cache_dir',
type_hint='path',
description='The directory to store the cache.',
defalut_value=os.path.abspath(
default_value=os.path.abspath(
os.path.join(git_utils.git_repo_root(), '.hidet_cache') # developer mode
if git_utils.in_git_repo()
else os.path.join(os.path.expanduser('~'), '.hidet', 'cache') # user mode
Expand All @@ -91,15 +91,21 @@ def register_hidet_options():
).register_option(
name='parallel_build',
type_hint='bool',
defalut_value=True,
default_value=True,
description='Whether to build operators in parallel.',
choices=[True, False],
).register_option(
name='save_lower_ir',
type_hint='bool',
defalut_value=False,
default_value=False,
description='Whether to save the IR when lower an IRModule to the operator cache.',
choices=[True, False],
).register_option(
name='debug_cache_tuning',
type_hint='bool',
default_value=False,
description='Whether to cache the generated kernels during tuning.',
choices=[True, False],
)


Expand Down Expand Up @@ -472,3 +478,20 @@ def get_save_lower_ir() -> bool:
Whether to save the lower IR.
"""
return OptionContext.current().get_option('save_lower_ir')


def debug_cache_tuning(enabled: bool = True):
"""
Whether to cache the generated kernels during tuning.

.. note::

This option is only used for debugging purpose. It will generate a lot of files in the cache directory
and take a lot of disk space.

Parameters
----------
enabled: bool
Whether to debug cache tuning.
"""
OptionContext.current().set_option('debug_cache_tuning', enabled)