Permalink
Browse files

fix warpctc plugin work with new version warpctc (#4530)

  • Loading branch information...
1 parent 6ef3058 commit b24c6c5077221a6686e85433973ba8ba6aa927c7 @yajiedesign yajiedesign committed with piiswrong Jan 7, 2017
Showing with 16 additions and 2 deletions.
  1. +14 −1 CMakeLists.txt
  2. +2 −1 plugin/warpctc/warpctc-inl.h
View
@@ -21,6 +21,8 @@ include(mshadow/cmake/Utils.cmake)
include(mshadow/cmake/Cuda.cmake)
set(mxnet_LINKER_LIBS "")
+set(mxnet_LINKER_LIBS_DEBUG "")
+set(mxnet_LINKER_LIBS_RELEASE "")
list(APPEND mxnet_LINKER_LIBS ${mshadow_LINKER_LIBS})
include_directories("include")
@@ -173,7 +175,11 @@ list(APPEND SOURCE ${MSHADOWSOURCE})
if(USE_PLUGINS_WARPCTC)
set(WARPCTC_INCLUDE "" CACHE PATH "WARPCTC include")
- set(WARPCTC_LIB "" CACHE FILEPATH "WARPCTC lib")
+ set(WARPCTC_LIB_DEBUG "" CACHE FILEPATH "WARPCTC lib")
+ set(WARPCTC_LIB_RELEASE "" CACHE FILEPATH "WARPCTC lib")
+ set(mxnet_LINKER_LIBS_RELEASE ${WARPCTC_LIB_RELEASE})
+ set(mxnet_LINKER_LIBS_DEBUG ${WARPCTC_LIB_DEBUG})
+
include_directories(SYSTEM ${WARPCTC_INCLUDE})
list(APPEND mxnet_LINKER_LIBS ${WARPCTC_LIB})
mxnet_source_group("Include\\plugin\\warpctc" GLOB "plugin/warpctc/*.h")
@@ -274,8 +280,15 @@ else()
add_library(mxnet SHARED ${SOURCE})
endif()
target_link_libraries(mxnet ${mxnet_LINKER_LIBS})
+
+if(USE_PLUGINS_WARPCTC)
+ target_link_libraries(mxnet debug ${mxnet_LINKER_LIBS_DEBUG})
+ target_link_libraries(mxnet optimized ${mxnet_LINKER_LIBS_RELEASE})
+endif()
+
target_link_libraries(mxnet dmlccore)
+
if(MSVC AND USE_MXNET_LIB_NAMING)
set_target_properties(mxnet PROPERTIES OUTPUT_NAME "libmxnet")
@@ -120,7 +120,7 @@ class WarpCTCOp : public Operator {
TBlob data = in_data[warpctc_enum::kData];
TBlob label = in_data[warpctc_enum::kLabel];
CHECK_EQ(data.shape_.ndim(), 2) << "input data shape should be 2 (t*n, p)";
- ctcComputeInfo info; // please build warp-ctc with commit 5bfb46e (cd warp-ctc && git checkout 5bfb46e) NOLINT(*)
+ ctcOptions info; //please updated to latest baidu/warp-ctc NOLINT(*)
if (data.dev_mask_ == cpu::kDevMask) {
info.loc = CTC_CPU;
info.num_threads = 1;
@@ -132,6 +132,7 @@ class WarpCTCOp : public Operator {
#endif
LOG(FATAL) << "Unknown device type " << data.dev_mask_;
}
+ info.blank_label = 0;
int T = param_.input_length;
int minibatch = data.shape_[0] / T;

0 comments on commit b24c6c5

Please sign in to comment.