Skip to content

Commit

Permalink
Merge pull request cupy#2053 from kmaehashi/fix-cusolver-threading
Browse files Browse the repository at this point in the history
Avoid sharing handles between threads
  • Loading branch information
okuta committed Apr 1, 2019
1 parent 9084dda commit 45b198a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 75 deletions.
4 changes: 4 additions & 0 deletions cupy/cuda/cupy_cusolver.h
Expand Up @@ -221,6 +221,10 @@ cusolverStatus_t cusolverSpCreate(...) {
return CUSOLVER_STATUS_SUCCESS;
}

cusolverStatus_t cusolverSpDestroy(...) {
return CUSOLVER_STATUS_SUCCESS;
}

cusolverStatus_t cusolverSpScsrlsvqr(...) {
return CUSOLVER_STATUS_SUCCESS;
}
Expand Down
1 change: 1 addition & 0 deletions cupy/cuda/cusolver.pxd
Expand Up @@ -42,6 +42,7 @@ cpdef enum:
cpdef size_t create() except? 0
cpdef size_t spCreate() except? 0
cpdef destroy(size_t handle)
cpdef spDestroy(size_t handle)

###############################################################################
# Stream
Expand Down
8 changes: 8 additions & 0 deletions cupy/cuda/cusolver.pyx
Expand Up @@ -14,6 +14,7 @@ cdef extern from 'cupy_cusolver.h' nogil:
int cusolverDnCreate(Handle* handle)
int cusolverSpCreate(SpHandle* handle)
int cusolverDnDestroy(Handle handle)
int cusolverSpDestroy(SpHandle handle)

# Stream
int cusolverDnGetStream(Handle handle, driver.Stream* streamId)
Expand Down Expand Up @@ -235,6 +236,13 @@ cpdef destroy(size_t handle):
status = cusolverDnDestroy(<Handle>handle)
check_status(status)


cpdef spDestroy(size_t handle):
with nogil:
status = cusolverSpDestroy(<SpHandle>handle)
check_status(status)


###############################################################################
# Stream
###############################################################################
Expand Down
5 changes: 5 additions & 0 deletions cupy/cuda/device.pxd
Expand Up @@ -4,6 +4,11 @@ cpdef size_t get_cusolver_handle() except? 0
cpdef size_t get_cusparse_handle() except? 0
cpdef str get_compute_capability()

cdef class Handle:
cdef:
public size_t handle
object _destroy_func

cdef class Device:
cdef:
public int id
Expand Down
128 changes: 53 additions & 75 deletions cupy/cuda/device.pyx
@@ -1,12 +1,15 @@
# distutils: language = c++

import atexit
import threading

import six

from cupy.cuda cimport cublas
from cupy.cuda cimport cusparse
from cupy.cuda cimport runtime
from cupy.cuda import cublas
from cupy.cuda import cusparse
from cupy.cuda import runtime
from cupy.cuda import runtime
from cupy import util

try:
from cupy.cuda import cusolver
Expand All @@ -15,46 +18,48 @@ except ImportError:
cusolver_enabled = False


cdef object _thread_local = threading.local()

cdef dict _devices = {}
cdef dict _compute_capabilities = {}


cpdef int get_device_id() except? -1:
return runtime.getDevice()


cdef dict _cublas_handles = {}
cdef dict _cusolver_handles = {}
cdef dict _cusolver_sp_handles = {}
cdef dict _cusparse_handles = {}
cdef dict _compute_capabilities = {}
cpdef Device _get_device():
dev_id = runtime.getDevice()
ret = _devices.get(dev_id, None)
if ret is None:
ret = Device()
_devices[dev_id] = ret
return ret


cdef class Handle:
def __init__(self, handle, destroy_func):
self.handle = handle
self._destroy_func = destroy_func

def __dealloc__(self):
self._destroy_func(self.handle)


cpdef size_t get_cublas_handle() except? 0:
dev_id = get_device_id()
ret = _cublas_handles.get(dev_id, None)
if ret is not None:
return ret
return Device().cublas_handle
return _get_device().cublas_handle


