forked from JuliaGPU/CUDA.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CUTENSOR.jl
61 lines (45 loc) · 1.39 KB
/
CUTENSOR.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module CUTENSOR
using ..APIUtils
using ..CUDA
using ..CUDA: CUstream, cudaDataType
using ..CUDA: libcutensor, @retry_reclaim, initialize_context
using CEnum: @cenum
const cudaDataType_t = cudaDataType
# core library
include("libcutensor_common.jl")
include("error.jl")
include("libcutensor.jl")
# low-level wrappers
include("tensor.jl")
include("wrappers.jl")
# high-level integrations
include("interfaces.jl")
# cache for created, but unused handles
const idle_handles = HandleCache{CuContext,Base.RefValue{cutensorHandle_t}}()
function handle()
cuda = CUDA.active_state()
# every task maintains library state per device
LibraryState = @NamedTuple{handle::Base.RefValue{cutensorHandle_t}}
states = get!(task_local_storage(), :CUTENSOR) do
Dict{CuContext,LibraryState}()
end::Dict{CuContext,LibraryState}
# get library state
@noinline function new_state(cuda)
new_handle = pop!(idle_handles, cuda.context) do
handle = Ref{cutensorHandle_t}()
cutensorInit(handle)
handle
end
finalizer(current_task()) do task
push!(idle_handles, cuda.context, new_handle) do
# CUTENSOR doesn't need to actively destroy its handle
end
end
(; handle=new_handle)
end
state = get!(states, cuda.context) do
new_state(cuda)
end
return state.handle
end
end