Skip to content

Commit

Permalink
[WIP] put kernel compiled bits in a dict
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarkall committed Oct 1, 2020
1 parent e5d76b9 commit 898a9da
Showing 1 changed file with 52 additions and 13 deletions.
65 changes: 52 additions & 13 deletions numba/cuda/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections
import ctypes
import inspect
import os
Expand All @@ -17,6 +16,7 @@
from numba.core.dispatcher import OmittedArg
from numba.core.errors import NumbaDeprecationWarning
from numba.core.typing.typeof import Purpose, typeof
from collections import namedtuple, OrderedDict
from warnings import warn
import numba
from .cudadrv.devices import get_context
Expand Down Expand Up @@ -492,6 +492,18 @@ def _rebuild(cls, entry_name, ptx, linking, max_registers):
return cls(entry_name, ptx, linking, max_registers)


_kernel_def_fields = (
'entry_name',
'signature',
'type_annotation',
'func',
'call_helper'
)


_KernelDefinition = namedtuple("_KernelDefinition", _kernel_def_fields)


class _Kernel(serialize.ReduceMixin):
'''
CUDA Kernel specialized for a given set of argument types. When called, this
Expand All @@ -513,6 +525,8 @@ def __init__(self, py_func, argtypes, link, debug=False, inline=False,
self.max_registers = max_registers
self.opt = opt

self.definitions = {}

cc = get_current_device().compute_capability
self.compile(cc)

Expand All @@ -528,13 +542,11 @@ def compile(self, cc):
args,
debug=self.debug)

llvm_module = lib._final_module
name = kernel.name
pretty_name = cres.fndesc.qualname
signature = cres.signature
type_annotation = cres.type_annotation
call_helper = cres.call_helper
max_registers = self.max_registers

# initialize CUfunction
options = {
Expand All @@ -543,14 +555,17 @@ def compile(self, cc):
'opt': 3 if self.opt else 0
}

ptx = CachedPTX(pretty_name, str(llvm_module), options=options)
cufunc = CachedCUFunction(name, ptx, self.link, max_registers)
ptx = CachedPTX(pretty_name, str(lib._final_module), options=options)
cufunc = CachedCUFunction(name, ptx, self.link, self.max_registers)

# populate members
self.entry_name = name
self.signature = signature
self._type_annotation = type_annotation
self._func = cufunc
self.call_helper = call_helper
self.definitions[cc] = _KernelDefinition(
entry_name=name,
signature=signature,
type_annotation=type_annotation,
func=cufunc,
call_helper=call_helper
)

@property
def argument_types(self):
Expand Down Expand Up @@ -598,6 +613,30 @@ def __call__(self, *args, **kwargs):
stream=self.stream,
sharedmem=self.sharedmem)

@property
def _func(self):
cc = get_current_device().compute_capability
return self.definitions[cc].func

@property
def _type_annotation(self):
return next(iter(self.definitions.values())).type_annotation

@property
def entry_name(self):
cc = get_current_device().compute_capability
return self.definitions[cc].entry_name

@property
def call_helper(self):
cc = get_current_device().compute_capability
return self.definitions[cc].call_helper

@property
def signature(self):
cc = get_current_device().compute_capability
return self.definitions[cc].signature

def bind(self):
"""
Force binding to current CUDA context
Expand Down Expand Up @@ -650,7 +689,7 @@ def inspect_types(self, file=None):
if file is None:
file = sys.stdout

print("%s %s" % (self.entry_name, self.argument_types), file=file)
print("%s %s" % (self.entry_name, self.argtypes), file=file)
print('-' * 80, file=file)
print(self._type_annotation, file=file)
print('=' * 80, file=file)
Expand All @@ -670,7 +709,7 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
retr = [] # hold functors for writeback

kernelargs = []
for t, v in zip(self.argument_types, args):
for t, v in zip(self.argtypes, args):
self._prepare_args(t, v, stream, retr, kernelargs)

# Configure kernel
Expand Down Expand Up @@ -855,7 +894,7 @@ def __init__(self, py_func, sigs, targetoptions):

# A mapping of signatures to compile results
# Stopgap for _DispatcherBase
self.overloads = collections.OrderedDict()
self.overloads = OrderedDict()

_dispatcher.Dispatcher.__init__(self, self._tm.get_pointer(),
arg_count, self._fold_args, argnames,
Expand Down

0 comments on commit 898a9da

Please sign in to comment.