Skip to content

Commit

Permalink
[Option] Add option module (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding committed Nov 3, 2022
1 parent 0829f80 commit b9bacad
Show file tree
Hide file tree
Showing 26 changed files with 626 additions and 494 deletions.
1 change: 1 addition & 0 deletions docs/source/python_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Python API
:caption: Python API

root
option
driver
ir/index
graph/index
Expand Down
6 changes: 6 additions & 0 deletions docs/source/python_api/option.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
hidet.option
------------

.. automodule:: hidet.option
:members:
:autosummary:
6 changes: 3 additions & 3 deletions gallery/tutorials/run-onnx-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def bench_hidet_graph(graph: hidet.FlowGraph):
# %%
# Optimize FlowGraph
# ------------------
# To optimize the model, we set the level of operator schedule space to 2 with :func:`hidet.space_level`. We also
# To optimize the model, we set the level of operator schedule space to 2 with :func:`hidet.option.search_space`. We also
# conduct graph level optimizations with :func:`hidet.graph.optimize`.

# set the search space level for kernel tuning
hidet.space_level(2)
# set the search space level for kernel tuning,
hidet.option.search_space(2)

# optimize the flow graph, such as operator fusion
with hidet.graph.PassContext() as ctx:
Expand Down
4 changes: 1 addition & 3 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Hidet is an open-source DNN inference framework based on compilation.
"""
import sys
from . import option
from . import ir
from . import backend
from . import utils
Expand All @@ -18,12 +19,9 @@
from .graph import ops
from .graph import empty, randn, zeros, ones, full, randint, symbol, array, from_torch
from .graph import empty_like, randn_like, zeros_like, ones_like, symbol_like, full_like, randint_like
from .graph import space_level, get_space_level, profile_config, get_profile_config, cache_operator
from .graph import trace_from, load_graph, save_graph
from .graph import jit

from .utils import hidet_set_cache_root as set_cache_root

from .lang import script, script_module

sys.setrecursionlimit(10000)
89 changes: 25 additions & 64 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import multiprocessing
import logging
from hashlib import sha256

from hidet import option
from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument
from hidet.backend import codegen, compile_source, load_task_func, load_lib_func
from hidet.utils import hidet_cache_dir
from hidet.utils.py import cyan, green
from hidet.ir.task import Task, TaskContext
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.type import FuncType
from hidet.runtime.module import compiled_task_cache, CompiledFunction
Expand All @@ -16,50 +17,17 @@
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

cache_disabled = False


def disable_cache(disable: bool = False):
global cache_disabled
cache_disabled = not disable


def build_task(
task: Task,
space_level=0,
target_device='cuda',
warmup=3,
number=10,
repeat=3,
use_cache=True,
cache_dir=None,
load=True,
) -> Optional[CompiledFunction]:
def build_task(task: Task, target_device='cuda', load=True) -> Optional[CompiledFunction]:
"""
Build a task into a compiled function.
Parameters
----------
task: Task
The task to be built.
space_level: int
The space level of the schedule space. Candidates are 0, 1, 2, 3. Space level 0 indicates to use the default
schedule without tuning. Space level 1 indicates to search in a small search space. Space level 2 indicates to
search in the largest search space. Larger search space leads to better performance, but also takes more time to
tune.
target_device: str
The target device. Candidates are 'cuda' and 'cpu'.
warmup: int
The number of warmup runs when benchmarking different schedules.
number: int
The number of runs per repeat when benchmarking different schedules.
repeat: int
The number of repeats when benchmarking different schedules.
use_cache: bool
Whether to use cache on disk.
cache_dir: str
The cache directory. The default is None, which means to use the default cache directory:
**hidet_cache_dir/ops**.
load: bool
Whether to load the compiled function. If False, the compiled function will not be loaded, and None is returned.
Otherwise, the compiled function is loaded and returned.
Expand All @@ -68,26 +36,27 @@ def build_task(
compiled_func:
When load is True, the compiled function is returned. Otherwise, None is returned.
"""
# pylint: disable=too-many-arguments, too-many-locals
task_string: str = str(task)
compiled_func: Optional[CompiledFunction] = None

space_level = option.get_option('search_space')
op_cache_dir = os.path.join(option.get_option('cache_dir'), './ops')
use_cache = option.get_option('cache_operator')

# check in-memory cache
if compiled_task_cache.contains(target_device, space_level, task_string):
if load:
compiled_func = compiled_task_cache.get(target_device, space_level, task_string)
else:
# check on-disk cache
if cache_dir is None:
cache_dir = os.path.join(hidet_cache_dir(), 'ops')
config_str = f'{target_device}_space_{space_level}'
task_hash = sha256(task_string.encode()).hexdigest()[:16]
task_dir = os.path.join(cache_dir, config_str, task.name, task_hash)
task_dir = os.path.join(op_cache_dir, config_str, task.name, task_hash)
src_path = os.path.join(task_dir, 'source.cu')
lib_path = os.path.join(task_dir, 'lib.so')

# use previously generated library when available
if not cache_disabled and use_cache and os.path.exists(lib_path):
if use_cache and os.path.exists(lib_path):
logger.debug(f"Load cached task binary {green(task.name)} from path: \n{cyan(lib_path)}")
if load:
compiled_func = load_task_func(lib_path, task)
Expand All @@ -100,15 +69,16 @@ def build_task(
with open(os.path.join(task_dir, 'task.txt'), 'w') as f:
f.write(task_string)
# implement task
with TaskContext(space_level, warmup, number, repeat, resolve_out_dir=task_dir):
ir_module = task.implement(target=target_device)
ir_module = task.implement(target=target_device, workding_dir=task_dir)
# lower ir module
with PassContext(
instruments=[
SaveIRInstrument(out_dir=os.path.join('./outs/ir', task.name, task_hash)),
ProfileInstrument(log_file=os.path.join('./outs/ir', task.name, task_hash, 'lower_time.txt')),
if option.get_option('save_lower_ir'):
instruments = [
SaveIRInstrument(out_dir=os.path.join(task_dir, './ir')),
ProfileInstrument(log_file=os.path.join(task_dir, './lower_time.txt')),
]
):
else:
instruments = []
with PassContext(instruments=instruments):
ir_module = lower(ir_module)
# code generation
codegen(ir_module, src_out_path=src_path)
Expand All @@ -122,24 +92,15 @@ def build_task(


def _build_task_job(args):
task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, load = args
build_task(task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, load)
task, target_device, dumped_options = args
option.restore_options(dumped_options)
build_task(task, target_device, load=False)


def build_batch_task(
tasks: List[Task],
space_level: int,
target_device: str = 'cuda',
warmup: int = 3,
number: int = 10,
repeat: int = 3,
parallel=True,
use_cache=True,
cache_dir=None,
):
# pylint: disable=too-many-arguments
jobs = [(task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, False) for task in tasks]
if parallel and len(tasks) > 1:
def build_batch_task(tasks: List[Task], target_device: str = 'cuda'):
dumped_options = option.dump_options()
jobs = [(task, target_device, dumped_options) for task in tasks]
if option.get_option('parallel_build') and len(jobs) > 1:
with multiprocessing.Pool() as pool:
pool.map(_build_task_job, jobs)
else:
Expand Down
1 change: 0 additions & 1 deletion python/hidet/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .tensor import array, randn, empty, zeros, ones, symbol, randint, randn_like, empty_like, zeros_like, ones_like
from .tensor import symbol_like, randint_like, from_torch, full, full_like
from .operator import space_level, get_space_level, profile_config, get_profile_config, cache_operator
from .ir import trace_from, load_graph, save_graph
from .transforms import optimize
from .modules import nn
Expand Down
32 changes: 9 additions & 23 deletions python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict

import hidet.graph.operator
from hidet import option
from hidet.graph.tensor import Tensor, empty_like
from hidet.graph.operator import Operator
from hidet.utils.doc import Doc, NewLine, Text, doc_join
Expand Down Expand Up @@ -100,38 +101,23 @@ def build(self):
tasks = []
tunable_tasks = []
task_keys = set()
space_level = hidet.get_space_level()
profile_config = hidet.get_profile_config()
search_space = hidet.option.get_option('search_space')
for node in self.nodes:
if node.task_func is None:
# if space_level == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
task_key = hash(str(node.task))
if task_key in task_keys:
continue
task_keys.add(task_key)
# if node.task.fast_implement(space_level):
if space_level == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
if search_space == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
tasks.append(node.task)
else:
tunable_tasks.append(node.task)
hidet.driver.build_batch_task(
tasks,
space_level,
warmup=profile_config.warmup,
number=profile_config.number,
repeat=profile_config.repeat,
parallel=True,
)
# hidet.driver.build_batch_task(tasks, space_level, warmup=profile_config.warmup,
# number=profile_config.number, repeat=profile_config.repeat, parallel=False)
hidet.driver.build_batch_task(
tunable_tasks,
space_level,
warmup=profile_config.warmup,
number=profile_config.number,
repeat=profile_config.repeat,
parallel=False,
)

hidet.driver.build_batch_task(tasks)

with option.context():
hidet.option.parallel_build(False)
hidet.driver.build_batch_task(tunable_tasks) # build tunable tasks one by one

def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
"""Run the computation graph.
Expand Down

0 comments on commit b9bacad

Please sign in to comment.