forked from JuliaGPU/CUDA.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CUSOLVER.jl
90 lines (73 loc) · 2.5 KB
/
CUSOLVER.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
module CUSOLVER
using ..APIUtils
using ..CUDA
using ..CUDA: CUstream, cuComplex, cuDoubleComplex, libraryPropertyType, cudaDataType
using ..CUDA: libcusolver, @allowscalar, assertscalar, unsafe_free!, @retry_reclaim
using ..CUBLAS: cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasDiagType_t
using ..CUSPARSE: cusparseMatDescr_t
using CEnum
# core library
include("libcusolver_common.jl")
include("error.jl")
include("libcusolver.jl")
# low-level wrappers
include("util.jl")
include("wrappers.jl")
# high-level integrations
include("linalg.jl")
# thread cache for task-local library handles
const thread_dense_handles = Vector{Union{Nothing,cusolverDnHandle_t}}()
const thread_sparse_handles = Vector{Union{Nothing,cusolverSpHandle_t}}()
function dense_handle()
tid = Threads.threadid()
if @inbounds thread_dense_handles[tid] === nothing
ctx = context()
thread_dense_handles[tid] = get!(task_local_storage(), (:CUSOLVER, :dense, ctx)) do
handle = cusolverDnCreate()
cusolverDnSetStream(handle, CuStreamPerThread())
finalizer(current_task()) do task
CUDA.isvalid(ctx) || return
context!(ctx) do
cusolverDnDestroy(handle)
end
end
handle
end
end
something(@inbounds thread_dense_handles[tid])
end
function sparse_handle()
tid = Threads.threadid()
if @inbounds thread_sparse_handles[tid] === nothing
ctx = context()
thread_sparse_handles[tid] = get!(task_local_storage(), (:CUSOLVER, :sparse, ctx)) do
handle = cusolverSpCreate()
cusolverSpSetStream(handle, CuStreamPerThread())
finalizer(current_task()) do task
CUDA.isvalid(ctx) || return
context!(ctx) do
cusolverSpDestroy(handle)
end
end
handle
end
end
something(@inbounds thread_sparse_handles[tid])
end
function __init__()
resize!(thread_dense_handles, Threads.nthreads())
fill!(thread_dense_handles, nothing)
resize!(thread_sparse_handles, Threads.nthreads())
fill!(thread_sparse_handles, nothing)
CUDA.atdeviceswitch() do
tid = Threads.threadid()
thread_dense_handles[tid] = nothing
thread_sparse_handles[tid] = nothing
end
CUDA.attaskswitch() do
tid = Threads.threadid()
thread_dense_handles[tid] = nothing
thread_sparse_handles[tid] = nothing
end
end
end