diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e53e362 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "3rd/json11"] + path = 3rd/json11 + url = https://github.com/dropbox/json11 +[submodule "3rd/spdlog"] + path = 3rd/spdlog + url = https://github.com/gabime/spdlog diff --git a/3rd/json11 b/3rd/json11 new file mode 160000 index 0000000..2df9473 --- /dev/null +++ b/3rd/json11 @@ -0,0 +1 @@ +Subproject commit 2df9473fb3605980db55ecddf34392a2e832ad35 diff --git a/3rd/spdlog b/3rd/spdlog new file mode 160000 index 0000000..592ea36 --- /dev/null +++ b/3rd/spdlog @@ -0,0 +1 @@ +Subproject commit 592ea36a86a9c9049b433d9e44256d04333d8e52 diff --git a/README.md b/README.md index 0342692..34a25c7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ -# cusim -cuda implementaion of w2v and lda +### How to install + + +```shell +# clone repo and submodules +git clone git@github.com:js1010/cusim.git && cd cusim && git submodule update --init + +# install requirements +pip install -r requirements.txt + +# generate proto +python -m grpc_tools.protoc --python_out cusim/ --proto_path cusim/proto/ config.proto + +# install +python setup.py install +``` diff --git a/cpp/include/culda/cuda_lda_kernels.cuh b/cpp/include/culda/cuda_lda_kernels.cuh new file mode 100644 index 0000000..02dbb37 --- /dev/null +++ b/cpp/include/culda/cuda_lda_kernels.cuh @@ -0,0 +1,121 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" + + +namespace cusim { + +// reference: http://web.science.mq.edu.au/~mjohnson/code/digamma.c +__inline__ __device__ +float Digamma(float x) { + float result = 0.0f, xx, xx2, xx4; + for ( ; x < 7.0f; ++x) + result -= 1.0f / x; + x -= 0.5f; + xx = 1.0f / x; + xx2 = xx * xx; + xx4 = xx2 * xx2; + result += logf(x) + 1.0f / 24.0f * xx2 + - 7.0f / 960.0f * xx4 + 31.0f / 8064.0f * xx4 * xx2 + - 127.0f / 30720.0f * xx4 * xx4; + return result; +} + +__global__ void EstepKernel( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr, + const int num_topics, const int num_iters, + float* gamma, float* new_gamma, float* phi, + const float* alpha, const float* beta, + float* grad_alpha, float* new_beta, + float* train_losses, float* vali_losses, int* mutex) { + + // storage for block + float* _gamma = gamma + num_topics * blockIdx.x; + float* _new_gamma = new_gamma + num_topics * blockIdx.x; + float* _phi = phi + num_topics * blockIdx.x; + float* _grad_alpha = grad_alpha + num_topics * blockIdx.x; + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + // initialize gamma + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _gamma[j] = alpha[j] + (end - beg) / num_topics; + __syncthreads(); + + // iterate E step + for (int j = 0; j < num_iters; ++j) { + // initialize new gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _new_gamma[k] = 0.0f; + __syncthreads(); + + // compute phi from gamma + for (int k = beg; k < end; ++k) { + const int w = cols[k]; + const bool _vali = vali[k]; + + // compute phi + if (not _vali or j + 1 == num_iters) { + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) + _phi[l] = beta[w * num_topics + l] * expf(Digamma(_gamma[l])); + __syncthreads(); + + // normalize phi and add it to new gamma and new beta + float phi_sum = ReduceSum(_phi, num_topics); + + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + _phi[l] /= phi_sum; + if (not _vali) _new_gamma[l] += _phi[l]; + } + __syncthreads(); + } + + if (j + 1 == num_iters) { + // write access of w th vector of new_beta + if (threadIdx.x == 0) { + while (atomicCAS(&mutex[w], 0, 1)) {} + } + + __syncthreads(); + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + if (j + 1 == num_iters) { + if (not _vali) new_beta[w * num_topics + l] += _phi[l]; + _phi[l] *= beta[w * num_topics + l]; + } + } + __syncthreads(); + + // release lock + if (threadIdx.x == 0) mutex[w] = 0; + __syncthreads(); + + float p = fmaxf(EPS, ReduceSum(_phi, num_topics)); + if (threadIdx.x == 0) { + if (_vali) + vali_losses[blockIdx.x] += logf(p); + else + train_losses[blockIdx.x] += logf(p); + } + } + __syncthreads(); + } + + // update gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _gamma[k] = _new_gamma[k] + alpha[k]; + __syncthreads(); + } + float gamma_sum = ReduceSum(_gamma, num_topics); + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum)); + + __syncthreads(); + } +} + +} // cusim diff --git a/cpp/include/culda/culda.hpp b/cpp/include/culda/culda.hpp new file mode 100644 index 0000000..3a126cc --- /dev/null +++ b/cpp/include/culda/culda.hpp @@ -0,0 +1,88 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT + +#include "json11.hpp" +#include "utils/log.hpp" +#include "utils/types.hpp" + +namespace cusim { + + +// reference: https://people.math.sc.edu/Burkardt/cpp_src/asa121/asa121.cpp +inline float Trigamma(float x) { + const float a = 0.0001f; + const float b = 5.0f; + const float b2 = 0.1666666667f; + const float b4 = -0.03333333333f; + const float b6 = 0.02380952381f; + const float b8 = -0.03333333333f; + float value = 0, y = 0, z = x; + if (x <= a) return 1.0f / x / x; + while (z < b) { + value += 1.0f / z / z; + z++; + } + y = 1.0f / z / z; + value += value + 0.5 * y + (1.0 + + y * (b2 + + y * (b4 + + y * (b6 + + y * b8)))) / z; + return value; +} + + +class CuLDA { + public: + CuLDA(); + ~CuLDA(); + bool Init(std::string opt_path); + void LoadModel(float* alpha, float* beta, + float* grad_alpha, float* new_beta, const int num_words); + std::pair FeedData( + const int* indices, const int* indptr, const bool* vali, + const int num_indices, const int num_indptr, const int num_iters); + void Pull(); + void Push(); + int GetBlockCnt(); + + private: + DeviceInfo dev_info_; + json11::Json opt_; + std::shared_ptr logger_; + thrust::device_vector dev_alpha_, dev_beta_; + thrust::device_vector dev_grad_alpha_, dev_new_beta_; + thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; + thrust::device_vector dev_mutex_; + + float *alpha_, *beta_, *grad_alpha_, *new_beta_; + int block_cnt_, block_dim_; + int num_topics_, num_words_; +}; + +} // namespace cusim diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh new file mode 100644 index 0000000..026da7f --- /dev/null +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -0,0 +1,172 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include "utils/types.hpp" + +namespace cusim { + +// Error Checking utilities, checks status codes from cuda calls +// and throws exceptions on failure (which cython can proxy back to python) +#define CHECK_CUDA(code) { checkCuda((code), __FILE__, __LINE__); } +inline void checkCuda(cudaError_t code, const char *file, int line) { + if (code != cudaSuccess) { + std::stringstream err; + err << "Cuda Error: " << cudaGetErrorString(code) << " (" << file << ":" << line << ")"; + throw std::runtime_error(err.str()); + } +} + +inline const char* cublasGetErrorString(cublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + } + return "Unknown"; +} + +#define CHECK_CUBLAS(code) { checkCublas((code), __FILE__, __LINE__); } +inline void checkCublas(cublasStatus_t code, const char * file, int line) { + if (code != CUBLAS_STATUS_SUCCESS) { + std::stringstream err; + err << "cublas error: " << cublasGetErrorString(code) + << " (" << file << ":" << line << ")"; + throw std::runtime_error(err.str()); + } +} + +inline DeviceInfo GetDeviceInfo() { + DeviceInfo ret; + CHECK_CUDA(cudaGetDevice(&ret.devId)); + cudaDeviceProp prop; + CHECK_CUDA(cudaGetDeviceProperties(&prop, ret.devId)); + ret.mp_cnt = prop.multiProcessorCount; + ret.major = prop.major; + ret.minor = prop.minor; + // reference: https://stackoverflow.com/a/32531982 + switch (ret.major) { + case 2: // Fermi + if (ret.minor == 1) + ret.cores = ret.mp_cnt * 48; + else + ret.cores = ret.mp_cnt * 32; + break; + case 3: // Kepler + ret.cores = ret.mp_cnt * 192; + break; + case 5: // Maxwell + ret.cores = ret.mp_cnt * 128; + break; + case 6: // Pascal + if (ret.minor == 1 or ret.minor == 2) + ret.cores = ret.mp_cnt * 128; + else if (ret.minor == 0) + ret.cores = ret.mp_cnt * 64; + else + ret.unknown = true; + break; + case 7: // Volta and Turing + if (ret.minor == 0 or ret.minor == 5) + ret.cores = ret.mp_cnt * 64; + else + ret.unknown = true; + break; + case 8: // Ampere + if (ret.minor == 0) + ret.cores = ret.mp_cnt * 64; + else if (ret.minor == 6) + ret.cores = ret.mp_cnt * 128; + else + ret.unknown = true; + break; + default: + ret.unknown = true; + break; + } + if (ret.cores == -1) ret.cores = ret.mp_cnt * 128; + return ret; +} + +__inline__ __device__ +float warp_reduce_sum(float val) { + #if __CUDACC_VER_MAJOR__ >= 9 + // __shfl_down is deprecated with cuda 9+. use newer variants + unsigned int active = __activemask(); + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down_sync(active, val, offset); + } + #else + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down(val, offset); + } + #endif + return val; +} + +__inline__ __device__ +float ReduceSum(const float* vec, const int length) { + + static __shared__ float shared[32]; + + // figure out the warp/ position inside the warp + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // paritial sum + float val = 0.0f; + for (int i = threadIdx.x; i < length; i += blockDim.x) + val += vec[i]; + val = warp_reduce_sum(val); + + // write out the partial reduction to shared memory if appropiate + if (lane == 0) { + shared[warp] = val; + } + __syncthreads(); + + // if we we don't have multiple warps, we're done + if (blockDim.x <= WARP_SIZE) { + return shared[0]; + } + + // otherwise reduce again in the first warp + val = (threadIdx.x < blockDim.x / WARP_SIZE) ? shared[lane]: 0.0f; + if (warp == 0) { + val = warp_reduce_sum(val); + // broadcast back to shared memory + if (threadIdx.x == 0) { + shared[0] = val; + } + } + __syncthreads(); + return shared[0]; +} + +} // namespace cusim diff --git a/cpp/include/utils/ioutils.hpp b/cpp/include/utils/ioutils.hpp new file mode 100644 index 0000000..756b4b2 --- /dev/null +++ b/cpp/include/utils/ioutils.hpp @@ -0,0 +1,53 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include + +#include "json11.hpp" +#include "utils/log.hpp" + +namespace cusim { + +class IoUtils { + public: + IoUtils(); + ~IoUtils(); + bool Init(std::string opt_path); + int LoadStreamFile(std::string filepath); + std::pair ReadStreamForVocab(int num_lines, int num_threads); + std::pair TokenizeStream(int num_lines, int num_threads); + void GetWordVocab(int min_count, std::string keys_path); + void GetToken(int* rows, int* cols, int* indptr); + private: + void ParseLine(std::string line, std::vector& line_vec); + void ParseLineImpl(std::string line, std::vector& line_vec); + + std::vector> cols_; + std::vector indptr_; + std::mutex global_lock_; + std::ifstream stream_fin_; + json11::Json opt_; + std::shared_ptr logger_; + std::unordered_map word_idmap_, word_count_; + std::vector word_list_; + int num_lines_, remain_lines_; +}; // class IoUtils + +} // namespace cusim diff --git a/cpp/include/utils/log.hpp b/cpp/include/utils/log.hpp new file mode 100644 index 0000000..05f30ec --- /dev/null +++ b/cpp/include/utils/log.hpp @@ -0,0 +1,44 @@ +// Copyright (c) 2020 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +// reference: https://github.com/kakao/buffalo/blob/5f571c2c7d8227e6625c6e538da929e4db11b66d/lib/misc/log.cc +#pragma once +#include + +#define SPDLOG_EOL "" +#define SPDLOG_TRACE_ON +#include "spdlog/spdlog.h" +#include "spdlog/sinks/stdout_color_sinks.h" + +#define __FILENAME__ (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) + +#define INFO(x, ...) logger_->info("[{}:{}] " x "\n", __FILENAME__, __LINE__, __VA_ARGS__); +#define DEBUG(x, ...) logger_->debug("[{}:{}] " x "\n", __FILENAME__, __LINE__, __VA_ARGS__); +#define WARN(x, ...) logger_->warn("[{}:{}] " x "\n", __FILENAME__, __LINE__, __VA_ARGS__); +#define TRACE(x, ...) logger_->trace("[{}:{}] " x "\n", __FILENAME__, __LINE__, __VA_ARGS__); +#define CRITICAL(x, ...) logger_->critical("[{}:{}] " x "\n", __FILENAME__, __LINE__, __VA_ARGS__); + +#define INFO0(x) logger_->info("[{}:{}] " x "\n", __FILENAME__, __LINE__); +#define DEBUG0(x) logger_->debug("[{}:{}] " x "\n", __FILENAME__, __LINE__); +#define WARN0(x) logger_->warn("[{}:{}] " x "\n", __FILENAME__, __LINE__); +#define TRACE0(x) logger_->trace("[{}:{}] " x "\n", __FILENAME__, __LINE__); +#define CRITICAL0(x) logger_->critical("[{}:{}] " x "\n", __FILENAME__, __LINE__); + +namespace cusim { + +class CuSimLogger { + public: + CuSimLogger(); + std::shared_ptr& get_logger(); + void set_log_level(int level); + int get_log_level(); + + private: + static int global_logging_level_; + std::shared_ptr logger_; +}; // class CuSimLogger + +} // namespace cusim diff --git a/cpp/include/utils/types.hpp b/cpp/include/utils/types.hpp new file mode 100644 index 0000000..87e864b --- /dev/null +++ b/cpp/include/utils/types.hpp @@ -0,0 +1,14 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +struct DeviceInfo { + int devId, mp_cnt, major, minor, cores; + bool unknown = false; +}; + +#define WARP_SIZE 32 +#define EPS 1e-10f diff --git a/cpp/src/culda/culda.cu b/cpp/src/culda/culda.cu new file mode 100644 index 0000000..9217902 --- /dev/null +++ b/cpp/src/culda/culda.cu @@ -0,0 +1,135 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "culda/culda.hpp" +#include "culda/cuda_lda_kernels.cuh" + +namespace cusim { + +CuLDA::CuLDA() { + logger_ = CuSimLogger().get_logger(); + dev_info_ = GetDeviceInfo(); + if (dev_info_.unknown) DEBUG0("Unknown device type"); + INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", + dev_info_.major, dev_info_.minor, dev_info_.mp_cnt, dev_info_.cores); +} + +CuLDA::~CuLDA() {} + +bool CuLDA::Init(std::string opt_path) { + std::ifstream in(opt_path.c_str()); + if (not in.is_open()) return false; + + std::string str((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + std::string err_cmt; + auto _opt = json11::Json::parse(str, err_cmt); + if (not err_cmt.empty()) return false; + opt_ = _opt; + CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + num_topics_ = opt_["num_topics"].int_value(); + block_dim_ = opt_["block_dim"].int_value(); + block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); + INFO("num_topics: {}, block_dim: {}, block_cnt: {}", num_topics_, block_dim_, block_cnt_); + return true; +} + +void CuLDA::LoadModel(float* alpha, float* beta, + float* grad_alpha, float* new_beta, int num_words) { + num_words_ = num_words; + DEBUG("copy model({} x {})", num_words_, num_topics_); + dev_alpha_.resize(num_topics_); + dev_beta_.resize(num_topics_ * num_words_); + thrust::copy(alpha, alpha + num_topics_, dev_alpha_.begin()); + thrust::copy(beta, beta + num_topics_ * num_words_, dev_beta_.begin()); + alpha_ = alpha; beta_ = beta; + + // resize device vector + grad_alpha_ = grad_alpha; + new_beta_ = new_beta; + dev_grad_alpha_.resize(num_topics_ * block_cnt_); + dev_new_beta_.resize(num_topics_ * num_words_); + // copy to device + thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); + thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); + dev_gamma_.resize(num_topics_ * block_cnt_); + dev_new_gamma_.resize(num_topics_ * block_cnt_); + dev_phi_.resize(num_topics_ * block_cnt_); + + // set mutex + dev_mutex_.resize(num_words_); + std::vector host_mutex(num_words_, 0); + thrust::copy(host_mutex.begin(), host_mutex.end(), dev_mutex_.begin()); + + CHECK_CUDA(cudaDeviceSynchronize()); +} + +std::pair CuLDA::FeedData( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr, const int num_iters) { + + // copy feed data to GPU memory + thrust::device_vector dev_cols(num_cols); + thrust::device_vector dev_indptr(num_indptr + 1); + thrust::device_vector dev_vali(num_cols); + thrust::device_vector dev_train_losses(block_cnt_, 0.0f); + thrust::device_vector dev_vali_losses(block_cnt_, 0.0f); + thrust::copy(cols, cols + num_cols, dev_cols.begin()); + thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); + thrust::copy(vali, vali + num_cols, dev_vali.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("copy feed data to GPU memory"); + + // run E step in GPU + EstepKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_vali.data()), + num_cols, num_indptr, num_topics_, num_iters, + thrust::raw_pointer_cast(dev_gamma_.data()), + thrust::raw_pointer_cast(dev_new_gamma_.data()), + thrust::raw_pointer_cast(dev_phi_.data()), + thrust::raw_pointer_cast(dev_alpha_.data()), + thrust::raw_pointer_cast(dev_beta_.data()), + thrust::raw_pointer_cast(dev_grad_alpha_.data()), + thrust::raw_pointer_cast(dev_new_beta_.data()), + thrust::raw_pointer_cast(dev_train_losses.data()), + thrust::raw_pointer_cast(dev_vali_losses.data()), + thrust::raw_pointer_cast(dev_mutex_.data())); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("run E step in GPU"); + + // pull loss + std::vector train_losses(block_cnt_), vali_losses(block_cnt_); + thrust::copy(dev_train_losses.begin(), dev_train_losses.end(), train_losses.begin()); + thrust::copy(dev_vali_losses.begin(), dev_vali_losses.end(), vali_losses.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("pull loss values"); + + // accumulate + float train_loss = std::accumulate(train_losses.begin(), train_losses.end(), 0.0f); + float vali_loss = std::accumulate(vali_losses.begin(), vali_losses.end(), 0.0f); + return {train_loss, vali_loss}; +} + +void CuLDA::Pull() { + thrust::copy(dev_grad_alpha_.begin(), dev_grad_alpha_.end(), grad_alpha_); + thrust::copy(dev_new_beta_.begin(), dev_new_beta_.end(), new_beta_); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +void CuLDA::Push() { + thrust::copy(alpha_, alpha_ + num_topics_, dev_alpha_.begin()); + thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); + thrust::copy(beta_, beta_ + num_words_ * num_topics_, dev_beta_.begin()); + thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +int CuLDA::GetBlockCnt() { + return block_cnt_; +} + +} // namespace cusim diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc new file mode 100644 index 0000000..14bd94e --- /dev/null +++ b/cpp/src/utils/ioutils.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "utils/ioutils.hpp" + +namespace cusim { + +IoUtils::IoUtils() { + logger_ = CuSimLogger().get_logger(); +} + +IoUtils::~IoUtils() {} + +bool IoUtils::Init(std::string opt_path) { + std::ifstream in(opt_path.c_str()); + if (not in.is_open()) return false; + + std::string str((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + std::string err_cmt; + auto _opt = json11::Json::parse(str, err_cmt); + if (not err_cmt.empty()) return false; + opt_ = _opt; + CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + return true; +} + +void IoUtils::ParseLine(std::string line, std::vector& ret) { + ParseLineImpl(line, ret); +} + + +void IoUtils::ParseLineImpl(std::string line, std::vector& ret) { + ret.clear(); + int n = line.size(); + std::string element; + for (int i = 0; i < n; ++i) { + if (line[i] == ' ') { + ret.push_back(element); + element.clear(); + } else { + element += std::tolower(line[i]); + } + } + if (element.size() > 0) { + ret.push_back(element); + } +} + +int IoUtils::LoadStreamFile(std::string filepath) { + INFO("read gensim file to generate vocabulary: {}", filepath); + if (stream_fin_.is_open()) stream_fin_.close(); + stream_fin_.open(filepath.c_str()); + int count = 0; + std::string line; + while (getline(stream_fin_, line)) + count++; + stream_fin_.close(); + stream_fin_.open(filepath.c_str()); + num_lines_ = count; + remain_lines_ = num_lines_; + INFO("number of lines: {}", num_lines_); + return count; +} + +std::pair IoUtils::TokenizeStream(int num_lines, int num_threads) { + int read_lines = std::min(num_lines, remain_lines_); + if (not read_lines) return {0, 0}; + remain_lines_ -= read_lines; + cols_.clear(); + cols_.resize(read_lines); + indptr_.resize(read_lines); + std::fill(indptr_.begin(), indptr_.end(), 0); + #pragma omp parallel num_threads(num_threads) + { + std::string line; + std::vector line_vec; + #pragma omp for schedule(dynamic, 4) + for (int i = 0; i < read_lines; ++i) { + // get line thread-safely + { + std::unique_lock lock(global_lock_); + getline(stream_fin_, line); + } + + // seems to be bottle-neck + ParseLine(line, line_vec); + + // tokenize + for (auto& word: line_vec) { + if (not word_idmap_.count(word)) continue; + cols_[i].push_back(word_idmap_[word]); + } + } + } + int cumsum = 0; + for (int i = 0; i < read_lines; ++i) { + cumsum += cols_[i].size(); + indptr_[i] = cumsum; + } + return {read_lines, indptr_[read_lines - 1]}; +} + +void IoUtils::GetToken(int* rows, int* cols, int* indptr) { + int n = cols_.size(); + for (int i = 0; i < n; ++i) { + int beg = i == 0? 0: indptr_[i - 1]; + int end = indptr_[i]; + for (int j = beg; j < end; ++j) { + rows[j] = i; + cols[j] = cols_[i][j - beg]; + } + indptr[i] = indptr_[i]; + } +} + +std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) { + int read_lines = std::min(num_lines, remain_lines_); + remain_lines_ -= read_lines; + #pragma omp parallel num_threads(num_threads) + { + std::string line; + std::vector line_vec; + std::unordered_map word_count; + #pragma omp for schedule(dynamic, 4) + for (int i = 0; i < read_lines; ++i) { + // get line thread-safely + { + std::unique_lock lock(global_lock_); + getline(stream_fin_, line); + } + + // seems to be bottle-neck + ParseLine(line, line_vec); + + // update private word count + for (auto& word: line_vec) { + word_count[word]++; + } + } + + // update word count to class variable + { + std::unique_lock lock(global_lock_); + for (auto& it: word_count) { + word_count_[it.first] += it.second; + } + } + } + if (not remain_lines_) stream_fin_.close(); + return {read_lines, word_count_.size()}; +} + +void IoUtils::GetWordVocab(int min_count, std::string keys_path) { + INFO("number of raw words: {}", word_count_.size()); + word_idmap_.clear(); word_list_.clear(); + for (auto& it: word_count_) { + if (it.second >= min_count) { + word_idmap_[it.first] = word_idmap_.size(); + word_list_.push_back(it.first); + } + } + INFO("number of words after filtering: {}", word_list_.size()); + + // write keys to csv file + std::ofstream fout(keys_path.c_str()); + INFO("dump keys to {}", keys_path); + int n = word_list_.size(); + for (int i = 0; i < n; ++i) { + std::string line = word_list_[i] + "\n"; + fout.write(line.c_str(), line.size()); + } + fout.close(); +} + +} // namespace cusim diff --git a/cpp/src/utils/log.cc b/cpp/src/utils/log.cc new file mode 100644 index 0000000..ddfcb0c --- /dev/null +++ b/cpp/src/utils/log.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2020 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +// reference: https://github.com/kakao/buffalo/blob/5f571c2c7d8227e6625c6e538da929e4db11b66d/lib/misc/log.cc +#include "utils/log.hpp" + + +namespace cusim { +int CuSimLogger::global_logging_level_ = 2; + +CuSimLogger::CuSimLogger() { + spdlog::set_pattern("[%^%-8l%$] %Y-%m-%d %H:%M:%S %v"); + logger_ = spdlog::default_logger(); +} + +std::shared_ptr& CuSimLogger::get_logger() { + return logger_; +} + +void CuSimLogger::set_log_level(int level) { + global_logging_level_ = level; + switch (level) { + case 0: spdlog::set_level(spdlog::level::off); break; + case 1: spdlog::set_level(spdlog::level::warn); break; + case 2: spdlog::set_level(spdlog::level::info); break; + case 3: spdlog::set_level(spdlog::level::debug); break; + default: spdlog::set_level(spdlog::level::trace); break; + } +} + +int CuSimLogger::get_log_level() { + return global_logging_level_; +} + +} // namespace cusim diff --git a/cuda_setup.py b/cuda_setup.py new file mode 100644 index 0000000..5ff76b6 --- /dev/null +++ b/cuda_setup.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/rmcgibbo/npcuda-example and +# https://github.com/cupy/cupy/blob/master/cupy_setup_build.py +# pylint: disable=fixme,access-member-before-definition +# pylint: disable=attribute-defined-outside-init,arguments-differ +import logging +import os +import sys + +from distutils import ccompiler, errors, msvccompiler, unixccompiler +from setuptools.command.build_ext import build_ext as setuptools_build_ext + +HALF_PRECISION = False + +def find_in_path(name, path): + "Find a file in a search path" + # adapted fom http://code.activestate.com/ + # recipes/52224-find-a-file-given-a-search-path/ + for _dir in path.split(os.pathsep): + binpath = os.path.join(_dir, name) + if os.path.exists(binpath): + return os.path.abspath(binpath) + return None + + +def locate_cuda(): + """Locate the CUDA environment on the system + If a valid cuda installation is found + this returns a dict with keys 'home', 'nvcc', 'include', + and 'lib64' and values giving the absolute path to each directory. + Starts by looking for the CUDAHOME env variable. + If not found, everything is based on finding + 'nvcc' in the PATH. + If nvcc can't be found, this returns None + """ + nvcc_bin = 'nvcc' + if sys.platform.startswith("win"): + nvcc_bin = 'nvcc.exe' + + # check env variables CUDA_HOME, CUDAHOME, CUDA_PATH. + found = False + for env_name in ['CUDA_PATH', 'CUDAHOME', 'CUDA_HOME']: + if env_name not in os.environ: + continue + found = True + home = os.environ[env_name] + nvcc = os.path.join(home, 'bin', nvcc_bin) + break + if not found: + # otherwise, search the PATH for NVCC + nvcc = find_in_path(nvcc_bin, os.environ['PATH']) + if nvcc is None: + logging.warning('The nvcc binary could not be located in your ' + '$PATH. Either add it to ' + 'your path, or set $CUDA_HOME to enable CUDA extensions') + return None + home = os.path.dirname(os.path.dirname(nvcc)) + + cudaconfig = {'home': home, + 'nvcc': nvcc, + 'include': os.path.join(home, 'include'), + 'lib64': os.path.join(home, 'lib64')} + post_args = [ + "-arch=sm_52", + "-gencode=arch=compute_52,code=sm_52", + "-gencode=arch=compute_60,code=sm_60", + "-gencode=arch=compute_61,code=sm_61", + "-gencode=arch=compute_70,code=sm_70", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_86,code=sm_86", + "-gencode=arch=compute_86,code=compute_86", + '--ptxas-options=-v', '-O2'] + if HALF_PRECISION: + post_args = [flag for flag in post_args if "52" not in flag] + + if sys.platform == "win32": + cudaconfig['lib64'] = os.path.join(home, 'lib', 'x64') + post_args += ['-Xcompiler', '/MD', '-std=c++14', "-Xcompiler", "/openmp"] + if HALF_PRECISION: + post_args += ["-Xcompiler", "/D HALF_PRECISION"] + else: + post_args += ['-c', '--compiler-options', "'-fPIC'", + "--compiler-options", "'-std=c++14'"] + if HALF_PRECISION: + post_args += ["--compiler-options", "'-D HALF_PRECISION'"] + for k, val in cudaconfig.items(): + if not os.path.exists(val): + logging.warning('The CUDA %s path could not be located in %s', k, val) + return None + + cudaconfig['post_args'] = post_args + return cudaconfig + + +# This code to build .cu extensions with nvcc is taken from cupy: +# https://github.com/cupy/cupy/blob/master/cupy_setup_build.py +class _UnixCCompiler(unixccompiler.UnixCCompiler): + src_extensions = list(unixccompiler.UnixCCompiler.src_extensions) + src_extensions.append('.cu') + + def _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts): + # For sources other than CUDA C ones, just call the super class method. + if os.path.splitext(src)[1] != '.cu': + return unixccompiler.UnixCCompiler._compile( + self, obj, src, ext, cc_args, extra_postargs, pp_opts) + + # For CUDA C source files, compile them with NVCC. + _compiler_so = self.compiler_so + try: + nvcc_path = CUDA['nvcc'] + post_args = CUDA['post_args'] + # TODO? base_opts = build.get_compiler_base_options() + self.set_executable('compiler_so', nvcc_path) + + return unixccompiler.UnixCCompiler._compile( + self, obj, src, ext, cc_args, post_args, pp_opts) + finally: + self.compiler_so = _compiler_so + + +class _MSVCCompiler(msvccompiler.MSVCCompiler): + _cu_extensions = ['.cu'] + + src_extensions = list(unixccompiler.UnixCCompiler.src_extensions) + src_extensions.extend(_cu_extensions) + + def _compile_cu(self, sources, output_dir=None, macros=None, + include_dirs=None, debug=0, extra_preargs=None, + extra_postargs=None, depends=None): + # Compile CUDA C files, mainly derived from UnixCCompiler._compile(). + macros, objects, extra_postargs, pp_opts, _build = \ + self._setup_compile(output_dir, macros, include_dirs, sources, + depends, extra_postargs) + + compiler_so = CUDA['nvcc'] + cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) + post_args = CUDA['post_args'] + + for obj in objects: + try: + src, _ = _build[obj] + except KeyError: + continue + try: + self.spawn([compiler_so] + cc_args + [src, '-o', obj] + post_args) + except errors.DistutilsExecError as e: + raise errors.CompileError(str(e)) + + return objects + + def compile(self, sources, **kwargs): + # Split CUDA C sources and others. + cu_sources = [] + other_sources = [] + for source in sources: + if os.path.splitext(source)[1] == '.cu': + cu_sources.append(source) + else: + other_sources.append(source) + + # Compile source files other than CUDA C ones. + other_objects = msvccompiler.MSVCCompiler.compile( + self, other_sources, **kwargs) + + # Compile CUDA C sources. + cu_objects = self._compile_cu(cu_sources, **kwargs) + + # Return compiled object filenames. + return other_objects + cu_objects + + +class CudaBuildExt(setuptools_build_ext): + """Custom `build_ext` command to include CUDA C source files.""" + + def run(self): + if CUDA is not None: + def wrap_new_compiler(func): + def _wrap_new_compiler(*args, **kwargs): + try: + return func(*args, **kwargs) + except errors.DistutilsPlatformError: + if sys.platform != 'win32': + CCompiler = _UnixCCompiler + else: + CCompiler = _MSVCCompiler + return CCompiler( + None, kwargs['dry_run'], kwargs['force']) + return _wrap_new_compiler + ccompiler.new_compiler = wrap_new_compiler(ccompiler.new_compiler) + # Intentionally causes DistutilsPlatformError in + # ccompiler.new_compiler() function to hook. + self.compiler = 'nvidia' + + setuptools_build_ext.run(self) + + +CUDA = locate_cuda() +assert CUDA is not None +BUILDEXT = CudaBuildExt if CUDA else setuptools_build_ext diff --git a/cusim/.gitignore b/cusim/.gitignore new file mode 100644 index 0000000..19f06b3 --- /dev/null +++ b/cusim/.gitignore @@ -0,0 +1 @@ +config_pb2.py diff --git a/cusim/__init__.py b/cusim/__init__.py new file mode 100644 index 0000000..24fe984 --- /dev/null +++ b/cusim/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.ioutils import IoUtils +from cusim.culda import CuLDA diff --git a/cusim/aux.py b/cusim/aux.py new file mode 100644 index 0000000..4a1c2c5 --- /dev/null +++ b/cusim/aux.py @@ -0,0 +1,337 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +import os +import re +import sys +import json +import time +import logging +import logging.handlers +import numpy as np +import jsmin +from google.protobuf.json_format import Parse, MessageToDict + +# get_logger and Option refer to +# https://github.com/kakao/buffalo/blob/ +# 5f571c2c7d8227e6625c6e538da929e4db11b66d/buffalo/misc/aux.py +def get_logger(name=__file__, level=2): + if level == 1: + level = logging.WARNING + elif level == 2: + level = logging.INFO + elif level == 3: + level = logging.DEBUG + logger = logging.getLogger(name) + if logger.handlers: + return logger + logger.setLevel(level) + sh0 = logging.StreamHandler() + sh0.setLevel(level) + formatter = logging.Formatter('[%(levelname)-8s] %(asctime)s ' + '[%(filename)s] [%(funcName)s:%(lineno)d]' + '%(message)s', '%Y-%m-%d %H:%M:%S') + sh0.setFormatter(formatter) + logger.addHandler(sh0) + return logger + +# This function helps you to read non-standard json strings. +# - Handles json string with c++ style inline comments +# - Handles json string with trailing commas. +def load_json_string(cont): + # (1) Removes comment. + # Refer to https://plus.google.com/+DouglasCrockfordEsq/posts/RK8qyGVaGSr + cont = jsmin.jsmin(cont) + + # (2) Removes trailing comma. + cont = re.sub(",[ \t\r\n]*}", "}", cont) + cont = re.sub(",[ \t\r\n]*" + r"\]", "]", cont) + + return json.loads(cont) + + +# function read json file from filename +def load_json_file(fname): + with open(fname, "r") as fin: + ret = load_json_string(fin.read()) + return ret + +# use protobuf to restrict field and types +def get_opt_as_proto(raw, proto_type=None): + assert proto_type is not None + proto = proto_type() + # convert raw to proto + Parse(json.dumps(Option(raw)), proto) + err = [] + assert proto.IsInitialized(err), \ + f"some required fields are missing in proto {err}\n {proto}" + return proto + +def proto_to_dict(proto): + return MessageToDict(proto, \ + including_default_value_fields=True, \ + preserving_proto_field_name=True) + +def copy_proto(proto): + newproto = type(proto)() + Parse(json.dumps(proto_to_dict(proto)), newproto) + return newproto + +class Option(dict): + def __init__(self, *args, **kwargs): + args = [arg if isinstance(arg, dict) + else load_json_file(arg) for arg in args] + super().__init__(*args, **kwargs) + for arg in args: + if isinstance(arg, dict): + for k, val in arg.items(): + if isinstance(val, dict): + self[k] = Option(val) + else: + self[k] = val + if kwargs: + for k, val in kwargs.items(): + if isinstance(val, dict): + self[k] = Option(val) + else: + self[k] = val + + def __getattr__(self, attr): + return self.get(attr) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self.__dict__.update({key: value}) + + def __delattr__(self, item): + self.__delitem__(item) + + def __delitem__(self, key): + super().__delitem__(key) + del self.__dict__[key] + + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + vars(self).update(state) + +# reference: https://github.com/tensorflow/tensorflow/blob/ +# 85c8b2a817f95a3e979ecd1ed95bff1dc1335cff/tensorflow/python/ +# keras/utils/generic_utils.py#L483 +class Progbar: + # pylint: disable=too-many-branches,too-many-statements,invalid-name + # pylint: disable=blacklisted-name,no-else-return + """Displays a progress bar. + Arguments: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* be + averaged over time. Metrics in this list will be displayed as-is. All + others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__(self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name='step'): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules or + 'posix' in sys.modules or + 'PYCHARM_HOSTED' in os.environ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + Arguments: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is in + `stateful_metrics`, `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, defaults to `current >= self.target`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in the first + # epoch, both on_batch_end and on_epoch_end will be called, which will + # cause 'current' and 'self._seen_so_far' to have the same value. Force + # the minimal value to 1 here, otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + now = time.time() + info = ' - %.0fs' % (now - self._start) + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += ('=' * (prog_width - 1)) + if current < self.target: + bar += '>' + else: + bar += '=' + bar += ('.' * (self.width - prog_width)) + bar += ']' + else: + bar = '%7d/Unknown' % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + if time_per_unit >= 1 or time_per_unit == 0: + info += ' %.0fs/%s' % (time_per_unit, self.unit_name) + elif time_per_unit >= 1e-3: + info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) + else: + info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format + + for k in self._values_order: + info += ' - %s:' % k + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + else: + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) + + if finalize: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) + info = count + info + for k in self._values_order: + info += ' - %s:' % k + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += ' %.4f' % avg + else: + info += ' %.4e' % avg + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + Given the step number `current` and the corresponding time `now` + this function returns an estimate for how long a single step + takes. If this is called before one step has been completed + (i.e. `current == 0`) then zero is given as an estimate. The duration + estimate ignores the duration of the (assumed to be non-representative) + first step for estimates when more steps are available (i.e. `current>1`). + Arguments: + current: Index of current step. + now: The current time. + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying step 1 + # 2) somebody is calling the progress bar and supplies step one mulitple + # times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / (current - 1) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 diff --git a/cusim/culda/__init__.py b/cusim/culda/__init__.py new file mode 100644 index 0000000..e27fb6a --- /dev/null +++ b/cusim/culda/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.culda.pyculda import CuLDA diff --git a/cusim/culda/bindings.cc b/cusim/culda/bindings.cc new file mode 100644 index 0000000..f85b2d5 --- /dev/null +++ b/cusim/culda/bindings.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include +#include +#include + +#include +#include "culda/culda.hpp" + +namespace py = pybind11; + +typedef py::array_t float_array; +typedef py::array_t int_array; +typedef py::array_t bool_array; + +class CuLDABind { + public: + CuLDABind() {} + + bool Init(std::string opt_path) { + return obj_.Init(opt_path); + } + + void LoadModel(py::object& alpha, py::object& beta, + py::object& grad_alpha, py::object& new_beta) { + // check shape of alpha and beta + float_array _alpha(alpha); + float_array _beta(beta); + auto alpha_buffer = _alpha.request(); + auto beta_buffer = _beta.request(); + if (alpha_buffer.ndim != 1 or beta_buffer.ndim != 2 or + alpha_buffer.shape[0] != beta_buffer.shape[1]) { + throw std::runtime_error("invalid alpha or beta"); + } + + // check shape of grad alpha and new beta + float_array _grad_alpha(grad_alpha); + float_array _new_beta(new_beta); + auto grad_alpha_buffer = _grad_alpha.request(); + auto new_beta_buffer = _new_beta.request(); + if (grad_alpha_buffer.ndim != 2 or + new_beta_buffer.ndim != 2 or + grad_alpha_buffer.shape[1] != new_beta_buffer.shape[1]) { + throw std::runtime_error("invalid grad_alpha or new_beta"); + } + + int num_words = beta_buffer.shape[0]; + + return obj_.LoadModel(_alpha.mutable_data(0), + _beta.mutable_data(0), + _grad_alpha.mutable_data(0), + _new_beta.mutable_data(0), num_words); + } + + std::pair FeedData(py::object& cols, py::object& indptr, py::object& vali, const int num_iters) { + int_array _cols(cols); + int_array _indptr(indptr); + bool_array _vali(vali); + auto cols_buffer = _cols.request(); + auto indptr_buffer = _indptr.request(); + auto vali_buffer = _vali.request(); + if (cols_buffer.ndim != 1 or indptr_buffer.ndim != 1 or vali_buffer.ndim != 1 + or cols_buffer.shape[0] != vali_buffer.shape[0]) { + throw std::runtime_error("invalid cols or indptr"); + } + int num_cols = cols_buffer.shape[0]; + int num_indptr = indptr_buffer.shape[0] - 1; + return obj_.FeedData(_cols.data(0), _indptr.data(0), _vali.data(0), + num_cols, num_indptr, num_iters); + } + + void Pull() { + obj_.Pull(); + } + + void Push() { + obj_.Push(); + } + + int GetBlockCnt() { + return obj_.GetBlockCnt(); + } + + private: + cusim::CuLDA obj_; +}; + +PYBIND11_PLUGIN(culda_bind) { + py::module m("CuLDABind"); + + py::class_(m, "CuLDABind") + .def(py::init()) + .def("init", &CuLDABind::Init, py::arg("opt_path")) + .def("load_model", &CuLDABind::LoadModel, + py::arg("alpha"), py::arg("beta"), + py::arg("grad_alpha"), py::arg("new_beta")) + .def("feed_data", &CuLDABind::FeedData, + py::arg("cols"), py::arg("indptr"), py::arg("vali"), py::arg("num_iters")) + .def("pull", &CuLDABind::Pull) + .def("push", &CuLDABind::Push) + .def("get_block_cnt", &CuLDABind::GetBlockCnt) + .def("__repr__", + [](const CuLDABind &a) { + return ""; + } + ); + return m.ptr(); +} diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py new file mode 100644 index 0000000..e0fa3d9 --- /dev/null +++ b/cusim/culda/pyculda.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member +import os +from os.path import join as pjoin + +import json +import tempfile + +import h5py +import numpy as np +from scipy.special import polygamma as pg + +from cusim import aux, IoUtils +from cusim.culda.culda_bind import CuLDABind +from cusim.config_pb2 import CuLDAConfigProto + +EPS = 1e-10 + +class CuLDA: + def __init__(self, opt=None): + self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) + self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) + tmp.write(opt_content) + tmp.close() + + self.logger.info("opt: %s", opt_content) + self.obj = CuLDABind() + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" + os.remove(tmp.name) + + self.words, self.num_words, self.num_docs = None, None, None + self.alpha, self.beta, self.grad_alpha, self.new_beta = \ + None, None, None, None + + def preprocess_data(self): + if self.opt.skip_preprocess: + return + iou = IoUtils() + if not self.opt.processed_data_dir: + self.opt.processed_data_dir = tempfile.TemporaryDirectory().name + iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, + self.opt.processed_data_dir) + + def init_model(self): + # load voca + data_dir = self.opt.processed_data_dir + self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) + with open(pjoin(data_dir, "keys.txt"), "rb") as fin: + self.words = [line.strip() for line in fin] + self.num_words = len(self.words) + + # count number of docs + h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") + self.num_docs = h5f["indptr"].shape[0] - 1 + h5f.close() + + self.logger.info("number of words: %d, docs: %d", + self.num_words, self.num_docs) + + # random initialize alpha and beta + np.random.seed(self.opt.seed) + self.alpha = np.random.uniform( \ + size=(self.opt.num_topics,)).astype(np.float32) + self.beta = np.random.uniform( \ + size=(self.num_words, self.opt.num_topics)).astype(np.float32) + self.beta /= np.sum(self.beta, axis=0)[None, :] + self.logger.info("alpha %s, beta %s initialized", + self.alpha.shape, self.beta.shape) + + # zero initialize grad alpha and new beta + block_cnt = self.obj.get_block_cnt() + self.grad_alpha = np.zeros(shape=(block_cnt, self.opt.num_topics), + dtype=np.float32) + self.new_beta = np.zeros(shape=self.beta.shape, dtype=np.float32) + self.logger.info("grad alpha %s, new beta %s initialized", + self.grad_alpha.shape, self.new_beta.shape) + + # push it to gpu + self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) + + def train_model(self): + self.preprocess_data() + self.init_model() + h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") + for epoch in range(1, self.opt.epochs + 1): + self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) + self._train_e_step(h5f) + self._train_m_step() + h5f.close() + + def _train_e_step(self, h5f): + offset, size = 0, h5f["cols"].shape[0] + pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) + train_loss_nume, train_loss_deno = 0, 0 + vali_loss_nume, vali_loss_deno = 0, 0 + while True: + target = h5f["indptr"][offset] + self.opt.batch_size + if target < size: + next_offset = h5f["rows"][target] + else: + next_offset = h5f["indptr"].shape[0] - 1 + indptr = h5f["indptr"][offset:next_offset + 1] + beg, end = indptr[0], indptr[-1] + indptr -= beg + cols = h5f["cols"][beg:end] + vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) + offset = next_offset + + # call cuda kernel + train_loss, vali_loss = \ + self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) + + # accumulate loss + train_loss_nume -= train_loss + vali_loss_nume -= vali_loss + vali_cnt = np.count_nonzero(vali) + train_cnt = len(vali) - vali_cnt + train_loss_deno += train_cnt + vali_loss_deno += vali_cnt + train_loss = train_loss_nume / (train_loss_deno + EPS) + vali_loss = vali_loss_nume / (vali_loss_deno + EPS) + + # update progress bar + pbar.update(end, values=[("train_loss", train_loss), + ("vali_loss", vali_loss)]) + if end == size: + break + + def _train_m_step(self): + self.obj.pull() + + # update beta + self.new_beta[:, :] = np.maximum(self.new_beta, EPS) + self.beta[:, :] = self.new_beta / np.sum(self.new_beta, axis=0)[None, :] + self.new_beta[:, :] = 0 + + # update alpha + alpha_sum = np.sum(self.alpha) + gvec = np.sum(self.grad_alpha, axis=0) + gvec += self.num_docs * (pg(0, alpha_sum) - pg(0, self.alpha)) + hvec = self.num_docs * pg(1, self.alpha) + z_0 = pg(1, alpha_sum) + c_nume = np.sum(gvec / hvec) + c_deno = 1 / z_0 + np.sum(1 / hvec) + c_0 = c_nume / c_deno + delta = (gvec - c_0) / hvec + self.alpha -= delta + self.alpha[:] = np.maximum(self.alpha, EPS) + self.grad_alpha[:,:] = 0 + + self.obj.push() + + def save_model(self, model_path): + self.logger.info("save model path: %s", model_path) + h5f = h5py.File(model_path, "w") + h5f.create_dataset("alpha", data=self.alpha) + h5f.create_dataset("beta", data=self.beta) + h5f.create_dataset("keys", data=np.array(self.words)) + h5f.close() diff --git a/cusim/ioutils/__init__.py b/cusim/ioutils/__init__.py new file mode 100644 index 0000000..61fc1fe --- /dev/null +++ b/cusim/ioutils/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.ioutils.pyioutils import IoUtils diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc new file mode 100644 index 0000000..28fbbc8 --- /dev/null +++ b/cusim/ioutils/bindings.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include +#include +#include + +#include +#include "utils/ioutils.hpp" + +namespace py = pybind11; + +typedef py::array_t float_array; +typedef py::array_t int_array; + +class IoUtilsBind { + public: + IoUtilsBind() {} + + bool Init(std::string opt_path) { + return obj_.Init(opt_path); + } + + int LoadStreamFile(std::string filepath) { + return obj_.LoadStreamFile(filepath); + } + + std::pair ReadStreamForVocab(int num_lines, int num_threads) { + return obj_.ReadStreamForVocab(num_lines, num_threads); + } + + std::pair TokenizeStream(int num_lines, int num_threads) { + return obj_.TokenizeStream(num_lines, num_threads); + } + + void GetWordVocab(int min_count, std::string keys_path) { + obj_.GetWordVocab(min_count, keys_path); + } + + void GetToken(py::object& rows, py::object& cols, py::object& indptr) { + int_array _rows(rows); + int_array _cols(cols); + int_array _indptr(indptr); + obj_.GetToken(_rows.mutable_data(0), _cols.mutable_data(0), _indptr.mutable_data(0)); + } + + private: + cusim::IoUtils obj_; +}; + +PYBIND11_PLUGIN(ioutils_bind) { + py::module m("IoUtilsBind"); + + py::class_(m, "IoUtilsBind") + .def(py::init()) + .def("init", &IoUtilsBind::Init, py::arg("opt_path")) + .def("load_stream_file", &IoUtilsBind::LoadStreamFile, py::arg("filepath")) + .def("read_stream_for_vocab", &IoUtilsBind::ReadStreamForVocab, + py::arg("num_lines"), py::arg("num_threads")) + .def("tokenize_stream", &IoUtilsBind::TokenizeStream, + py::arg("num_lines"), py::arg("num_threads")) + .def("get_word_vocab", &IoUtilsBind::GetWordVocab, + py::arg("min_count"), py::arg("keys_path")) + .def("get_token", &IoUtilsBind::GetToken, + py::arg("indices"), py::arg("indptr"), py::arg("offset")) + .def("__repr__", + [](const IoUtilsBind &a) { + return ""; + } + ); + return m.ptr(); +} diff --git a/cusim/ioutils/pyioutils.py b/cusim/ioutils/pyioutils.py new file mode 100644 index 0000000..5bce9b7 --- /dev/null +++ b/cusim/ioutils/pyioutils.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member +import os +from os.path import join as pjoin + +import json +import tempfile + +import h5py +import numpy as np + +from cusim import aux +from cusim.ioutils.ioutils_bind import IoUtilsBind +from cusim.config_pb2 import IoUtilsConfigProto + +class IoUtils: + def __init__(self, opt=None): + self.opt = aux.get_opt_as_proto(opt or {}, IoUtilsConfigProto) + self.logger = aux.get_logger("ioutils", level=self.opt.py_log_level) + + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) + tmp.write(opt_content) + tmp.close() + + self.logger.info("opt: %s", opt_content) + self.obj = IoUtilsBind() + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" + os.remove(tmp.name) + + def load_stream_vocab(self, filepath, min_count, keys_path): + full_num_lines = self.obj.load_stream_file(filepath) + pbar = aux.Progbar(full_num_lines, unit_name="line", + stateful_metrics=["word_count"]) + processed = 0 + while True: + read_lines, word_count = \ + self.obj.read_stream_for_vocab( + self.opt.chunk_lines, self.opt.num_threads) + processed += read_lines + pbar.update(processed, values=[("word_count", word_count)]) + if processed == full_num_lines: + break + self.obj.get_word_vocab(min_count, keys_path) + + def convert_stream_to_h5(self, filepath, min_count, out_dir, + chunk_indices=10000, seed=777): + np.random.seed(seed) + os.makedirs(out_dir, exist_ok=True) + keys_path = pjoin(out_dir, "keys.txt") + token_path = pjoin(out_dir, "token.h5") + self.logger.info("save key and token to %s, %s", + keys_path, token_path) + self.load_stream_vocab(filepath, min_count, keys_path) + full_num_lines = self.obj.load_stream_file(filepath) + pbar = aux.Progbar(full_num_lines, unit_name="line") + processed = 0 + h5f = h5py.File(token_path, "w") + rows = h5f.create_dataset("rows", shape=(chunk_indices,), + maxshape=(None,), dtype=np.int32, + chunks=(chunk_indices,)) + cols = h5f.create_dataset("cols", shape=(chunk_indices,), + maxshape=(None,), dtype=np.int32, + chunks=(chunk_indices,)) + vali = h5f.create_dataset("vali", shape=(chunk_indices,), + maxshape=(None,), dtype=np.float32, + chunks=(chunk_indices,)) + indptr = h5f.create_dataset("indptr", shape=(full_num_lines + 1,), + dtype=np.int32, chunks=True) + processed, offset = 1, 0 + indptr[0] = 0 + while True: + read_lines, data_size = self.obj.tokenize_stream( + self.opt.chunk_lines, self.opt.num_threads) + _rows = np.empty(shape=(data_size,), dtype=np.int32) + _cols = np.empty(shape=(data_size,), dtype=np.int32) + _indptr = np.empty(shape=(read_lines,), dtype=np.int32) + self.obj.get_token(_rows, _cols, _indptr) + rows.resize((offset + data_size,)) + rows[offset:offset + data_size] = _rows + (processed - 1) + cols.resize((offset + data_size,)) + cols[offset:offset + data_size] = _cols + vali.resize((offset + data_size,)) + vali[offset:offset + data_size] = \ + np.random.uniform(size=(data_size,)).astype(np.float32) + indptr[processed:processed + read_lines] = _indptr + offset + offset += data_size + processed += read_lines + pbar.update(processed - 1) + if processed == full_num_lines + 1: + break + h5f.close() diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto new file mode 100644 index 0000000..10b3820 --- /dev/null +++ b/cusim/proto/config.proto @@ -0,0 +1,33 @@ +// Copyright (c) 2020 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. + +syntax = "proto2"; + +message IoUtilsConfigProto { + optional int32 py_log_level = 1 [default = 2]; + optional int32 c_log_level = 2 [default = 2]; + optional int32 chunk_lines = 3 [default = 100000]; + optional int32 num_threads = 4 [default = 4]; +} + +message CuLDAConfigProto { + required string data_path = 7; + + optional int32 py_log_level = 1 [default = 2]; + optional int32 c_log_level = 2 [default = 2]; + + optional int32 num_topics = 3 [default = 10]; + optional int32 block_dim = 4 [default = 32]; + optional int32 hyper_threads = 5 [default = 10]; + optional string processed_data_dir = 6; + optional bool skip_preprocess = 8; + optional int32 word_min_count = 9 [default = 5]; + optional int32 batch_size = 10 [default = 100000]; + optional int32 epochs = 11 [default = 10]; + optional int32 num_iters_in_e_step = 12 [default = 5]; + optional double vali_p = 13 [default = 0.2]; + optional int32 seed = 14 [default = 777]; +} diff --git a/examples/example1.py b/examples/example1.py new file mode 100644 index 0000000..f9362cb --- /dev/null +++ b/examples/example1.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,logging-format-truncated +import os +import subprocess +import fire + +import h5py +import numpy as np +from gensim import downloader as api +from cusim import aux, IoUtils, CuLDA + +LOGGER = aux.get_logger() +DOWNLOAD_PATH = "./res" +# DATASET = "wiki-english-20171001" +DATASET = "quora-duplicate-questions" +DATA_PATH = f"./res/{DATASET}.stream.txt" +LDA_PATH = f"./res/{DATASET}-lda.h5" +PROCESSED_DATA_DIR = f"./res/{DATASET}-converted" +MIN_COUNT = 5 +TOPK = 10 + +def download(): + if os.path.exists(DATA_PATH): + LOGGER.info("%s already exists", DATA_PATH) + return + api.BASE_DIR = DOWNLOAD_PATH + filepath = api.load(DATASET, return_path=True) + LOGGER.info("filepath: %s", filepath) + cmd = ["gunzip", "-c", filepath, ">", DATA_PATH] + cmd = " ".join(cmd) + LOGGER.info("cmd: %s", cmd) + subprocess.call(cmd, shell=True) + +def run_io(): + download() + iou = IoUtils(opt={"chunk_lines": 10000, "num_threads": 8}) + iou.convert_stream_to_h5(DATA_PATH, 5, PROCESSED_DATA_DIR) + + +def run_lda(): + opt = { + "data_path": DATA_PATH, + "processed_data_dir": PROCESSED_DATA_DIR, + "skip_preprocess":True, + } + lda = CuLDA(opt) + lda.train_model() + lda.save_model(LDA_PATH) + h5f = h5py.File(LDA_PATH, "r") + beta = h5f["beta"][:] + word_list = h5f["keys"][:] + num_topics = h5f["alpha"].shape[0] + for i in range(num_topics): + print("=" * 50) + print(f"topic {i + 1}") + words = np.argsort(-beta.T[i])[:10] + print("-" * 50) + for j in range(TOPK): + word = word_list[words[j]].decode("utf8") + prob = beta[words[j], i] + print(f"rank {j + 1}. word: {word}, prob: {prob}") + h5f.close() + +if __name__ == "__main__": + fire.Fire() diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..728d5c2 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +fire +gensim diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bfe001f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +jsmin +numpy +pandas +pybind11 +protobuf==3.10.0 +grpcio-tools==1.27.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..512d262 --- /dev/null +++ b/setup.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=fixme,too-few-public-methods +# reference: https://github.com/kakao/buffalo/blob/ +# 5f571c2c7d8227e6625c6e538da929e4db11b66d/setup.py +"""cusim +""" +import os +import sys +import glob +import pathlib +import platform +import sysconfig +import subprocess +from setuptools import setup, Extension + +import pybind11 +import numpy as np +from cuda_setup import CUDA, BUILDEXT + + +DOCLINES = __doc__.split("\n") + +# TODO: Python3 Support +if sys.version_info[:3] < (3, 6): + raise RuntimeError("Python version 3.6 or later required.") + +assert platform.system() == 'Linux' # TODO: MacOS +with open("requirements.txt", "r") as fin: + INSTALL_REQUIRES = [line.strip() for line in fin] + +MAJOR = 0 +MINOR = 0 +MICRO = 0 +RELEASE = True +STAGE = {True: '', False: 'b'}.get(RELEASE) +VERSION = f'{MAJOR}.{MINOR}.{MICRO}{STAGE}' +STATUS = {False: 'Development Status :: 4 - Beta', + True: 'Development Status :: 5 - Production/Stable'} + +CLASSIFIERS = """{status} +Programming Language :: C++ +Programming Language :: Python :: 3.6 +Operating System :: POSIX :: Linux +Operating System :: Unix +Operating System :: MacOS +License :: OSI Approved :: Apache Software License""".format( \ + status=STATUS.get(RELEASE)) +CLIB_DIR = os.path.join(sysconfig.get_path('purelib'), 'cusim') +LIBRARY_DIRS = [CLIB_DIR] + + +def get_extend_compile_flags(): + flags = ['-march=native'] + return flags + + +class CMakeExtension(Extension): + extension_type = 'cmake' + + def __init__(self, name): + super().__init__(name, sources=[]) + + +extend_compile_flags = get_extend_compile_flags() +extra_compile_args = ['-fopenmp', '-std=c++14', '-ggdb', '-O3'] + \ + extend_compile_flags +util_srcs = glob.glob("cpp/src/utils/*.cc") +extensions = [ + Extension("cusim.ioutils.ioutils_bind", + sources = util_srcs + [ \ + "cusim/ioutils/bindings.cc", + "3rd/json11/json11.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=["-fopenmp"], + extra_objects=[], + include_dirs=[ \ + "cpp/include/", np.get_include(), pybind11.get_include(), + pybind11.get_include(True), + "3rd/json11", "3rd/spdlog/include"]), + Extension("cusim.culda.culda_bind", + sources= util_srcs + [ \ + "cpp/src/culda/culda.cu", + "cusim/culda/bindings.cc", + "3rd/json11/json11.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=["-fopenmp"], + library_dirs=[CUDA['lib64']], + libraries=['cudart', 'cublas', 'curand'], + extra_objects=[], + include_dirs=[ \ + "cpp/include/", np.get_include(), pybind11.get_include(), + pybind11.get_include(True), CUDA['include'], + "3rd/json11", "3rd/spdlog/include"]), +] + + +# Return the git revision as a string +def git_version(): + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH']: + val = os.environ.get(k) + if val is not None: + env[k] = val + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env). \ + communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + git_revision = out.strip().decode('ascii') + except OSError: + git_revision = "Unknown" + + return git_revision + + +def write_version_py(filename='cusim/version.py'): + cnt = """ +short_version = '%(version)s' +git_revision = '%(git_revision)s' +""" + git_revision = git_version() + with open(filename, 'w') as fout: + fout.write(cnt % {'version': VERSION, + 'git_revision': git_revision}) + + +class BuildExtension(BUILDEXT): + def run(self): + for ext in self.extensions: + print(ext.name) + if hasattr(ext, 'extension_type') and ext.extension_type == 'cmake': + self.cmake() + super().run() + + def cmake(self): + cwd = pathlib.Path().absolute() + + build_temp = pathlib.Path(self.build_temp) + build_temp.mkdir(parents=True, exist_ok=True) + + build_type = 'Debug' if self.debug else 'Release' + + cmake_args = [ + '-DCMAKE_BUILD_TYPE=' + build_type, + '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + CLIB_DIR, + ] + + build_args = [] + + os.chdir(str(build_temp)) + self.spawn(['cmake', str(cwd)] + cmake_args) + if not self.dry_run: + self.spawn(['cmake', '--build', '.'] + build_args) + os.chdir(str(cwd)) + + +def setup_package(): + write_version_py() + cmdclass = { + 'build_ext': BuildExtension + } + + metadata = dict( + name='cusim', + maintainer="Jisang Yoon", + maintainer_email="vjs10101v@gmail.com", + author="Jisang Yoon", + author_email="vjs10101v@gmail.com", + description=DOCLINES[0], + long_description="\n".join(DOCLINES[2:]), + url="https://github.com/js1010/cusim", + download_url="https://github.com/js1010/cusim/releases", + include_package_data=False, + license='Apache2', + packages=['cusim/', "cusim/ioutils/", "cusim/culda/"], + install_requires=INSTALL_REQUIRES, + cmdclass=cmdclass, + classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f], + platforms=['Linux', 'Mac OSX', 'Unix'], + ext_modules=extensions, + entry_points={ + 'console_scripts': [ + ] + }, + python_requires='>=3.6', + ) + + metadata['version'] = VERSION + setup(**metadata) + + +if __name__ == '__main__': + setup_package()