Skip to content

Commit

Permalink
Replaced artifacts with JLL artifacts
Browse files Browse the repository at this point in the history
Using extended platform selection based on platform augmentation tags, i.e. https://pkgdocs.julialang.org/v1.7/artifacts/#Extending-Platform-Selection - adapted for use with JLL Artifacts.
  • Loading branch information
stemann committed Sep 10, 2022
1 parent f79d863 commit 0111241
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 88 deletions.
37 changes: 37 additions & 0 deletions .pkg/platform_augmentation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using Libdl, Base.BinaryPlatforms

function augment_platform!(p::Platform, tag::Union{String,Nothing} = nothing)
if tag === nothing
return p
end

if tag === "cuda"
# If this platform object already has a `cuda` tag set, don't augment
if haskey(p, "cuda")
return p
end

# Open libcuda explicitly, so it gets `dlclose()`'ed after we're done
try
dlopen("libcuda") do lib
# find symbol to ask for driver version; if we can't find it, just silently continue
cuDriverGetVersion = dlsym(lib, "cuDriverGetVersion"; throw_error=false)
if cuDriverGetVersion !== nothing
# Interrogate CUDA driver for driver version:
driverVersion = Ref{Cint}()
ccall(cuDriverGetVersion, UInt32, (Ptr{Cint},), driverVersion)

# Store only the major version
p["cuda"] = div(driverVersion, 1000)
end
end
catch
end

# Return possibly-altered `Platform` object
return p
else
@warn "Unexpected tag: $tag"
return p
end
end
19 changes: 19 additions & 0 deletions .pkg/select_artifacts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using TOML, Artifacts, Base.BinaryPlatforms

import ONNXRuntime_jll

include("platform_augmentation.jl")

artifacts_toml = Artifacts.find_artifacts_toml(pathof(ONNXRuntime_jll))

# Get "target triplet" from ARGS, if given (defaulting to the host triplet otherwise)
target_triplet = get(ARGS, 1, Base.BinaryPlatforms.host_triplet())

# Augment this platform object with any special tags we require
platform = augment_platform!(HostPlatform(parse(Platform, target_triplet)))

# Select all downloadable artifacts that match that platform
artifacts = select_downloadable_artifacts(artifacts_toml; platform)

# Output the result to `stdout` as a TOML dictionary
TOML.print(stdout, artifacts)
48 changes: 0 additions & 48 deletions Artifacts.toml

This file was deleted.

3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ version = "0.3.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
ONNXRuntime_jll = "09e6dd1b-8208-5c7e-a336-6e9061773d0b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[compat]
ArgCheck = "2"
Expand Down
28 changes: 25 additions & 3 deletions src/ONNXRunTime.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module ONNXRunTime
using Requires:@require

using Artifacts
using LazyArtifacts
using ONNXRuntime_jll

using Requires: @require

function _perm(arr::AbstractArray{T,N}) where {T,N}
ntuple(i->N+1-i, N)
Expand All @@ -11,11 +16,28 @@ function reversedims_lazy(arr)
PermutedDimsArray(arr, _perm(arr))
end

include("capi.jl")
include("highlevel.jl")
const EXECUTION_PROVIDERS = [:cpu, :cuda]

const artifact_dir_map = Dict{Symbol, String}()

include("../.pkg/platform_augmentation.jl")

function __init__()
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")

# Workaround/replacement for Artifacts.@artifact_str using the local Artifacts.toml
function artifact_dir(m::Module, artifact_name::String, p::Platform)
artifacts_toml = find_artifacts_toml(pathof(m))
h = artifact_hash(artifact_name, artifacts_toml; platform = p)
path = artifact_path(h)
return path
end

artifact_dir_map[:cpu] = artifact_dir(ONNXRuntime_jll, "ONNXRuntime", HostPlatform())
artifact_dir_map[:cuda] = artifact_dir(ONNXRuntime_jll, "ONNXRuntime", augment_platform!(HostPlatform(), "cuda"))
end

include("capi.jl")
include("highlevel.jl")

end #module
42 changes: 7 additions & 35 deletions src/capi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ This module closely follows the offical onnxruntime [C-API](https://github.com/m
See [here](https://github.com/microsoft/onnxruntime-inference-examples/blob/d031f879c9a8d33c8b7dc52c5bc65fe8b9e3960d/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c) for a C code example.
"""
module CAPI
using ONNXRunTime: reversedims_lazy

using ONNXRunTime: EXECUTION_PROVIDERS, artifact_dir_map, reversedims_lazy

using ONNXRuntime_jll

using DocStringExtensions
using Libdl
using CEnum: @cenum
using ArgCheck
using LazyArtifacts
using Pkg.Artifacts: artifact_path, ensure_artifact_installed, find_artifacts_toml

const LIB_CPU = Ref(C_NULL)
const LIB_CUDA = Ref(C_NULL)

const EXECUTION_PROVIDERS = [:cpu, :cuda]

# For model_path on windows ONNX uses wchar_t while on linux + mac char is used.
# Other strings use char on any platform it seems
# https://github.com/microsoft/onnxruntime/issues/9568#issuecomment-952951564
Expand All @@ -45,36 +44,9 @@ end

function make_lib!(execution_provider)
@argcheck execution_provider in EXECUTION_PROVIDERS
artifact_name = if execution_provider === :cpu
"onnxruntime_cpu"
elseif execution_provider === :cuda
"onnxruntime_gpu"
else
error("Unreachable")
end
artifacts_toml = find_artifacts_toml(joinpath(@__DIR__ , "ONNXRunTime.jl"))
h = artifact_hash(artifact_name, artifacts_toml)
if h === nothing
msg = """
Unsupported execution_provider = $(repr(execution_provider)) for
this architectur.
"""
error(msg)
end
ensure_artifact_installed(artifact_name, artifacts_toml)
root = artifact_path(h)
@check isdir(root)
dir = joinpath(root, only(readdir(root)))
@check isdir(dir)
libname = if Sys.iswindows()
"onnxruntime.dll"
elseif Sys.isapple()
"libonnxruntime.dylib"
else
"libonnxruntime.so"
end
path = joinpath(dir, "lib", libname)
@check isfile(path)
path = ONNXRuntime_jll.libonnxruntime_path
rel_path = joinpath(basename(dirname(path)), basename(path))
path = joinpath(artifact_dir_map[execution_provider], rel_path)
set_lib!(path, execution_provider)
end

Expand Down
3 changes: 1 addition & 2 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ArgCheck
using LazyArtifacts
using DataStructures: OrderedDict
using DocStringExtensions
################################################################################
Expand All @@ -11,7 +10,7 @@ end


using .CAPI
using .CAPI: juliatype, EXECUTION_PROVIDERS
using .CAPI: juliatype
export InferenceSession, load_inference

"""
Expand Down

0 comments on commit 0111241

Please sign in to comment.