cpdef size_t get_cusolver_handle() except? 0:
dev_id = get_device_id()
ret = _cusolver_handles.get(dev_id, None)
if ret is not None:
return ret
return Device().cusolver_handle
return _get_device().cusolver_handle


cpdef get_cusolver_sp_handle():
dev_id = get_device_id()
if dev_id in _cusolver_sp_handles:
return _cusolver_sp_handles[dev_id]
return Device().cusolver_sp_handle
return _get_device().cusolver_sp_handle


cpdef size_t get_cusparse_handle() except? 0:
dev_id = get_device_id()
ret = _cusparse_handles.get(dev_id, None)
if ret is not None:
return ret
return Device().cusparse_handle
return _get_device().cusparse_handle


cpdef str get_compute_capability():
Expand Down Expand Up @@ -146,6 +151,19 @@ cdef class Device:
_compute_capabilities[self.id] = cc
return cc

def _get_handle(self, name, create_func, destroy_func):
handles = getattr(_thread_local, name, None)
if handles is None:
handles = {}
setattr(_thread_local, name, handles)
handle = handles.get(self.id, None)
if handle is not None:
return handle.handle
with self:
handle = create_func()
handles[self.id] = Handle(handle, destroy_func)
return handle

@property
def cublas_handle(self):
"""The cuBLAS handle for this device.
Expand All @@ -154,12 +172,8 @@ cdef class Device:
itself is different.
"""
if self.id in _cublas_handles:
return _cublas_handles[self.id]
with self:
handle = cublas.create()
_cublas_handles[self.id] = handle
return handle
return self._get_handle(
'cublas_handles', cublas.create, cublas.destroy)

@property
def cusolver_handle(self):
Expand All @@ -169,15 +183,8 @@ cdef class Device:
itself is different.
"""
if not cusolver_enabled:
raise RuntimeError(
'Current cupy only supports cusolver in CUDA 8.0')
if self.id in _cusolver_handles:
return _cusolver_handles[self.id]
with self:
handle = cusolver.create()
_cusolver_handles[self.id] = handle
return handle
return self._get_handle(
'cusolver_handles', cusolver.create, cusolver.destroy)

@property
def cusolver_sp_handle(self):
Expand All @@ -187,15 +194,8 @@ cdef class Device:
itself is different.
"""
if not cusolver_enabled:
raise RuntimeError(
'Current cupy only supports cusolver in CUDA 8.0')
if self.id in _cusolver_sp_handles:
return _cusolver_sp_handles[self.id]
with self:
handle = cusolver.spCreate()
_cusolver_sp_handles[self.id] = handle
return handle
return self._get_handle(
'cusolver_sp_handles', cusolver.spCreate, cusolver.spDestroy)

@property
def cusparse_handle(self):
Expand All @@ -205,12 +205,8 @@ cdef class Device:
itself is different.
"""
if self.id in _cusparse_handles:
return _cusparse_handles[self.id]
with self:
handle = cusparse.create()
_cusparse_handles[self.id] = handle
return handle
return self._get_handle(
'cusparse_sp_handles', cusparse.create, cusparse.destroy)

@property
def mem_info(self):
Expand Down Expand Up @@ -253,21 +249,3 @@ def from_pointer(ptr):
"""
attrs = runtime.pointerGetAttributes(ptr)
return Device(attrs.device)


@atexit.register
def destroy_cublas_handles():
"""Destroys the cuBLAS handles for all devices."""
global _cublas_handles
for handle in _cublas_handles.itervalues():
cublas.destroy(handle)
_cublas_handles = {}


@atexit.register
def destroy_cusparse_handles():
"""Destroys the cuSPARSE handles for all devices."""
global _cusparse_handles
for handle in six.itervalues(_cusparse_handles):
cusparse.destroy(handle)
_cusparse_handles = {}

0 comments on commit 45b198a

Please sign in to comment.