Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Option] Add option module #6

Merged
merged 3 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/python_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Python API
:caption: Python API

root
option
driver
ir/index
graph/index
Expand Down
6 changes: 6 additions & 0 deletions docs/source/python_api/option.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
hidet.option
------------

.. automodule:: hidet.option
:members:
:autosummary:
6 changes: 3 additions & 3 deletions gallery/tutorials/run-onnx-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def bench_hidet_graph(graph: hidet.FlowGraph):
# %%
# Optimize FlowGraph
# ------------------
# To optimize the model, we set the level of operator schedule space to 2 with :func:`hidet.space_level`. We also
# To optimize the model, we set the level of operator schedule space to 2 with :func:`hidet.option.search_space`. We also
# conduct graph level optimizations with :func:`hidet.graph.optimize`.

# set the search space level for kernel tuning
hidet.space_level(2)
# set the search space level for kernel tuning,
hidet.option.search_space(2)

# optimize the flow graph, such as operator fusion
with hidet.graph.PassContext() as ctx:
Expand Down
4 changes: 1 addition & 3 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Hidet is an open-source DNN inference framework based on compilation.
"""
import sys
from . import option
from . import ir
from . import backend
from . import utils
Expand All @@ -18,12 +19,9 @@
from .graph import ops
from .graph import empty, randn, zeros, ones, full, randint, symbol, array, from_torch
from .graph import empty_like, randn_like, zeros_like, ones_like, symbol_like, full_like, randint_like
from .graph import space_level, get_space_level, profile_config, get_profile_config, cache_operator
from .graph import trace_from, load_graph, save_graph
from .graph import jit

from .utils import hidet_set_cache_root as set_cache_root

from .lang import script, script_module

sys.setrecursionlimit(10000)
89 changes: 25 additions & 64 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import multiprocessing
import logging
from hashlib import sha256

from hidet import option
from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument
from hidet.backend import codegen, compile_source, load_task_func, load_lib_func
from hidet.utils import hidet_cache_dir
from hidet.utils.py import cyan, green
from hidet.ir.task import Task, TaskContext
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.type import FuncType
from hidet.runtime.module import compiled_task_cache, CompiledFunction
Expand All @@ -16,50 +17,17 @@
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

cache_disabled = False


def disable_cache(disable: bool = False):
global cache_disabled
cache_disabled = not disable


def build_task(
task: Task,
space_level=0,
target_device='cuda',
warmup=3,
number=10,
repeat=3,
use_cache=True,
cache_dir=None,
load=True,
) -> Optional[CompiledFunction]:
def build_task(task: Task, target_device='cuda', load=True) -> Optional[CompiledFunction]:
"""
Build a task into a compiled function.

