Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Adding support for compiled model #247

Merged
merged 3 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 1 addition & 17 deletions gallery/how-to-guides/add-new-operator-template-based.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,10 @@ def demo_usage():
# ---------------------
# If you are interested in the generated source code, here it is:

# sphinx_gallery_start_ignore
a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
op = BatchMatmulFp16Op(a, b)
c = op.get_output(0)
func = op.task_func
import os

relative_path = os.path.relpath(
func.src_path, os.path.dirname(hidet.utils.hidet_cache_dir())
)
source_path = func.src_path
# sphinx_gallery_end_ignore

# we hide the code to get the source path for simplicity
print('Generated source path (relative to hidet cache root): \n{}'.format(relative_path))
print()
print('Generated source code:')
with open(source_path, 'r') as f:
print(f.read())
print(op.task_func.source(color=True))

# %%
# Summary
Expand Down
101 changes: 101 additions & 0 deletions include/hidet/runtime/memory_planner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#pragma once
#include <list>
#include <unordered_map>
#include <cstdint>
#include <hidet/runtime/common.h>

struct Region {
int64_t start;
int64_t size;
};

struct MemoryPlanner {
std::list<Region> regions;
std::unordered_map<int64_t, int64_t> size_map;
void print() {
for (auto region : regions) {
printf("[%ld %ld] ", region.start, region.size);
}
printf("\n");
}
};

static MemoryPlanner memory_planner;
//
//int max_segments = 0;

static void memory_planner_init() {
memory_planner.size_map.clear();
memory_planner.regions.clear();
memory_planner.regions.push_back({0, -1});
}

static int64_t memory_planner_allocate(int64_t size) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();
for (auto it = memory_planner.regions.begin(); it != memory_planner.regions.end(); ++it) {
if (it->size >= size) {
auto region = *it;
if (region.size > size) {
memory_planner.regions.insert(it, {region.start + size, region.size - size});
}
memory_planner.regions.erase(it);
auto ret = region.start;
memory_planner.size_map[ret] = size;
return ret;
} else if (it->size == -1) {
int64_t start = it->start;
it->start += size;
memory_planner.size_map[start] = size;
return start;
}
}
assert(false);
return 0;
}

static void memory_planner_free(int64_t ptr) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();
int64_t start = ptr;
int64_t size = memory_planner.size_map[ptr];
auto it = memory_planner.regions.begin();
while(it != memory_planner.regions.end() && it->start <= start)
it++;
if(it == memory_planner.regions.begin()) {
if(start + size == it->start) {
it->start = start;
if(it->size != -1) {
it->size += size;
}
} else {
memory_planner.regions.insert(it, {start, size});
}
} else {
auto pit = it;
pit--;
if(start + size == it->start && start == pit->start + pit->size) {
it->start = pit->start;
if(it->size != -1) {
it->size += pit->size + size;
}
memory_planner.regions.erase(pit);
} else if (start + size == it->start){
it->start = start;
if (it->size != -1) {
it->size += size;
}
} else if (start == pit->start + pit->size) {
pit->size += size;
} else {
memory_planner.regions.insert(it, {start, size});
}
}
}

static int64_t memory_planner_used() {
auto riter = memory_planner.regions.rbegin();
return riter->start;
}
2 changes: 1 addition & 1 deletion python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
from .ir.expr import symbol_var

from .runtime.device import Device, device
from .runtime.model import save_model, load_model

from .graph import Tensor, Operator, Module, FlowGraph
from .graph import nn
from .graph import ops
from .graph import empty, randn, zeros, ones, full, randint, symbol, asarray, from_torch
from .graph import empty_like, randn_like, zeros_like, ones_like, symbol_like, full_like
from .graph import trace_from, load_graph, save_graph
from .graph import jit
from .graph import from_dlpack
from .graph import frontend
from .graph.ops import arange, linspace
Expand Down
4 changes: 1 addition & 3 deletions python/hidet/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .codegen import codegen
from .build import compile_source, load_task_func, load_lib_func

# from .build import compile_source, load_task_func, BuildInstance, batch_build_ir_modules, load_lib_func
from .build import compile_source
62 changes: 7 additions & 55 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@

import hidet.cuda
from hidet.libinfo import get_include_dirs
from hidet.ir.type import FuncType
from hidet.runtime import CompiledFunction
from hidet.ffi import PackedFunc
from hidet.ffi.ffi import library_paths
from hidet.ffi.shared_lib import SharedLibrary
from hidet.ir.task import Task # pylint: disable=unused-import


class CompilationFailed(Exception):
Expand All @@ -47,7 +42,8 @@ class SourceCompiler:
def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, str]] = None) -> None:
raise NotImplementedError()

def run_compile_command(self, command: str, src_path, out_lib_path: str):
@staticmethod
def run_compile_command(command: str, src_path, out_lib_path: str):
try:
# the directory to store the library "lib.so"
out_lib_dir = os.path.dirname(out_lib_path)
Expand Down Expand Up @@ -122,6 +118,8 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
# optimize host side code via -O3
'-O3',
# enable openmp support for cpu kernels
'-Xcompiler -fopenmp',
# the target PTX and SASS version.
Expand Down Expand Up @@ -168,7 +166,9 @@ def __init__(self):
self.include_dirs: List[str] = get_include_dirs()
self.library_dirs: List[str] = [os.path.dirname(library_paths['hidet_runtime'])]

def _resolve_gcc_path(self):
@staticmethod
@functools.lru_cache(maxsize=None)
def _resolve_gcc_path():
path: Optional[str] = shutil.which('g++')
if path is not None:
return path
Expand Down Expand Up @@ -222,51 +222,3 @@ def compile_source(src_path: str, out_lib_path: str) -> None:
compiler = GCC()

compiler.compile(src_path, out_lib_path)


def load_task_func(lib_path: str, task) -> CompiledFunction:
"""
Load task's entry function from dynamic linked library.

Parameters
----------
lib_path: str
The dynamic library path.
task: Task
The task that corresponds to the dynamic library.

Returns
-------
ret: CompiledFunction
The loaded function that can be directly called in python.
"""
try:
lib = SharedLibrary(lib_path)
except OSError as e:
print("Removed the file '{}'".format(lib_path))
os.remove(lib_path)
raise e
func_name = 'hidet_launch'
param_types = [param.type for param in task.params]
packed_func = PackedFunc(param_types=param_types, c_func_pointer=lib[func_name])

potential_src_path = os.path.join(os.path.dirname(lib_path), 'source.cu')
if os.path.isfile(potential_src_path):
src_path = potential_src_path
else:
src_path = None

return CompiledFunction(name=task.name, packed_func=packed_func, lib_path=lib_path, src_path=src_path)


def load_lib_func(
lib_path: str, func_name: str, func_type: FuncType, src_path: Optional[str] = None
) -> CompiledFunction:
try:
lib = SharedLibrary(lib_path)
except OSError as e:
print("Removed the file '{}'".format(lib_path))
os.remove(lib_path)
raise e
packed_func = PackedFunc(param_types=list(func_type.param_types), c_func_pointer=lib[func_name])
return CompiledFunction(name=func_name, packed_func=packed_func, lib_path=lib_path, src_path=src_path)