Skip to content

Commit

Permalink
Expose storage mode in allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
jvkersch committed Nov 16, 2023
1 parent e8f16bf commit 08ea259
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ import Adapt
export MetalBackend

struct MetalBackend <: KA.GPU
storage::MTL.MTLResourceOptions
end
MetalBackend(;storage=DefaultStorageMode) = MetalBackend(storage)

KA.allocate(::MetalBackend, ::Type{T}, dims::Tuple) where T = MtlArray{T}(undef, dims)
KA.zeros(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.zeros(T, dims)
KA.ones(::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.ones(T, dims)
KA.allocate(backend::MetalBackend, ::Type{T}, dims::Dims{N}) where {T,N} = MtlArray{T,N,backend.storage}(undef, dims)
KA.zeros(backend::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.zeros(T, dims; storage=backend.storage)
KA.ones(backend::MetalBackend, ::Type{T}, dims::Tuple) where T = Metal.ones(T, dims; storage=backend.storage)

KA.get_backend(::MtlArray) = MetalBackend()
KA.synchronize(::MetalBackend) = synchronize()
Expand Down
2 changes: 1 addition & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# host array

export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl
export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, DefaultStorageMode

function hasfieldcount(@nospecialize(dt))
try
Expand Down

0 comments on commit 08ea259

Please sign in to comment.