Parameters
----------
task: Task
The task to be built.
space_level: int
The space level of the schedule space. Candidates are 0, 1, 2, 3. Space level 0 indicates to use the default
schedule without tuning. Space level 1 indicates to search in a small search space. Space level 2 indicates to
search in the largest search space. Larger search space leads to better performance, but also takes more time to
tune.
target_device: str
The target device. Candidates are 'cuda' and 'cpu'.
warmup: int
The number of warmup runs when benchmarking different schedules.
number: int
The number of runs per repeat when benchmarking different schedules.
repeat: int
The number of repeats when benchmarking different schedules.
use_cache: bool
Whether to use cache on disk.
cache_dir: str
The cache directory. The default is None, which means to use the default cache directory:
**hidet_cache_dir/ops**.
load: bool
Whether to load the compiled function. If False, the compiled function will not be loaded, and None is returned.
Otherwise, the compiled function is loaded and returned.
Expand All @@ -68,26 +36,27 @@ def build_task(
compiled_func:
When load is True, the compiled function is returned. Otherwise, None is returned.
"""
# pylint: disable=too-many-arguments, too-many-locals
task_string: str = str(task)
compiled_func: Optional[CompiledFunction] = None

space_level = option.get_option('search_space')
op_cache_dir = os.path.join(option.get_option('cache_dir'), './ops')
use_cache = option.get_option('cache_operator')

# check in-memory cache
if compiled_task_cache.contains(target_device, space_level, task_string):
if load:
compiled_func = compiled_task_cache.get(target_device, space_level, task_string)
else:
# check on-disk cache
if cache_dir is None:
cache_dir = os.path.join(hidet_cache_dir(), 'ops')
config_str = f'{target_device}_space_{space_level}'
task_hash = sha256(task_string.encode()).hexdigest()[:16]
task_dir = os.path.join(cache_dir, config_str, task.name, task_hash)
task_dir = os.path.join(op_cache_dir, config_str, task.name, task_hash)
src_path = os.path.join(task_dir, 'source.cu')
lib_path = os.path.join(task_dir, 'lib.so')

# use previously generated library when available
if not cache_disabled and use_cache and os.path.exists(lib_path):
if use_cache and os.path.exists(lib_path):
logger.debug(f"Load cached task binary {green(task.name)} from path: \n{cyan(lib_path)}")
if load:
compiled_func = load_task_func(lib_path, task)
Expand All @@ -100,15 +69,16 @@ def build_task(
with open(os.path.join(task_dir, 'task.txt'), 'w') as f:
f.write(task_string)
# implement task
with TaskContext(space_level, warmup, number, repeat, resolve_out_dir=task_dir):
ir_module = task.implement(target=target_device)
ir_module = task.implement(target=target_device, workding_dir=task_dir)
# lower ir module
with PassContext(
instruments=[
SaveIRInstrument(out_dir=os.path.join('./outs/ir', task.name, task_hash)),
ProfileInstrument(log_file=os.path.join('./outs/ir', task.name, task_hash, 'lower_time.txt')),
if option.get_option('save_lower_ir'):
instruments = [
SaveIRInstrument(out_dir=os.path.join(task_dir, './ir')),
ProfileInstrument(log_file=os.path.join(task_dir, './lower_time.txt')),
]
):
else:
instruments = []
with PassContext(instruments=instruments):
ir_module = lower(ir_module)
# code generation
codegen(ir_module, src_out_path=src_path)
Expand All @@ -122,24 +92,15 @@ def build_task(


def _build_task_job(args):
task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, load = args
build_task(task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, load)
task, target_device, dumped_options = args
option.restore_options(dumped_options)
build_task(task, target_device, load=False)


def build_batch_task(
tasks: List[Task],
space_level: int,
target_device: str = 'cuda',
warmup: int = 3,
number: int = 10,
repeat: int = 3,
parallel=True,
use_cache=True,
cache_dir=None,
):
# pylint: disable=too-many-arguments
jobs = [(task, space_level, target_device, warmup, number, repeat, use_cache, cache_dir, False) for task in tasks]
if parallel and len(tasks) > 1:
def build_batch_task(tasks: List[Task], target_device: str = 'cuda'):
dumped_options = option.dump_options()
jobs = [(task, target_device, dumped_options) for task in tasks]
if option.get_option('parallel_build') and len(jobs) > 1:
with multiprocessing.Pool() as pool:
pool.map(_build_task_job, jobs)
else:
Expand Down
1 change: 0 additions & 1 deletion python/hidet/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .tensor import array, randn, empty, zeros, ones, symbol, randint, randn_like, empty_like, zeros_like, ones_like
from .tensor import symbol_like, randint_like, from_torch, full, full_like
from .operator import space_level, get_space_level, profile_config, get_profile_config, cache_operator
from .ir import trace_from, load_graph, save_graph
from .transforms import optimize
from .modules import nn
Expand Down
32 changes: 9 additions & 23 deletions python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict

import hidet.graph.operator
from hidet import option
from hidet.graph.tensor import Tensor, empty_like
from hidet.graph.operator import Operator
from hidet.utils.doc import Doc, NewLine, Text, doc_join
Expand Down Expand Up @@ -100,38 +101,23 @@ def build(self):
tasks = []
tunable_tasks = []
task_keys = set()
space_level = hidet.get_space_level()
profile_config = hidet.get_profile_config()
search_space = hidet.option.get_option('search_space')
for node in self.nodes:
if node.task_func is None:
# if space_level == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
task_key = hash(str(node.task))
if task_key in task_keys:
continue
task_keys.add(task_key)
# if node.task.fast_implement(space_level):
if space_level == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
if search_space == 0 or 'implement_cuda' not in node.task.__class__.__dict__:
tasks.append(node.task)
else:
tunable_tasks.append(node.task)
hidet.driver.build_batch_task(
tasks,
space_level,
warmup=profile_config.warmup,
number=profile_config.number,
repeat=profile_config.repeat,
parallel=True,
)
# hidet.driver.build_batch_task(tasks, space_level, warmup=profile_config.warmup,
# number=profile_config.number, repeat=profile_config.repeat, parallel=False)
hidet.driver.build_batch_task(
tunable_tasks,
space_level,
warmup=profile_config.warmup,
number=profile_config.number,
repeat=profile_config.repeat,
parallel=False,
)

hidet.driver.build_batch_task(tasks)

with option.context():
hidet.option.parallel_build(False)
hidet.driver.build_batch_task(tunable_tasks) # build tunable tasks one by one

def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]:
"""Run the computation graph.
Expand Down