Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 30 additions & 48 deletions c_src/dll_loader/adbc_dll_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,38 @@
#include <winbase.h>
#include <wchar.h>

#define NIF(NAME) ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])

// Helper for returning `{:error, msg}` from NIF.
ERL_NIF_TERM error(ErlNifEnv *env, const char *msg)
{
ERL_NIF_TERM atom = enif_make_atom(env, "error");
ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1);
return enif_make_tuple2(env, atom, msg_term);
ERL_NIF_TERM adbc_dll_unused(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) {
(void)(env);
(void)(argc);
(void)(argv);
return enif_make_int(env, 0);
}

// Helper for returning `{:ok, term}` from NIF.
ERL_NIF_TERM ok(ErlNifEnv *env)
{
return enif_make_atom(env, "ok");
}
int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) {
// Silence "unused var" warnings.
(void)(env);
(void)(priv_data);
(void)(old_priv_data);
(void)(load_info);

NIF(add_dll_directory) {
static bool path_updated = false;
if (path_updated) return ok(env);
return 0;
}

int load(ErlNifEnv *,void **,ERL_NIF_TERM) {
wchar_t dll_path_c[65536];
char err_msg[128] = { '\0' };
HMODULE hm = NULL;
if (GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, (LPCWSTR)&add_dll_directory, &hm) == 0) {
int ret = GetLastError();
snprintf(err_msg, sizeof(err_msg) - 1, "GetModuleHandle failed, error = %d\r\n", ret);
return error(env, err_msg);

if (GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, (LPCWSTR)&load, &hm) == 0) {
int ret = GetLastError();
printf("GetModuleHandle failed, error = %d\n", ret);
return 1;
}

if (GetModuleFileNameW(hm, (LPWSTR)dll_path_c, sizeof(dll_path_c)) == 0) {
int ret = GetLastError();
snprintf(err_msg, sizeof(err_msg) - 1, "GetModuleFileName failed, error = %d\r\n", ret);
return error(env, err_msg);
int ret = GetLastError();
printf("GetModuleFileName failed, error = %d\n", ret);
return 1;
}

std::wstring dll_path = dll_path_c;
Expand All @@ -64,44 +63,27 @@ NIF(add_dll_directory) {

SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_USER_DIRS);
DLL_DIRECTORY_COOKIE ret = AddDllDirectory(directory_pcwstr);

if (ret == 0) {
DWORD last_error = GetLastError();
LPTSTR error_text = nullptr;
FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, HRESULT_FROM_WIN32(last_error), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPTSTR)&error_text, 0, NULL);

if (error_text != nullptr) {
ERL_NIF_TERM ret_term = error(env, error_text);
LocalFree(error_text);
return ret_term;
printf("Error: %s\n", error_text);
LocalFree(error_text);
} else {
ERL_NIF_TERM ret_term = error(env, "error happened when adding adbc driver runtime path, but cannot get formatted error message");
return ret_term;
printf("Error: error happened when adding adbc driver runtime path, but cannot get formatted error message\n");
}
}
path_updated = true;
return ok(env);
}

int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) {
// Silence "unused var" warnings.
(void)(env);
(void)(priv_data);
(void)(old_priv_data);
(void)(load_info);

return 0;
}
return 1;
}

int load(ErlNifEnv *,void **,ERL_NIF_TERM) {
return 0;
}

#define F(NAME, ARITY) \
{ \
#NAME, ARITY, NAME, 0 \
}

static ErlNifFunc nif_functions[] = {
F(add_dll_directory, 0)
{"__unused__", 0, adbc_dll_unused, 0}
};

ERL_NIF_INIT(Elixir.Adbc.DLLLoaderNif, nif_functions, load, NULL, upgrade, NULL);
8 changes: 2 additions & 6 deletions lib/adbc_dll_loader_nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@ defmodule Adbc.DLLLoaderNif do
File.mkdir_p!(Path.join(priv_dir, "bin"))
path = :filename.join(priv_dir, ~c"adbc_dll_loader")
:erlang.load_nif(path, 0)
add_dll_directory()

_ ->
:ok
end
end

def add_dll_directory do
case :os.type() do
{:win32, _} -> :erlang.nif_error(:not_loaded)
_ -> :ok
end
def __unused__ do
0
end
end