diff --git a/CMakeLists.txt b/CMakeLists.txt index bf18ffc9e856..ed905547cbae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.2) +cmake_minimum_required(VERSION 3.9) project(tvm C CXX) # Utility functions @@ -63,6 +63,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) +tvm_option(USE_CXX_RPC "Build CXX RPC" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -275,6 +276,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) +if(USE_CXX_RPC STREQUAL "ON") + add_subdirectory("apps/cpp_rpc") +endif() if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") @@ -405,6 +409,10 @@ endif(INSTALL_DEV) # More target definitions if(MSVC) + set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + set_property(TARGET nnvm_compiler PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS) target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS) target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS) diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt new file mode 100644 index 000000000000..61c40c1affe6 --- /dev/null +++ b/apps/cpp_rpc/CMakeLists.txt @@ -0,0 +1,23 @@ +set(TVM_RPC_SOURCES + main.cc + rpc_env.cc + rpc_server.cc +) + +if(WIN32) + list(APPEND TVM_RPC_SOURCES win32_process.cc) +endif() + +# Set output to same directory as the other TVM libs +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +add_executable(tvm_rpc ${TVM_RPC_SOURCES}) +set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) +target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) +target_include_directories( + tvm_rpc + PUBLIC "../../include" + PUBLIC DLPACK_PATH + PUBLIC DMLC_PATH +) + +target_link_libraries(tvm_rpc tvm_runtime) \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 3cf2ed6a5d59..f37cf56d39d0 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,10 +21,12 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include -#include -#include +#include +#include +#include +#if defined(__linux__) || defined(__ANDROID__) #include +#endif #include #include #include @@ -35,11 +37,15 @@ #include "../../src/common/socket.h" #include "rpc_server.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + using namespace std; using namespace tvm::runtime; using namespace tvm::common; -static const string kUSAGE = \ +static const string kUsage = \ "Command line usage\n" \ " server - Start the server\n" \ "--host - The hostname of the server, Default=0.0.0.0\n" \ @@ -73,13 +79,16 @@ struct RpcServerArgs { string key; string custom_addr; bool silent = false; +#if defined(WIN32) + std::string mmap_path; +#endif }; /*! * \brief PrintArgs print the contents of RpcServerArgs * \param args RpcServerArgs structure */ -void PrintArgs(struct RpcServerArgs args) { +void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "host = " << args.host; LOG(INFO) << "port = " << args.port; LOG(INFO) << "port_end = " << args.port_end; @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) { LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief CtrlCHandler, exits if Ctrl+C is pressed * \param s signal @@ -109,7 +119,7 @@ void HandleCtrlC() { sigIntHandler.sa_flags = 0; sigaction(SIGINT, &sigIntHandler, nullptr); } - +#endif /*! * \brief GetCmdOption Parse and find the command option. * \param argc arg counter @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { } // We assume "=" is the end of option. CHECK_EQ(*option.rbegin(), '='); - cmd = arg.substr(arg.find("=") + 1); + cmd = arg.substr(arg.find('=') + 1); return cmd; } } @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) { * \brief ParseCmdArgs parses the command line arguments. * \param argc arg counter * \param argv arg values - * \param args, the output structure which holds the parsed values + * \param args the output structure which holds the parsed values */ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { - string silent = GetCmdOption(argc, argv, "--silent", true); + const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } - string host = GetCmdOption(argc, argv, "--host="); + const string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { if (!ValidateIP(host)) { LOG(WARNING) << "Wrong host address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.host = host; } - string port = GetCmdOption(argc, argv, "--port="); + const string port = GetCmdOption(argc, argv, "--port="); if (!port.empty()) { if (!IsNumber(port) || stoi(port) > 65535) { LOG(WARNING) << "Wrong port number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port = stoi(port); } - string port_end = GetCmdOption(argc, argv, "--port_end="); + const string port_end = GetCmdOption(argc, argv, "--port_end="); if (!port_end.empty()) { if (!IsNumber(port_end) || stoi(port_end) > 65535) { LOG(WARNING) << "Wrong port_end number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port_end = stoi(port_end); @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { if (!tracker.empty()) { if (!ValidateTracker(tracker)) { LOG(WARNING) << "Wrong tracker address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.tracker = tracker; } - string key = GetCmdOption(argc, argv, "--key="); + const string key = GetCmdOption(argc, argv, "--key="); if (!key.empty()) { args.key = key; } - string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + const string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); if (!custom_addr.empty()) { if (!ValidateIP(custom_addr)) { LOG(WARNING) << "Wrong custom address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.custom_addr = custom_addr; } +#if defined(WIN32) + const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); + if(!mmap_path.empty()) { + args.mmap_path = mmap_path; + dmlc::InitLogging("--minloglevel=0"); + } +#endif + } /*! @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \return result of operation. */ int RpcServer(int argc, char * argv[]) { - struct RpcServerArgs args; + RpcServerArgs args; /* parse the command line args */ ParseCmdArgs(argc, argv, args); PrintArgs(args); - // Ctrl+C handler LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; +#if defined(__linux__) || defined(__ANDROID__) + // Ctrl+C handler HandleCtrlC(); - tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); +#endif + +#if defined(WIN32) + if(!args.mmap_path.empty()) { + int ret = 0; + + try { + ChildProcSocketHandler(args.mmap_path); + } catch (const std::exception&) { + ret = -1; + } + + return ret; + } +#endif + + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, + args.key, args.custom_addr, args.silent); return 0; } @@ -251,15 +286,18 @@ int RpcServer(int argc, char * argv[]) { */ int main(int argc, char * argv[]) { if (argc <= 1) { - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; return 0; } + // Runs WSAStartup on Win32, no-op on POSIX + Socket::Startup(); + if (0 == strcmp(argv[1], "server")) { - RpcServer(argc, argv); - } else { - LOG(INFO) << kUSAGE; + return RpcServer(argc, argv); } + LOG(INFO) << kUsage; + return 0; } diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 44f848dc749e..c4f77d3ffdfe 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,77 +20,74 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ +#include #include -#include -#ifndef _MSC_VER -#include +#ifndef _WIN32 #include +#include #include #else #include +#include +namespace { + int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} #endif +#include #include -#include #include #include -#include +#include +#include -#include "rpc_env.h" #include "../../src/common/util.h" #include "../../src/runtime/file_util.h" +#include "rpc_env.h" + +namespace { +#if defined(__linux__) || defined(__ANDROID__) + const std::string untar_cmd = "tar -C "; +#elif defined(_WIN32) + const std::string untar_cmd = "wsl tar -C "; +#endif +}// Anonymous namespace namespace tvm { namespace runtime { - RPCEnv::RPCEnv() { - #if defined(__linux__) || defined(__ANDROID__) - base_ = "./rpc"; - mkdir(&base_[0], 0777); + base_ = "./rpc"; + mkdir(base_.c_str(), 0777); + TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); + }); - TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") - .set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body([](TVMArgs args, TVMRetValue *rv) { - static RPCEnv env; - std::string file_name = env.GetPath(args[0]); - *rv = Load(&file_name, ""); - LOG(INFO) << "Load module from " << file_name << " ..."; - }); - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + std::string file_name = env.GetPath(args[0]); + *rv = Load(&file_name, ""); + LOG(INFO) << "Load module from " << file_name << " ..."; + }); } /*! - * \brief GetPath To get the workpath from packed function - * \param name The file name + * \brief GetPath To get the work path from packed function + * \param file_name The file name * \return The full path of file. */ -std::string RPCEnv::GetPath(std::string file_name) { +std::string RPCEnv::GetPath(const std::string& file_name) const { // we assume file_name has "/" means file_name is the exact path // and does not create /.rpc/ - if (file_name.find("/") != std::string::npos) { - return file_name; - } else { - return base_ + "/" + file_name; - } + return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name; } /*! * \brief Remove The RPC Environment cleanup function */ -void RPCEnv::CleanUp() { - #if defined(__linux__) || defined(__ANDROID__) - CleanDir(&base_[0]); - int ret = rmdir(&base_[0]); - if (ret != 0) { - LOG(WARNING) << "Remove directory " << base_ << " failed"; - } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif +void RPCEnv::CleanUp() const { + CleanDir(base_); + const int ret = rmdir(base_.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove directory " << base_ << " failed"; + } } /*! @@ -98,53 +95,54 @@ void RPCEnv::CleanUp() { * \param dirname The root directory name * \return vector Files in directory. */ -std::vector ListDir(const std::string &dirname) { +std::vector ListDir(const std::string& dirname) { std::vector vec; - #ifndef _MSC_VER - DIR *dp = opendir(dirname.c_str()); - if (dp == nullptr) { - int errsv = errno; - LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); - } - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - std::string f = dirname; - if (f[f.length() - 1] != '/') { - f += '/'; - } - f += d->d_name; - vec.push_back(f); +#ifndef _WIN32 + DIR* dp = opendir(dirname.c_str()); + if (dp == nullptr) { + int errsv = errno; + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + dirent* d; + while ((d = readdir(dp)) != nullptr) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } + f += d->d_name; + vec.push_back(f); } - closedir(dp); - #else - WIN32_FIND_DATA fd; - std::string pattern = dirname + "/*"; - HANDLE handle = FindFirstFile(pattern.c_str(), &fd); - if (handle == INVALID_HANDLE_VALUE) { - int errsv = GetLastError(); - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); - } - do { - if (fd.cFileName != "." && fd.cFileName != "..") { - std::string f = dirname; - char clast = f[f.length() - 1]; - if (f == ".") { - f = fd.cFileName; - } else if (clast != '/' && clast != '\\') { - f += '/'; - f += fd.cFileName; - } - vec.push_back(f); + } + closedir(dp); +#elif defined(_WIN32) + WIN32_FIND_DATAA fd; + const std::string pattern = dirname + "/*"; + HANDLE handle = FindFirstFileA(pattern.c_str(), &fd); + if (handle == INVALID_HANDLE_VALUE) { + const int errsv = GetLastError(); + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + do { + std::string filename = fd.cFileName; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } - } while (FindNextFile(handle, &fd)); - FindClose(handle); - #endif + f += filename; + vec.push_back(f); + } + } while (FindNextFileA(handle, &fd)); + FindClose(handle); +#else + LOG(FATAL) << "Operating system not supported"; +#endif return vec; } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief LinuxShared Creates a linux shared library * \param output The output file name @@ -152,35 +150,66 @@ std::vector ListDir(const std::string &dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, +void LinuxShared(const std::string output, const std::vector &files, - std::string options = "", + std::string options = "", std::string cc = "g++") { - std::string cmd = cc; - cmd += " -shared -fPIC "; - cmd += " -o " + output; - for (auto f = files.begin(); f != files.end(); ++f) { - cmd += " " + *f; - } - cmd += " " + options; - std::string err_msg; - auto executed_status = common::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (const auto& file : files) { + cmd += " " + file; + } + cmd += " " + options; + std::string err_msg; + const auto executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } } +#endif + +#ifdef _WIN32 +/*! + * \brief WindowsShared Creates a Windows shared library + * \param output The output file name + * \param files The files for building + * \param options The compiler options + * \param cc The compiler + */ +void WindowsShared(const std::string& output, + const std::vector& files, + const std::string& options = "", + const std::string& cc = "clang") { + std::string cmd = cc; + cmd += " -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -o " + output; + for (const auto& file : files) { + cmd += " " + file; + } + cmd += " " + options; + std::string err_msg; + const auto executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + printf("compile error: %s\n", err_msg.c_str()); + LOG(FATAL) << err_msg; + } +} +#endif /*! * \brief CreateShared Creates a shared library * \param output The output file name * \param files The files for building */ -void CreateShared(const std::string output, const std::vector &files) { - #if defined(__linux__) || defined(__ANDROID__) - LinuxShared(output, files); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif +void CreateShared(const std::string& output, const std::vector& files) { +#if defined(__linux__) || defined(__ANDROID__) + LinuxShared(output, files); +#elif defined(_WIN32) + WindowsShared(output, files); +#else + LOG(FATAL) << "Operating system not supported"; +#endif } /*! @@ -193,61 +222,52 @@ void CreateShared(const std::string output, const std::vector &file * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string fmt) { - std::string file = *fileIn; +Module Load(std::string *fileIn, const std::string& fmt) { + const std::string& file = *fileIn; if (common::EndsWith(file, ".so")) { - return Module::LoadFromFile(file, fmt); + return Module::LoadFromFile(file, fmt); } - #if defined(__linux__) || defined(__ANDROID__) - std::string file_name = file + ".so"; - if (common::EndsWith(file, ".o")) { - std::vector files; - files.push_back(file); - CreateShared(file_name, files); - } else if (common::EndsWith(file, ".tar")) { - std::string tmp_dir = "./rpc/tmp/"; - mkdir(&tmp_dir[0], 0777); - std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; - std::string err_msg; - int executed_status = common::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } - CreateShared(file_name, ListDir(tmp_dir)); - CleanDir(tmp_dir); - rmdir(&tmp_dir[0]); - } else { - file_name = file; + std::string file_name = file + ".so"; + if (common::EndsWith(file, ".o")) { + std::vector files; + files.push_back(file); + CreateShared(file_name, files); + } else if (common::EndsWith(file, ".tar")) { + const std::string tmp_dir = "./rpc/tmp/"; + mkdir(tmp_dir.c_str(), 0777); + + const std::string cmd = untar_cmd + tmp_dir + " -zxf " + file; + + std::string err_msg; + const int executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; } - *fileIn = file_name; - return Module::LoadFromFile(file_name, fmt); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif + CreateShared(file_name, ListDir(tmp_dir)); + CleanDir(tmp_dir); + (void)rmdir(tmp_dir.c_str()); + } else { + file_name = file; + } + *fileIn = file_name; + return Module::LoadFromFile(file_name, fmt); } /*! * \brief CleanDir Removes the files from the directory * \param dirname The name of the directory */ -void CleanDir(const std::string &dirname) { - #if defined(__linux__) || defined(__ANDROID__) - DIR *dp = opendir(dirname.c_str()); - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - filename = dirname + "/" + d->d_name; - int ret = std::remove(&filename[0]); - if (ret != 0) { - LOG(WARNING) << "Remove file " << filename << " failed"; - } - } +void CleanDir(const std::string& dirname) { + auto files = ListDir(dirname); + for (const auto& filename : files) { + std::string file_path = dirname + "/"; + file_path += filename; + const int ret = std::remove(filename.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove file " << filename << " failed"; } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + } } } // namespace runtime diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index 82409bae81a1..d046f6ecb480 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -40,7 +40,7 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string fmt = ""); +Module Load(std::string *path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory @@ -62,11 +62,11 @@ struct RPCEnv { * \param name The file name * \return The full path of file. */ - std::string GetPath(std::string file_name); + std::string GetPath(const std::string& file_name) const; /*! * \brief The RPC Environment cleanup function */ - void CleanUp(); + void CleanUp() const; private: /*! diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index b35a63bd67dc..f586b8f1faf6 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -22,24 +22,27 @@ * \brief RPC Server implementation. */ #include - #if defined(__linux__) || defined(__ANDROID__) #include #include #endif -#include -#include -#include -#include #include +#include +#include +#include #include -#include "rpc_server.h" -#include "rpc_env.h" -#include "rpc_tracker_client.h" +#include "../../src/common/socket.h" #include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" -#include "../../src/common/socket.h" +#include "rpc_env.h" +#include "rpc_server.h" +#include "rpc_tracker_client.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + +using namespace std::chrono; namespace tvm { namespace runtime { @@ -49,7 +52,7 @@ namespace runtime { * \param status status value */ #if defined(__linux__) || defined(__ANDROID__) -static pid_t waitPidEintr(int *status) { +static pid_t waitPidEintr(int* status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { @@ -76,34 +79,32 @@ class RPCServer { public: /*! * \brief Constructor. - */ - RPCServer(const std::string &host, - int port, - int port_end, - const std::string &tracker_addr, - const std::string &key, - const std::string &custom_addr) { - // Init the values - host_ = host; - port_ = port; - port_end_ = port_end; - tracker_addr_ = tracker_addr; - key_ = key; - custom_addr_ = custom_addr; + */ + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr) : + host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), + custom_addr_(std::move(custom_addr)) + { + } /*! * \brief Destructor. - */ + */ ~RPCServer() { - // Free the resources - tracker_sock_.Close(); - listen_sock_.Close(); + try { + // Free the resources + tracker_sock_.Close(); + listen_sock_.Close(); + } catch(...) { + + } } /*! * \brief Start Creates the RPC listen process and execution. - */ + */ void Start() { listen_sock_.Create(); my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); @@ -130,102 +131,95 @@ class RPCServer { tracker.TryConnect(); // step 2: wait for in-coming connections AcceptConnection(&tracker, &conn, &addr, &opts); - } - catch (const char* msg) { + } catch (const char* msg) { LOG(WARNING) << "Socket exception: " << msg; // close tracker resource tracker.Close(); continue; - } - catch (std::exception& e) { - // Other errors + } catch (const std::exception& e) { + // close tracker resource + tracker.Close(); LOG(WARNING) << "Exception standard: " << e.what(); continue; } int timeout = GetTimeOutFromOpts(opts); - #if defined(__linux__) || defined(__ANDROID__) - // step 3: serving - if (timeout != 0) { - const pid_t timer_pid = fork(); - if (timer_pid == 0) { - // Timer process - sleep(timeout); - exit(0); - } +#if defined(__linux__) || defined(__ANDROID__) + // step 3: serving + if (timeout != 0) { + const pid_t timer_pid = fork(); + if (timer_pid == 0) { + // Timer process + sleep(timeout); + exit(0); + } - const pid_t worker_pid = fork(); - if (worker_pid == 0) { - // Worker process - ServerLoopProc(conn, addr); - exit(0); - } + const pid_t worker_pid = fork(); + if (worker_pid == 0) { + // Worker process + ServerLoopProc(conn, addr); + exit(0); + } - int status = 0; - const pid_t finished_first = waitPidEintr(&status); - if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); - } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); - } else { - LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; - } + int status = 0; + const pid_t finished_first = waitPidEintr(&status); + if (finished_first == timer_pid) { + kill(worker_pid, SIGKILL); + } else if (finished_first == worker_pid) { + kill(timer_pid, SIGKILL); + } else { + LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; + } - int status_second = 0; - waitPidEintr(&status_second); + int status_second = 0; + waitPidEintr(&status_second); - // Logging. - if (finished_first == timer_pid) { - LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout - << "), Process status = " << status_second; - } else if (finished_first == worker_pid) { - LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; - } - } else { - auto pid = fork(); - if (pid == 0) { - ServerLoopProc(conn, addr); - exit(0); - } - // Wait for the result - int status = 0; - wait(&status); - LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + // Logging. + if (finished_first == timer_pid) { + LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout + << "), Process status = " << status_second; + } else if (finished_first == worker_pid) { + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } - #else - // step 3: serving - std::future proc(std::async(std::launch::async, - &RPCServer::ServerLoopProc, this, conn, addr)); - // wait until server process finish or timeout - if (timeout != 0) { - // Autoterminate after timeout - proc.wait_for(std::chrono::seconds(timeout)); - } else { - // Wait for the result - proc.get(); + } else { + auto pid = fork(); + if (pid == 0) { + ServerLoopProc(conn, addr); + exit(0); } - #endif + // Wait for the result + int status = 0; + wait(&status); + LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + } +#elif defined(WIN32) + auto start_time = high_resolution_clock::now(); + try { + SpawnRPCChild(conn.sockfd, seconds(timeout)); + } catch (const std::exception&) { + + } + auto dur = high_resolution_clock::now() - start_time; + + LOG(INFO) << "Serve Time " << duration_cast(dur).count() << "ms"; +#endif // close from our side. LOG(INFO) << "Socket Connection Closed"; conn.Close(); } } - /*! * \brief AcceptConnection Accepts the RPC Server connection. * \param tracker Tracker details. - * \param conn New connection information. + * \param conn_sock New connection information. * \param addr New connection address information. * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, - common::TCPSocket* conn_sock, - common::SockAddr* addr, - std::string* opts, - int ping_period = 2) { - std::set old_keyset; + void AcceptConnection(TrackerClient* tracker, common::TCPSocket* conn_sock, + common::SockAddr* addr, std::string* opts, int ping_period = 2) { + std::set old_keyset; std::string matchkey; // Report resource to tracker and get key @@ -236,7 +230,7 @@ class RPCServer { common::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -265,15 +259,15 @@ class RPCServer { std::string arg0; ssin >> arg0; if (arg0 != expect_header) { - code = kRPCMismatch; - CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - conn.Close(); - LOG(WARNING) << "Mismatch key from" << addr->AsString(); - continue; + code = kRPCMismatch; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + conn.Close(); + LOG(WARNING) << "Mismatch key from" << addr->AsString(); + continue; } else { code = kRPCSuccess; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - keylen = server_key.length(); + keylen = int(server_key.length()); CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); @@ -289,25 +283,23 @@ class RPCServer { * \param sock The socket information * \param addr The socket address information */ - void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { - // Server loop - auto env = RPCEnv(); - RPCServerLoop(sock.sockfd); - LOG(INFO) << "Finish serving " << addr.AsString(); - env.CleanUp(); + static void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { + // Server loop + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + LOG(INFO) << "Finish serving " << addr.AsString(); + env.CleanUp(); } /*! * \brief GetTimeOutFromOpts Parse and get the timeout option. * \param opts The option string - * \param timeout value after parsing. */ - int GetTimeOutFromOpts(std::string opts) { - std::string cmd; - std::string option = "-timeout="; + int GetTimeOutFromOpts(const std::string& opts) const { + const std::string option = "-timeout="; if (opts.find(option) == 0) { - cmd = opts.substr(opts.find_last_of(option) + 1); + const std::string cmd = opts.substr(opts.find_last_of(option) + 1); CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -325,35 +317,45 @@ class RPCServer { common::TCPSocket tracker_sock_; }; +#if defined(WIN32) +/*! +* \brief ServerLoopFromChild The Server loop process. +* \param socket The socket information +*/ +void ServerLoopFromChild(SOCKET socket) { + // Server loop + tvm::common::TCPSocket sock(socket); + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + + sock.Close(); + env.CleanUp(); +} +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host, - int port, - int port_end, - std::string tracker_addr, - std::string key, - std::string custom_addr, - bool silent) { +void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") -.set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); - }); +TVM_REGISTER_GLOBAL("rpc._ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); +}); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 205182e4449a..db7c89d823dd 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -30,6 +30,15 @@ namespace tvm { namespace runtime { +#if defined(WIN32) +/*! + * \brief ServerLoopFromChild The Server loop process. + * \param sock The socket information + * \param addr The socket address information + */ +void ServerLoopFromChild(SOCKET socket); +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 @@ -40,13 +49,13 @@ namespace runtime { * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -TVM_DLL void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", + int port = 9090, + int port_end = 9099, + std::string tracker_addr = "", + std::string key = "", + std::string custom_addr = "", + bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc new file mode 100644 index 000000000000..4af222a4a8dd --- /dev/null +++ b/apps/cpp_rpc/win32_process.cc @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include "win32_process.h" +#include "rpc_server.h" + +using namespace std::chrono; +using namespace tvm::runtime; + +namespace { +// The prefix path for the memory mapped file used to store IPC information +const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC"; +// Used to construct unique names for named resources in the parent process +const std::string kParent = "parent"; +// Used to construct unique names for named resources in the child process +const std::string kChild = "child"; +// The timeout of the WIN32 events, in the parent and the child +const milliseconds kEventTimeout(2000); + +// Used to create unique WIN32 mmap paths and event names +int child_counter_ = 0; + +/*! + * \brief HandleDeleter Deleter for UniqueHandle smart pointer + * \param handle The WIN32 HANDLE to manage + */ +struct HandleDeleter { + void operator()(HANDLE handle) const { + if (handle != INVALID_HANDLE_VALUE && handle != nullptr) { + CloseHandle(handle); + } + } +}; + +/*! + * \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE + */ +using UniqueHandle = std::unique_ptr; + +/*! + * \brief MakeUniqueHandle Helper method to construct a UniqueHandle + * \param handle The WIN32 HANDLE to manage + */ +UniqueHandle MakeUniqueHandle(HANDLE handle) { + if (handle == INVALID_HANDLE_VALUE || handle == nullptr) { + return nullptr; + } + + return UniqueHandle(handle); +} + +/*! + * \brief GetSocket Gets the socket info from the parent process and duplicates the socket + * \param mmap_path The path to the memory mapped info set by the parent + */ +SOCKET GetSocket(const std::string& mmap_path) { + WSAPROTOCOL_INFO protocol_info; + + const std::string parent_event_name = mmap_path + kParent; + const std::string child_event_name = mmap_path + kChild; + + // Open the events + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + UniqueHandle child_file_mapping_event; + if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + } + + const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, + false, + mmap_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + void* map_view = MapViewOfFile(file_map.get(), + FILE_MAP_READ | FILE_MAP_WRITE, + 0, 0, 0); + + SOCKET sock_duplicated = INVALID_SOCKET; + + if (map_view != nullptr) { + memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Creates the duplicate socket, that was created in the parent + sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + &protocol_info, + 0, + 0); + + // Let the parent know we are finished dupicating the socket + SetEvent(child_file_mapping_event.get()); + } else { + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + + return sock_duplicated; +} +}// Anonymous namespace + +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, seconds timeout) { + STARTUPINFOA startup_info; + + memset(&startup_info, 0, sizeof(startup_info)); + startup_info.cb = sizeof(startup_info); + + std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++); + + const std::string parent_event_name = file_map_path + kParent; + const std::string child_event_name = file_map_path + kChild; + + // Create an event to let the child know the socket info was set to the mmap file + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for parent file mapping failed"; + } + + UniqueHandle child_file_mapping_event; + // An event to let the parent know the socket info was read from the mmap file + if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for child file mapping failed"; + } + + char current_executable[MAX_PATH]; + + // Get the full path of the current executable + GetModuleFileNameA(nullptr, current_executable, MAX_PATH); + + std::string child_command_line = current_executable; + child_command_line += " server --child_proc="; + child_command_line += file_map_path; + + // CreateProcessA requires a non const char*, so we copy our std::string + std::unique_ptr command_line_ptr(new char[child_command_line.size() + 1]); + strcpy(command_line_ptr.get(), child_command_line.c_str()); + + PROCESS_INFORMATION child_process_info; + if (CreateProcessA(nullptr, + command_line_ptr.get(), + nullptr, + nullptr, + false, + CREATE_NO_WINDOW, + nullptr, + nullptr, + &startup_info, + &child_process_info)) { + // Child process and thread handles must be closed, so wrapped in RAII + auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); + auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); + + WSAPROTOCOL_INFO protocol_info; + // Get info needed to duplicate the socket + if (WSADuplicateSocket(fd, + child_process_info.dwProcessId, + &protocol_info) == SOCKET_ERROR) { + LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); + } + + // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc + UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, + nullptr, + PAGE_READWRITE, + 0, + sizeof(WSAPROTOCOL_INFO), + file_map_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + if (GetLastError() == ERROR_ALREADY_EXISTS) { + LOG(FATAL) << "CreateFileMapping(): mapping file already exists"; + } else { + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); + + if (map_view != nullptr) { + memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Let child proc know the mmap file is ready to be read + SetEvent(parent_file_mapping_event.get()); + + // Wait for the child to finish reading mmap file + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + } + } else { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + } + + const DWORD process_timeout = timeout.count() + ? uint32_t(duration_cast(timeout).count()) + : INFINITE; + + // Wait for child process to exit, or hit configured timeout + if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { + LOG(INFO) << "Child process timeout. Terminating."; + TerminateProcess(child_process_handle.get(), 0); + } + } else { + LOG(INFO) << "Create child process failed: " << GetLastError(); + } +} +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path) { + SOCKET socket; + const auto last_thread_priority = GetThreadPriority(GetCurrentThread()); + + // Set high thread priority to avoid the thread scheduler from + // interfering with any measurements in the RPC server. + SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); + + try { + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { + tvm::runtime::ServerLoopFromChild(socket); + } + else { + LOG(FATAL) << "GetSocket() failed"; + } + } catch (...) { + // Restore thread priority + SetThreadPriority(GetCurrentThread(), last_thread_priority); + throw; + } +} +} // namespace runtime +} // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h new file mode 100644 index 000000000000..7d1a27680ed3 --- /dev/null +++ b/apps/cpp_rpc/win32_process.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ +#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#include +#include +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path); +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ \ No newline at end of file diff --git a/src/common/ring_buffer.h b/src/common/ring_buffer.h index f548acf1846b..1ce4a88a83b3 100644 --- a/src/common/ring_buffer.h +++ b/src/common/ring_buffer.h @@ -63,7 +63,7 @@ class RingBuffer { size_t ncopy = head_ptr_ + bytes_available_ - old_size; memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { // shrink too large temporary buffer to avoid out of memory on some embedded devices size_t old_bytes = bytes_available_; diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index cb59e723251b..2b7edbd8d45d 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -34,8 +34,13 @@ class SockChannel final : public RPCChannel { explicit SockChannel(common::TCPSocket sock) : sock_(sock) {} ~SockChannel() { - if (!sock_.BadSocket()) { - sock_.Close(); + try { + // BadSocket can throw + if (!sock_.BadSocket()) { + sock_.Close(); + } + } catch (...) { + } } size_t Send(const void* data, size_t size) final { @@ -100,7 +105,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) { return CreateRPCModule(RPCConnect(url, port, "client:" + key)); } -void RPCServerLoop(int sockfd) { +// TVM_DLL needed for MSVC +TVM_DLL void RPCServerLoop(int sockfd) { common::TCPSocket sock( static_cast(sockfd)); RPCSession::Create(