Skip to content

Commit

Permalink
Support Windows platform (#106)
Browse files Browse the repository at this point in the history
* Fix compilation for Windows platform

* Use std::ios::binary flag when opening *.params file

* Remove tabs in code
  • Loading branch information
hcho3 committed Nov 7, 2019
1 parent 423dd1e commit 35ed4fa
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 14 deletions.
12 changes: 7 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ cmake_minimum_required (VERSION 3.6)
include(cmake/Utils.cmake)
include(3rdparty/tvm/cmake/util/FindCUDA.cmake)

set_default_configuration_release()
msvc_use_static_runtime()

message(STATUS "CMAKE_BUILD_TYPE: " ${CMAKE_BUILD_TYPE})

# Option for Android on Arm --- has to come before project() function
option(ANDROID_BUILD "Build for Android target" OFF)
option(AAR_BUILD "Build Android Archive (AAR)" OFF)
Expand All @@ -20,6 +15,12 @@ endif(ANDROID_BUILD)

project(dlr)

# The following lines should be after project()
set_default_configuration_release()
msvc_use_static_runtime()
message(STATUS "CMAKE_BUILD_TYPE: " ${CMAKE_BUILD_TYPE})
set(CMAKE_LOCAL "${PROJECT_SOURCE_DIR}/cmake")

# CMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES stuff should go after project() function
if(ANDROID_BUILD)
# Disable debugging info for Release build by setting -g level to 0. It will reduce libdlr.so size by a factor of 3.
Expand Down Expand Up @@ -310,6 +311,7 @@ if(NOT(ANDROID_BUILD OR AAR_BUILD))
string(REPLACE ".cc" "" __execname ${__srcname})
add_executable(${__execname} ${__srcpath})
target_link_libraries(${__execname} dlr gtest_main)
set_output_directory(${__execname} ${CMAKE_BINARY_DIR})
add_test(NAME ${__execname} COMMAND ${__execname})
message(STATUS "Added Test: " ${__execname})
endforeach()
Expand Down
2 changes: 1 addition & 1 deletion cmake/googletest-download.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ExternalProject_Add(
GIT_REPOSITORY
https://github.com/google/googletest.git
GIT_TAG
release-1.8.0
release-1.8.1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
Expand Down
29 changes: 25 additions & 4 deletions include/dlr.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

/* special symbols for DLL library on Windows */
#ifdef __cplusplus
#if defined(_MSC_VER) || defined(_WIN32)
extern "C" __declspec(dllexport) { // Open extern "C" block on Windows
#else
extern "C" { // Open extern "C" block
#endif
#endif // __cplusplus

#if defined(_MSC_VER) || defined(_WIN32)
#define DLR_DLL __declspec(dllexport)
#else
#define DLR_DLL
#endif // defined(_MSC_VER) || defined(_WIN32)

/*! \brief major version */
#define DLR_MAJOR 1
/*! \brief minor version */
Expand Down Expand Up @@ -43,6 +45,7 @@ typedef void* DLRModelHandle;
\param dev_id Device ID.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int CreateDLRModel(DLRModelHandle *handle,
const char *model_path,
int dev_type,
Expand All @@ -57,6 +60,7 @@ int CreateDLRModel(DLRModelHandle *handle,
\param use_nnapi Use NNAPI, 0 - false, 1 - true.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int CreateDLRModelFromTFLite(DLRModelHandle *handle,
const char *model_path,
int threads,
Expand All @@ -68,13 +72,15 @@ int CreateDLRModelFromTFLite(DLRModelHandle *handle,
\param handle The model handle returned from CreateDLRModel().
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int DeleteDLRModel(DLRModelHandle* handle);

/*!
\brief Runs a DLR model.
\param handle The model handle returned from CreateDLRModel().
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int RunDLRModel(DLRModelHandle *handle);

/*!
Expand All @@ -83,6 +89,7 @@ int RunDLRModel(DLRModelHandle *handle);
\param num_inputs The pointer to save the number of inputs.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRNumInputs(DLRModelHandle* handle, int* num_inputs);

/*!
Expand All @@ -91,6 +98,7 @@ int GetDLRNumInputs(DLRModelHandle* handle, int* num_inputs);
\param num_weights The pointer to save the number of weights.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRNumWeights(DLRModelHandle* handle, int* num_weights);

/*!
Expand All @@ -100,6 +108,7 @@ int GetDLRNumWeights(DLRModelHandle* handle, int* num_weights);
\param input_name The pointer to save the name of the index-th input.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRInputName(DLRModelHandle* handle,
int index,
const char** input_name);
Expand All @@ -111,6 +120,7 @@ int GetDLRInputName(DLRModelHandle* handle,
\param input_name The pointer to save the name of the index-th weight.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRWeightName(DLRModelHandle* handle,
int index,
const char** weight_name);
Expand All @@ -124,6 +134,7 @@ int GetDLRWeightName(DLRModelHandle* handle,
\param dim The dimension of the input data.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int SetDLRInput(DLRModelHandle* handle,
const char* name,
const int64_t* shape,
Expand All @@ -136,6 +147,7 @@ int SetDLRInput(DLRModelHandle* handle,
\param input The current value of the input will be copied to this buffer.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRInput(DLRModelHandle* handle,
const char* name,
float* input);
Expand All @@ -146,6 +158,7 @@ int GetDLRInput(DLRModelHandle* handle,
\param shape The pointer to save the shape of index-th output. This should be a pointer to an array of size "dim" from GetDLROutputSizeDim().
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLROutputShape(DLRModelHandle* handle,
int index,
int64_t* shape);
Expand All @@ -157,6 +170,7 @@ int GetDLROutputShape(DLRModelHandle* handle,
\param out The pointer to save the output data. This should be a pointer to an array of size "size" from GetDLROutputSizeDim().
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLROutput(DLRModelHandle* handle,
int index,
float* out);
Expand All @@ -166,6 +180,7 @@ int GetDLROutput(DLRModelHandle* handle,
\param num_outputs The pointer to save the number of outputs.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRNumOutputs(DLRModelHandle* handle, int* num_outputs);

/*!
Expand All @@ -176,6 +191,7 @@ int GetDLRNumOutputs(DLRModelHandle* handle, int* num_outputs);
\param dim The pointer to save the dimension of the index-th output.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLROutputSizeDim(DLRModelHandle* handle,
int index,
int64_t* size,
Expand All @@ -184,6 +200,7 @@ int GetDLROutputSizeDim(DLRModelHandle* handle,
\brief Gets the last error message.
\return Null-terminated string containing the error message.
*/
DLR_DLL
const char* DLRGetLastError();

/*!
Expand All @@ -192,13 +209,15 @@ const char* DLRGetLastError();
\param name The pointer to save the null-terminated string containing the name.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRBackend(DLRModelHandle* handle, const char** name);

/*!
\brief Get DLR version
\param out The pointer to save the null-terminated string containing the version.
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int GetDLRVersion(const char** out);

/*!
Expand All @@ -207,6 +226,7 @@ int GetDLRVersion(const char** out);
\param threads number of threads
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int SetDLRNumThreads(DLRModelHandle* handle, int threads);

/*!
Expand All @@ -215,6 +235,7 @@ int SetDLRNumThreads(DLRModelHandle* handle, int threads);
\param use 0 to disable, 1 to enable
\return 0 for success, -1 for error. Call DLRGetLastError() to get the error message.
*/
DLR_DLL
int UseDLRCPUAffinity(DLRModelHandle* handle, int use);

/*! \} */
Expand Down
17 changes: 13 additions & 4 deletions src/dlr_tvm.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "dlr_tvm.h"

#include <iterator>
#include <stdlib.h>

#include <fstream>
Expand Down Expand Up @@ -51,7 +52,7 @@ void TVMModel::SetupTVMModule(std::vector<std::string> model_path) {
std::ifstream jstream(paths.model_json);
std::stringstream json_blob;
json_blob << jstream.rdbuf();
std::ifstream pstream(paths.params);
std::ifstream pstream(paths.params, std::ios::in | std::ios::binary);
std::stringstream param_blob;
param_blob << pstream.rdbuf();

Expand Down Expand Up @@ -175,19 +176,27 @@ const char* TVMModel::GetBackend() const {
return "tvm";
}

static inline int SetEnv(const char* key, const char* value) {
#ifdef _WIN32
return static_cast<int>(_putenv_s(key, value));
#else
return setenv(key, value, 1);
#endif // _WIN32
}

void TVMModel::SetNumThreads(int threads) {
if (threads > 0) {
setenv("TVM_NUM_THREADS", std::to_string(threads).c_str(), 1);
SetEnv("TVM_NUM_THREADS", std::to_string(threads).c_str());
LOG(INFO) << "Set Num Threads: " << threads;
}
}

void TVMModel::UseCPUAffinity(bool use) {
if (use) {
setenv("TVM_BIND_THREADS", "1", 1);
SetEnv("TVM_BIND_THREADS", "1");
LOG(INFO) << "CPU Affinity is enabled";
} else {
setenv("TVM_BIND_THREADS", "0", 1);
SetEnv("TVM_BIND_THREADS", "0");
LOG(INFO) << "CPU Affinity is disabled";
}
}
2 changes: 2 additions & 0 deletions tests/cpp/dlr_tflite/dlr_tflite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ TEST(TFLite, CreateDLRModel) {

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
#ifndef _WIN32
testing::FLAGS_gtest_death_test_style = "threadsafe";
#endif // _WIN32
return RUN_ALL_TESTS();
}
2 changes: 2 additions & 0 deletions tests/cpp/dlr_treelite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ TEST(Treelite, Test1) {

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
#ifndef _WIN32
testing::FLAGS_gtest_death_test_style = "threadsafe";
#endif // _WIN32
return RUN_ALL_TESTS();
}
2 changes: 2 additions & 0 deletions tests/cpp/dlr_tvm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ TEST(TVM, Test1) {

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
#ifndef _WIN32
testing::FLAGS_gtest_death_test_style = "threadsafe";
#endif // _WIN32
return RUN_ALL_TESTS();
}

0 comments on commit 35ed4fa

Please sign in to comment.