Skip to content

Commit

Permalink
[Device] Add support for mixed cpu and cuda kernels in the same flow …
Browse files Browse the repository at this point in the history
…graph (#270)
  • Loading branch information
yaoyaoding committed Jun 5, 2023
1 parent ca607f9 commit 9cbbf6d
Show file tree
Hide file tree
Showing 35 changed files with 747 additions and 598 deletions.
31 changes: 17 additions & 14 deletions include/hidet/runtime/memory_planner.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <list>
#include <vector>
#include <unordered_map>
#include <cstdint>
#include <hidet/runtime/common.h>
Expand All @@ -12,23 +13,22 @@ struct Region {
struct MemoryPlanner {
std::list<Region> regions;
std::unordered_map<int64_t, int64_t> size_map;
void print() {
for (auto region : regions) {
printf("[%ld %ld] ", region.start, region.size);
}
printf("\n");
}
};

static MemoryPlanner memory_planner;
static std::vector<MemoryPlanner> memory_planners;

static void memory_planner_init() {
memory_planner.size_map.clear();
memory_planner.regions.clear();
memory_planner.regions.push_back({0, -1});
static void memory_planner_init(int idx) {
if(memory_planners.size() <= idx) {
memory_planners.resize(idx + 1);
}
memory_planners[idx].size_map.clear();
memory_planners[idx].regions.clear();
memory_planners[idx].regions.push_back({0, -1});
}

static int64_t memory_planner_allocate(int64_t size) {
static int64_t memory_planner_allocate(int idx, int64_t size) {
MemoryPlanner &memory_planner = memory_planners[idx];

if(size == 0) {
return -1;
}
Expand All @@ -55,7 +55,9 @@ static int64_t memory_planner_allocate(int64_t size) {
return 0;
}

static void memory_planner_free(int64_t ptr) {
static void memory_planner_free(int idx, int64_t ptr) {
MemoryPlanner &memory_planner = memory_planners[idx];

if(ptr == -1) {
return;
}
Expand Down Expand Up @@ -96,7 +98,8 @@ static void memory_planner_free(int64_t ptr) {
}
}

static int64_t memory_planner_used() {
static int64_t memory_planner_used(int idx) {
MemoryPlanner &memory_planner = memory_planners[idx];
auto riter = memory_planner.regions.rbegin();
return riter->start;
}
82 changes: 43 additions & 39 deletions python/hidet/cuda/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-name-in-module, c-extension-no-member
from typing import List, Sequence, Optional
from typing import List, Sequence, Optional, Any, Callable
from cuda import cudart
from cuda.cudart import cudaGraphExec_t
from hidet.graph.tensor import Tensor, zeros_like, randn_like
from hidet.graph.tensor import Tensor
from hidet.runtime.storage import MemoryPool, CudaMemoryAPI, memory_pool
from hidet.runtime.device import Device
from hidet.utils import same_list, exiting
Expand All @@ -22,6 +22,10 @@
from .memory import memcpy_async


class CudaGraphCreationError(Exception):
pass


class FreezableMemoryAPI(CudaMemoryAPI):
def __init__(self, device: Device):
super().__init__(device)
Expand Down Expand Up @@ -81,64 +85,64 @@ def instantiate(self) -> cudaGraphExec_t:

class CudaGraph:
"""
A CUDA graph that executes a :class:`~hidet.graph.ir.flow_graph.FlowGraph` on the GPU.
Create a cuda graph to capture and replay the execution of a series of cuda kernels launched in a function.
You can create the CUDA graph by calling :meth:`~hidet.graph.ir.flow_graph.FlowGraph.cuda_graph`.
The graph is created by calling the constructor with the following arguments:
Parameters
----------
flow_graph: FlowGraph
The flow graph to be executed.
f_create_inputs: Callable[[], List[Tensor]]
A function that creates the input tensors of the graph. This function is called before f_run.
f_run: Callable[[List[Tensor]], List[Tensor]]
A function that runs the graph. Only the cuda kernels launched in this function will be captured. Rerunning
this function must launch the same cuda kernels in the same order. The input tensors of this function will be
the output tensors of the f_create_inputs function.
ref_objs: Any
The objects that should keep alive during the lifetime of the cuda graph. It may contain the weight tensors
that are used in the graph.
"""

def __init__(self, flow_graph):
from hidet.graph.flow_graph import FlowGraph

flow_graph: FlowGraph

def __init__(
self,
f_create_inputs: Callable[[], List[Tensor]],
f_run: Callable[[List[Tensor]], List[Tensor]],
ref_objs: List[Any],
):
self._memory_api: FreezableMemoryAPI = FreezableMemoryAPI(Device('cuda', current_device()))
self._memory_pool: MemoryPool = MemoryPool(
memory_api=self._memory_api, block_size=4096, max_reserve_size=10 * 1024**3
)
self._graph_capture: CudaGraphCapture = CudaGraphCapture()
self._flow_graph: FlowGraph = flow_graph
self._inputs: List[Tensor]
self._outputs: List[Tensor]
self._inputs: List[Tensor] = []
self._outputs: List[Tensor] = []
self._ref_objs: List[Any] = ref_objs

with memory_pool(self._memory_pool):
# update the nodes and inputs of the flow graph
flow_graph.update_nodes()

# prepare the dummpy inputs
inputs = []
for tensor in flow_graph.inputs:
if tensor.is_symbolic():
if tensor.dtype.is_float():
inputs.append(randn_like(tensor, device='cuda'))
else:
inputs.append(zeros_like(tensor, device='cuda'))
else:
inputs.append(tensor)
self._inputs = inputs

# run and capture the graph execution
flow_graph.forward(*self._inputs) # warm up, avoid memory allocation during capturing
flow_graph.forward(*self._inputs)
# create the input tensors
self._inputs = f_create_inputs()

# warmup the run function
num_warmup = 2
for _ in range(num_warmup):
f_run(self._inputs)

# capture the cuda graph
self._memory_api.freeze()
with self._graph_capture:
outputs = flow_graph.forward(*self._inputs)
self._outputs = f_run(self._inputs)

# process the outputs
self._outputs = [outputs] if isinstance(outputs, Tensor) else outputs

# instantiate the captured graph
# instantiate the cuda graph
self._graph_exec: cudaGraphExec_t = self._graph_capture.instantiate()

def __call__(self, *inputs: Tensor):
if len(inputs) == 0:
return self.run()
self.run()
else:
self.run(inputs)
if len(self.outputs) == 1:
return self.outputs[0]
else:
return self.run(inputs)
return self.outputs

def __del__(self, is_shutting_down=exiting.is_exiting):
if is_shutting_down():
Expand Down

0 comments on commit 9cbbf6d

Please sign in to comment.