Skip to content

Commit

Permalink
Added Windows support to C++ RPC Server
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorrill committed Dec 3, 2019
1 parent 255d46a commit 49f5e87
Show file tree
Hide file tree
Showing 11 changed files with 746 additions and 318 deletions.
10 changes: 9 additions & 1 deletion CMakeLists.txt
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.2)
cmake_minimum_required(VERSION 3.9)
project(tvm C CXX)

# Utility functions
Expand Down Expand Up @@ -63,6 +63,7 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CXX_RPC "Build CXX RPC" OFF)

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
Expand Down Expand Up @@ -275,6 +276,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS})

if(USE_CXX_RPC STREQUAL "ON")
add_subdirectory("apps/cpp_rpc")
endif()

if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...")
Expand Down Expand Up @@ -405,6 +409,10 @@ endif(INSTALL_DEV)

# More target definitions
if(MSVC)
set_property(TARGET tvm PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET tvm_topi PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET tvm_runtime PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET nnvm_compiler PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
Expand Down
23 changes: 23 additions & 0 deletions apps/cpp_rpc/CMakeLists.txt
@@ -0,0 +1,23 @@
set(TVM_RPC_SOURCES
main.cc
rpc_env.cc
rpc_server.cc
)

if(WIN32)
list(APPEND TVM_RPC_SOURCES win32_process.cc)
endif()

# Set output to same directory as the other TVM libs
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
add_executable(tvm_rpc ${TVM_RPC_SOURCES})
set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
target_include_directories(
tvm_rpc
PUBLIC "../../include"
PUBLIC DLPACK_PATH
PUBLIC DMLC_PATH
)

target_link_libraries(tvm_rpc tvm_runtime)
92 changes: 65 additions & 27 deletions apps/cpp_rpc/main.cc
Expand Up @@ -21,10 +21,12 @@
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
#include <stdlib.h>
#include <signal.h>
#include <stdio.h>
#include <cstdlib>
#include <csignal>
#include <cstdio>
#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
#endif
#include <dmlc/logging.h>
#include <iostream>
#include <cstring>
Expand All @@ -35,11 +37,15 @@
#include "../../src/common/socket.h"
#include "rpc_server.h"

#if defined(_WIN32)
#include "win32_process.h"
#endif

using namespace std;
using namespace tvm::runtime;
using namespace tvm::common;

static const string kUSAGE = \
static const string kUsage = \
"Command line usage\n" \
" server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \
Expand Down Expand Up @@ -73,13 +79,16 @@ struct RpcServerArgs {
string key;
string custom_addr;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
#endif
};

/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
void PrintArgs(struct RpcServerArgs args) {
void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
Expand All @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
}

#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
Expand All @@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}

#endif
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
Expand All @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
}
// We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '=');
cmd = arg.substr(arg.find("=") + 1);
cmd = arg.substr(arg.find('=') + 1);
return cmd;
}
}
Expand All @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
* \param args, the output structure which holds the parsed values
* \param args the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
string silent = GetCmdOption(argc, argv, "--silent", true);
const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}

string host = GetCmdOption(argc, argv, "--host=");
const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.host = host;
}

string port = GetCmdOption(argc, argv, "--port=");
const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port = stoi(port);
}

string port_end = GetCmdOption(argc, argv, "--port_end=");
const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.port_end = stoi(port_end);
Expand All @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.tracker = tracker;
}

string key = GetCmdOption(argc, argv, "--key=");
const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}

string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
exit(1);
}
args.custom_addr = custom_addr;
}
#if defined(WIN32)
const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
if(!mmap_path.empty()) {
args.mmap_path = mmap_path;
dmlc::InitLogging("--minloglevel=0");
}
#endif

}

/*!
Expand All @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation.
*/
int RpcServer(int argc, char * argv[]) {
struct RpcServerArgs args;
RpcServerArgs args;

/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);

// Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
#if defined(__linux__) || defined(__ANDROID__)
// Ctrl+C handler
HandleCtrlC();
tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
#endif

#if defined(WIN32)
if(!args.mmap_path.empty()) {
int ret = 0;

try {
ChildProcSocketHandler(args.mmap_path);
} catch (const std::exception&) {
ret = -1;
}

return ret;
}
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
args.key, args.custom_addr, args.silent);
return 0;
}

Expand All @@ -251,15 +286,18 @@ int RpcServer(int argc, char * argv[]) {
*/
int main(int argc, char * argv[]) {
if (argc <= 1) {
LOG(INFO) << kUSAGE;
LOG(INFO) << kUsage;
return 0;
}

// Runs WSAStartup on Win32, no-op on POSIX
Socket::Startup();

if (0 == strcmp(argv[1], "server")) {
RpcServer(argc, argv);
} else {
LOG(INFO) << kUSAGE;
return RpcServer(argc, argv);
}

LOG(INFO) << kUsage;

return 0;
}

0 comments on commit 49f5e87

Please sign in to comment.