Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] on-disk caching of CUDA kernels #8079

Closed
wants to merge 12 commits into from
12 changes: 10 additions & 2 deletions numba/core/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,14 @@ def rebuild(self, target_context, payload):
"""
Returns the unserialized CompileResult
"""
return compiler.CompileResult._rebuild(target_context, *payload)
# Huge bodge
from numba.cuda.target import CUDATargetContext
if isinstance(target_context, CUDATargetContext):
from numba.cuda.dispatcher import _Kernel
# Mega weird because of dict from reduced kernel
return _Kernel._rebuild(**payload)
else:
return compiler.CompileResult._rebuild(target_context, *payload)

def check_cachable(self, cres):
"""
Expand All @@ -401,7 +408,8 @@ def check_cachable(self, cres):
cannot_cache = None
if any(not x.can_cache for x in cres.lifted):
cannot_cache = "as it uses lifted code"
elif cres.library.has_dynamic_globals:
# Hack, probably need a KernelCacheImpl
elif hasattr(cres, 'library') and cres.library.has_dynamic_globals:
cannot_cache = ("as it uses dynamic globals "
"(such as ctypes pointers and large global arrays)")
if cannot_cache:
Expand Down
33 changes: 31 additions & 2 deletions numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numba.core import config, serialize
from numba.core.codegen import Codegen, CodeLibrary
from numba.core.errors import NumbaInvalidConfigWarning
from .cudadrv import devices, driver, nvvm
from .cudadrv import devices, driver, nvvm, runtime

import ctypes
import numpy as np
Expand Down Expand Up @@ -291,18 +291,37 @@ def finalize(self):

self._finalized = True

def serialize_using_object_code(self):
# Inspired by CPUCodeLibrary's version
self._ensure_finalized()
data = (self._ptx_cache, self._cubin_cache)
return (self.name, 'object', data)

@classmethod
def _unserialize(cls, codegen, state):
# Inspired by CPUCodeLibrary's version
name, kind, data = state
if kind != 'object':
raise ValueError('Only support unserialize with obj state')
self = codegen.create_library(name)
assert isinstance(self, cls)

self._ptx_cache, self._cubin_cache = data
# Is any more required?

def _reduce_states(self):
"""
Reduce the instance for serialization. We retain the PTX and cubins,
but loaded functions are discarded. They are recreated when needed
after deserialization.
"""

if self._linking_files:
msg = ('cannot pickle CUDACodeLibrary function with additional '
'libraries to link against')
raise RuntimeError(msg)
return dict(
codegen=self._codegen,
codegen=None,
name=self.name,
entry_name=self._entry_name,
module=self._module,
Expand Down Expand Up @@ -337,6 +356,8 @@ def _rebuild(cls, codegen, name, entry_name, module, linking_libraries,
instance._max_registers = max_registers
instance._nvvm_options = nvvm_options

return instance


class JITCUDACodegen(Codegen):
"""
Expand All @@ -360,3 +381,11 @@ def _create_empty_module(self, name):

def _add_module(self, module):
pass

