Skip to content

Commit

Permalink
[Operator] Add float16 precision matrix multiplication (#26)
Browse files Browse the repository at this point in the history
* adding f16 op

use refactor chain

add bench code

add a new pass; fix some bugs

* adding matmul

* .

* .

* .

* add a new tune module

* .

* .

* allow pack iter vars in for in hidet script

* .

* .

* .

* great!

* .

* .

* .

* delete experiment dir

* fix format & lint

* add forward context

* add another forward instrument

* .
  • Loading branch information
yaoyaoding committed Dec 13, 2022
1 parent 7b55f9d commit a8e0ee2
Show file tree
Hide file tree
Showing 32 changed files with 1,418 additions and 135 deletions.
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ cmake_minimum_required(VERSION 3.19)

project(hidet C CXX CUDA)

# common configs
set(CMAKE_C_COMPILER_LAUNCHER ccache)
set(CMAKE_CXX_COMPILER_LAUNCHER ccache)
set(CMAKE_CUDA_COMPILER_LAUNCHER ccache)

# config hidet
if(EXISTS "${CMAKE_BINARY_DIR}/config.cmake")
include(${CMAKE_BINARY_DIR}/config.cmake)
Expand Down
17 changes: 11 additions & 6 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from hidet.libinfo import get_include_dirs
from hidet.ir.type import FuncType
from hidet.ir.func import IRModule
from hidet.transforms import PassContext, lower
from hidet.runtime import CompiledFunction
from hidet.ffi import PackedFunc
Expand Down Expand Up @@ -95,6 +96,10 @@ def compile_source(src_path: str, out_lib_path: str, keep_ptx=False) -> None:
# shared cuda runtime library is used (.so), instead of static one (.a). used to reduce binary size.
'--cudart',
'shared',
# supress warming no 177 like: "warning #177-D: variable "xxx" was declared but never referenced"
# see https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#generic-tool-options-diag-suppress
'--diag-suppress',
'177',
# generate shared library (lib.so).
'--shared',
# the source path.
Expand Down Expand Up @@ -217,11 +222,11 @@ def __init__(self, ir_module, output_dir, keep_ir=False, nvcc_keep=True, verbose
verbose: bool
Reserved.
"""
self.ir_module = ir_module
self.output_dir = output_dir
self.keep_ir = keep_ir
self.nvcc_keep = nvcc_keep
self.verbose = verbose
self.ir_module: IRModule = ir_module
self.output_dir: str = output_dir
self.keep_ir: bool = keep_ir
self.nvcc_keep: bool = nvcc_keep
self.verbose: bool = verbose


def build_ir_module_job(build_instance: BuildInstance) -> Optional[str]:
Expand Down Expand Up @@ -283,7 +288,7 @@ def batch_build_ir_modules(build_instances, parallel=True, verbose=False) -> Lis
Returns
-------
funcs: List[Optional[CompiledFunction]]
funcs:
The compiled functions, in the same order as build_instances.
When the build for a build instance failed, None for that instance is returned.
"""
Expand Down
14 changes: 13 additions & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import os
import numpy as np
from hidet.ir.dialects.pattern import AnyExpr
from hidet.ir.type import DataType, PointerType, TensorPointerType, ReferenceType, TensorType, TypeNode, FuncType
Expand Down Expand Up @@ -273,7 +274,12 @@ def visit_RightShift(self, e: RightShift):
def visit_TensorElement(self, e: TensorElement):
if e.protected:
raise ValueError('The protected reading of tensor element should be lowered in lower_protect_access pass.')
return self(e.base) + doc_join(['[' + self(idx) + ']' for idx in e.indices], '')
base_doc = self(e.base)
index_doc = doc_join(['[' + self(idx) + ']' for idx in e.indices], '')
if isinstance(e.base, Address):
return Text('(') + base_doc + Text(')') + index_doc
else:
return base_doc + index_doc

def visit_IfThenElse(self, e: IfThenElse):
return '(' + self(e.cond) + ' ? ' + self(e.then_expr) + ' : ' + self(e.else_expr) + ')'
Expand All @@ -297,6 +303,9 @@ def visit_Call(self, e: Call):
func_name = Text(self.canonize_funcname(func_name))
if func.kind == 'cuda_kernel':

if isinstance(func.attrs['cuda_block_dim'], int) and func.attrs['cuda_block_dim'] > 1024:
raise ValueError('CUDA block dimension cannot be larger than 1024.')

def dim3_str(dims):
if isinstance(dims, (int, Expr)):
return self(dims)
Expand Down Expand Up @@ -570,6 +579,9 @@ def codegen(ir_module: IRModule, src_out_path: Optional[str] = None) -> str:
doc = gen(ir_module)
code = str(doc)
if src_out_path is not None:
dir_path = os.path.dirname(src_out_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(src_out_path, 'w') as f:
f.write(code)
return code
163 changes: 135 additions & 28 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import List, Optional
import subprocess
from typing import List, Optional, Sequence, Tuple
import os
import multiprocessing
import logging
from hashlib import sha256
import psutil
from tqdm import tqdm

from hidet import option
from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument
from hidet.backend import codegen, compile_source, load_task_func, load_lib_func
from hidet.utils.py import cyan, green
from hidet.backend.build import CompilationFailed
from hidet.utils.py import cyan, green, Timer
from hidet.ir.task import Task
from hidet.ir.func import IRModule, Function
from hidet.ir.type import FuncType
Expand Down Expand Up @@ -92,52 +96,66 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled


def _build_task_job(args):
task, target_device, dumped_options = args
option.restore_options(dumped_options)
build_task(task, target_device, load=False)
try:
task, target_device, dumped_options = args
option.restore_options(dumped_options)
build_task(task, target_device, load=False)
return True
except CompilationFailed as e:
if option.get_option('parallel_build'):
return False
else:
raise e


def build_batch_task(tasks: List[Task], target_device: str = 'cuda'):
def build_task_batch(tasks: List[Task], target_device: str = 'cuda', raise_on_error: bool = True):
dumped_options = option.dump_options()
jobs = [(task, target_device, dumped_options) for task in tasks]
if option.get_option('parallel_build') and len(jobs) > 1:
with multiprocessing.Pool() as pool:
pool.map(_build_task_job, jobs)
status_list = list(pool.map(_build_task_job, jobs))
else:
map(_build_task_job, jobs)
status_list = list(map(_build_task_job, jobs))
if not all(status_list) and raise_on_error:
msg = ['Failed to build {} tasks:'.format(sum(1 for s in status_list if not s))]
for task, status in zip(tasks, status_list):
if not status:
msg.append(f' {task.signature()}')
msg.append('Please turn off parallel build to see the error message:')
msg.append(' hidet.option.parallel_build(False)')
raise RuntimeError('\n'.join(msg))


def build_ir_module(
ir_module: IRModule,
func_name: str,
keep_ptx=False,
working_dir='./outs',
verbose=False,
func_type: Optional[FuncType] = None,
output_dir='./outs/ir_module',
save_ir: bool = True,
profile_pass: bool = True,
load: bool = True,
use_hash_dir: bool = True,
):
module_string = str(ir_module)
module_hash = sha256(module_string.encode()).hexdigest()[:16]
working_dir = os.path.join(working_dir, 'ir_module', module_hash)
src_path = os.path.join(working_dir, 'source.cu')
lib_path = os.path.join(working_dir, 'lib.so')
if use_hash_dir:
hash_dir = sha256(str(ir_module).encode()).hexdigest()[:16]
output_dir = os.path.join(output_dir, hash_dir)

if verbose:
print(f'Compiling {src_path}')
src_path = os.path.join(output_dir, 'source.cu')
lib_path = os.path.join(output_dir, 'lib.so')

# lower ir module
with PassContext(
instruments=[
SaveIRInstrument(out_dir=working_dir),
ProfileInstrument(log_file=os.path.join(working_dir, 'lower_time.txt')),
]
):
instruments = []
if save_ir:
instruments.append(SaveIRInstrument(out_dir=os.path.join(output_dir, './ir')))
if profile_pass:
instruments.append(ProfileInstrument(log_file=os.path.join(output_dir, './lower_time.txt')))
with PassContext(instruments=instruments):
ir_module = lower(ir_module)

# code generation
codegen(ir_module, src_out_path=src_path)

# compile source code
compile_source(src_path, out_lib_path=lib_path, keep_ptx=keep_ptx)
compile_source(src_path, out_lib_path=lib_path, keep_ptx=False)

# get function type
func: Function = ir_module.lookup(func_name)
Expand All @@ -147,5 +165,94 @@ def build_ir_module(
else:
func_type = FuncType.from_func(func)

# load function
return load_lib_func(lib_path, func_name, func_type=func_type)
if load:
# load function
return load_lib_func(lib_path, func_name, func_type=func_type)
else:
return lib_path, func_name, func_type


def _build_ir_module_job(args) -> Optional[Tuple[str, str, FuncType]]:
ir_module, func_name, output_dir, dumped_options = args
option.restore_options(dumped_options)
try:
return build_ir_module(
ir_module, func_name, output_dir, save_ir=False, profile_pass=False, load=False, use_hash_dir=False
)
except subprocess.CalledProcessError:
print('Failed launch subprocess to compile the lowered source code via nvcc.')
return None
except CompilationFailed:
print('Failed to compile the lowered source code via nvcc.')
return None


def build_ir_module_batch(
ir_modules: Sequence[IRModule], func_name: str, output_dir: str, parallel=True, verbose=False
) -> List[Optional[CompiledFunction]]:
"""
Build a batch of ir modules.
Parameters
----------
ir_modules: Sequence[IRModule]
A sequence of ir modules to build.
func_name: str
The name of the function to load after building.
output_dir: str
The output directory to save the compiled library and source code (lib.so and source.cu).
parallel: bool
Whether build in parallel. Default True.
verbose: bool
Whether show the progress and summary. Default False.
Returns
-------
funcs:
The compiled functions, in the same order as build_instances.
When the build for a build instance failed, None for that instance is returned.
"""
with Timer() as timer:
dumped_options = option.dump_options()
jobs = [
(ir_module, func_name, os.path.join(output_dir, str(idx)), dumped_options)
for idx, ir_module in enumerate(ir_modules)
]
build_results = []
if parallel:
# Set the affinity of current process. Some package such as numpy will change affinity of current process,
# which might limit the parallelism of compilation.
os.sched_setaffinity(0, range(os.cpu_count()))

# the maximum number of processes is limited by the number of cores and memory
mem_for_worker = 1.5 * 1024 * 1024 * 1024 # 1.5 GiB
num_workers = min(max(int(psutil.virtual_memory().available // mem_for_worker), 1), psutil.cpu_count())

with multiprocessing.Pool(processes=num_workers) as pool:
for build_result in tqdm(
pool.imap(_build_ir_module_job, jobs), desc='Compiling', total=len(jobs), disable=not verbose
):
build_results.append(build_result)
else:
# sequential build
build_results = list(map(_build_ir_module_job, jobs))

# load compiled functions
funcs: List[Optional[CompiledFunction]] = []
for build_result in build_results:
if build_result is not None:
lib_path, func_name, func_type = build_result
funcs.append(load_lib_func(lib_path, func_name, func_type))
else:
funcs.append(None)
if verbose:
print(
'Batch build {} modules within {:.3f} seconds, on average {:.1f} seconds per module.'.format(
len(jobs), timer.elapsed_seconds(), timer.elapsed_seconds() / len(jobs)
)
)
return funcs
2 changes: 1 addition & 1 deletion python/hidet/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .tensor import array, randn, empty, zeros, ones, symbol, randint, randn_like, empty_like, zeros_like, ones_like
from .tensor import symbol_like, randint_like, from_torch, full, full_like
from .ir import trace_from, load_graph, save_graph
from .ir import trace_from, load_graph, save_graph, forward_context
from .transforms import optimize
from .modules import nn
from .jit import jit
2 changes: 1 addition & 1 deletion python/hidet/graph/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import flow_graph
from . import functors

from .flow_graph import FlowGraph, Tensor, Operator, trace_from, load_graph, save_graph
from .flow_graph import FlowGraph, Tensor, Operator, trace_from, load_graph, save_graph, forward_context
from .functors import GraphRewriter, GraphVisitor

0 comments on commit a8e0ee2

Please sign in to comment.