def magic_tuple(self):
"""
Return a tuple unambiguously describing the codegen behaviour.
"""
ctx = devices.get_context()
cc = ctx.device.compute_capability
return (runtime.runtime.get_version(), cc)
14 changes: 11 additions & 3 deletions numba/cuda/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def jit(func_or_sig=None, device=False, inline=False, link=[], debug=None,
opt=True, **kws):
opt=True, cache=False, **kws):
"""
JIT compile a python function conforming to the CUDA Python specification.
If a signature is supplied, then a function is returned that takes a
Expand Down Expand Up @@ -102,6 +102,9 @@ def _jit(func):

disp = CUDADispatcher(func, targetoptions=targetoptions)

if cache:
disp.enable_caching()

if device:
disp.compile_device(argtypes)
else:
Expand All @@ -122,7 +125,7 @@ def autojitwrapper(func):
else:
def autojitwrapper(func):
return jit(func, device=device, debug=debug, opt=opt,
link=link, **kws)
link=link, cache=cache, **kws)

return autojitwrapper
# func_or_sig is a function
Expand All @@ -138,7 +141,12 @@ def autojitwrapper(func):
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
targetoptions['extensions'] = extensions
return CUDADispatcher(func_or_sig, targetoptions=targetoptions)
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)

if cache:
disp.enable_caching()

return disp


def declare_device(name, sig):
Expand Down
118 changes: 89 additions & 29 deletions numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
self.lineinfo = lineinfo
self.extensions = extensions or []

# Kernels don't have any lifted code
self.lifted = []

nvvm_options = {
'debug': self.debug,
'lineinfo': self.lineinfo,
Expand All @@ -87,6 +90,11 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
filename, linenum,
max_registers)

# Needed for caching to get codegen
self.target_context = tgt_ctx
self.fndesc = cres.fndesc
self.environment = cres.environment

if not link:
link = []

Expand All @@ -106,13 +114,57 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
self._codelibrary = lib
self.call_helper = cres.call_helper

# Pretend there are no referenced environments. TODO: Are there any for
# kernels?
self._referenced_environments = []
# What is reload_init for? is it only parfors?
self.reload_init = []

@property
def codegen(self):
return self.target_context.codegen()

@property
def library(self):
# Is this here because we have a discrepancy between the naming in
# kernel and compile result? It should probably be renamed.
return self._codelibrary

@property
def type_annotation(self):
# Another hack, should probably just change the name
return self._type_annotation

def _find_referenced_environments(self):
# Another hack
return self._referenced_environments

@property
def argument_types(self):
return tuple(self.signature.args)

def _reduce(self):
# TODO: WHY IS THIS CALLED AND NOT _reduce_states?
return self._reduce_states()

# First attempt:
# Inspired by CompileResult._reduce.

libdata = self.library.serialize_using_object_code()
# Make it (un)picklable efficiently
typeann = str(self.type_annotation)
fndesc = self.fndesc
# Those don't need to be pickled and may fail
fndesc.typemap = fndesc.calltypes = None
# Include all referenced environments
referenced_envs = self._find_referenced_environments()
return (libdata, self.fndesc, self.environment, self.signature,
self.objectmode, self.lifted, typeann, self.reload_init,
tuple(referenced_envs))

@classmethod
def _rebuild(cls, cooperative, name, argtypes, codelibrary, link, debug,
lineinfo, call_helper, extensions):
def _rebuild(cls, cooperative, name, signature, codelibrary,
debug, lineinfo, call_helper, extensions):
"""
Rebuild an instance.
"""
Expand All @@ -122,7 +174,7 @@ def _rebuild(cls, cooperative, name, argtypes, codelibrary, link, debug,
# populate members
instance.cooperative = cooperative
instance.entry_name = name
instance.argument_types = tuple(argtypes)
instance.signature = signature
instance._type_annotation = None
instance._codelibrary = codelibrary
instance.debug = debug
Expand All @@ -140,7 +192,7 @@ def _reduce_states(self):
Stream information is discarded.
"""
return dict(cooperative=self.cooperative, name=self.entry_name,
argtypes=self.argtypes, codelibrary=self.codelibrary,
signature=self.signature, codelibrary=self._codelibrary,
debug=self.debug, lineinfo=self.lineinfo,
call_helper=self.call_helper, extensions=self.extensions)

Expand Down Expand Up @@ -721,23 +773,47 @@ def compile(self, sig):
'''
argtypes, return_type = sigutils.normalize_signature(sig)
assert return_type is None or return_type == types.none

# Do we already have an in-memory compiled kernel?
if self.specialized:
return next(iter(self.overloads.values()))
else:
kernel = self.overloads.get(argtypes)
if kernel is None:
if not self._can_compile:
raise RuntimeError("Compilation disabled")
kernel = _Kernel(self.py_func, argtypes,
**self.targetoptions)
# Inspired by _DispatcherBase.add_overload, but differs slightly
# because we're inserting a _Kernel object instead of a compiled
# function.
if kernel is not None:
return kernel

# Can we load from the disk cache?
kernel = self._cache.load_overload(sig, self.targetctx)
if kernel is not None:
self._cache_hits[sig] += 1

# This should be refactored into an add_overload function, it
# duplicates the code for the "have to compile" path below
c_sig = [a._code for a in argtypes]
self._insert(c_sig, kernel, cuda=True)
self.overloads[argtypes] = kernel

kernel.bind()
return kernel

self._cache_misses[sig] += 1

# We need to compile a new kernel
if not self._can_compile:
raise RuntimeError("Compilation disabled")

kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)

# Inspired by _DispatcherBase.add_overload, but differs slightly
# because we're inserting a _Kernel object instead of a compiled
# function.
c_sig = [a._code for a in argtypes]
self._insert(c_sig, kernel, cuda=True)
self.overloads[argtypes] = kernel

kernel.bind()

self._cache.save_overload(sig, kernel)

return kernel

def inspect_llvm(self, signature=None):
Expand Down Expand Up @@ -834,19 +910,3 @@ def ptx(self):
def bind(self):
for defn in self.overloads.values():
defn.bind()

@classmethod
def _rebuild(cls, py_func, targetoptions):
"""
Rebuild an instance.
"""
instance = cls(py_func, targetoptions)
return instance

def _reduce_states(self):
"""
Reduce the instance for serialization.
Compiled definitions are discarded.
"""
return dict(py_func=self.py_func,
targetoptions=self.targetoptions)
6 changes: 6 additions & 0 deletions numba/cuda/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def load_additional_registries(self):
def codegen(self):
return self._internal_codegen

def get_executable(self, library, fndesc, env):
# Hack for caching, but there is no executable in the compile result
# anyway for CUDA
raise RuntimeError("tmp")
return None

@property
def target_data(self):
return self._target_data
Expand Down