diff --git a/cpp/include/culda.hpp b/cpp/include/culda.hpp deleted file mode 100644 index 0c2ba61..0000000 --- a/cpp/include/culda.hpp +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2021 Jisang Yoon -// All rights reserved. -// -// This source code is licensed under the Apache 2.0 license found in the -// LICENSE file in the root directory of this source tree. -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // NOLINT - -#include "json11.hpp" -#include "log.hpp" -#include "types.hpp" - -namespace cusim { - -class CuLDA { - public: - CuLDA(); - ~CuLDA(); - private: - std::shared_ptr logger_; - thrust::device_vector device_data_; -}; - -} // namespace cusim diff --git a/cpp/include/culda/cuda_lda_kernels.cuh b/cpp/include/culda/cuda_lda_kernels.cuh new file mode 100644 index 0000000..02dbb37 --- /dev/null +++ b/cpp/include/culda/cuda_lda_kernels.cuh @@ -0,0 +1,121 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" + + +namespace cusim { + +// reference: http://web.science.mq.edu.au/~mjohnson/code/digamma.c +__inline__ __device__ +float Digamma(float x) { + float result = 0.0f, xx, xx2, xx4; + for ( ; x < 7.0f; ++x) + result -= 1.0f / x; + x -= 0.5f; + xx = 1.0f / x; + xx2 = xx * xx; + xx4 = xx2 * xx2; + result += logf(x) + 1.0f / 24.0f * xx2 + - 7.0f / 960.0f * xx4 + 31.0f / 8064.0f * xx4 * xx2 + - 127.0f / 30720.0f * xx4 * xx4; + return result; +} + +__global__ void EstepKernel( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr, + const int num_topics, const int num_iters, + float* gamma, float* new_gamma, float* phi, + const float* alpha, const float* beta, + float* grad_alpha, float* new_beta, + float* train_losses, float* vali_losses, int* mutex) { + + // storage for block + float* _gamma = gamma + num_topics * blockIdx.x; + float* _new_gamma = new_gamma + num_topics * blockIdx.x; + float* _phi = phi + num_topics * blockIdx.x; + float* _grad_alpha = grad_alpha + num_topics * blockIdx.x; + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + // initialize gamma + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _gamma[j] = alpha[j] + (end - beg) / num_topics; + __syncthreads(); + + // iterate E step + for (int j = 0; j < num_iters; ++j) { + // initialize new gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _new_gamma[k] = 0.0f; + __syncthreads(); + + // compute phi from gamma + for (int k = beg; k < end; ++k) { + const int w = cols[k]; + const bool _vali = vali[k]; + + // compute phi + if (not _vali or j + 1 == num_iters) { + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) + _phi[l] = beta[w * num_topics + l] * expf(Digamma(_gamma[l])); + __syncthreads(); + + // normalize phi and add it to new gamma and new beta + float phi_sum = ReduceSum(_phi, num_topics); + + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + _phi[l] /= phi_sum; + if (not _vali) _new_gamma[l] += _phi[l]; + } + __syncthreads(); + } + + if (j + 1 == num_iters) { + // write access of w th vector of new_beta + if (threadIdx.x == 0) { + while (atomicCAS(&mutex[w], 0, 1)) {} + } + + __syncthreads(); + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + if (j + 1 == num_iters) { + if (not _vali) new_beta[w * num_topics + l] += _phi[l]; + _phi[l] *= beta[w * num_topics + l]; + } + } + __syncthreads(); + + // release lock + if (threadIdx.x == 0) mutex[w] = 0; + __syncthreads(); + + float p = fmaxf(EPS, ReduceSum(_phi, num_topics)); + if (threadIdx.x == 0) { + if (_vali) + vali_losses[blockIdx.x] += logf(p); + else + train_losses[blockIdx.x] += logf(p); + } + } + __syncthreads(); + } + + // update gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _gamma[k] = _new_gamma[k] + alpha[k]; + __syncthreads(); + } + float gamma_sum = ReduceSum(_gamma, num_topics); + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum)); + + __syncthreads(); + } +} + +} // cusim diff --git a/cpp/include/culda/culda.hpp b/cpp/include/culda/culda.hpp new file mode 100644 index 0000000..3a126cc --- /dev/null +++ b/cpp/include/culda/culda.hpp @@ -0,0 +1,88 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT + +#include "json11.hpp" +#include "utils/log.hpp" +#include "utils/types.hpp" + +namespace cusim { + + +// reference: https://people.math.sc.edu/Burkardt/cpp_src/asa121/asa121.cpp +inline float Trigamma(float x) { + const float a = 0.0001f; + const float b = 5.0f; + const float b2 = 0.1666666667f; + const float b4 = -0.03333333333f; + const float b6 = 0.02380952381f; + const float b8 = -0.03333333333f; + float value = 0, y = 0, z = x; + if (x <= a) return 1.0f / x / x; + while (z < b) { + value += 1.0f / z / z; + z++; + } + y = 1.0f / z / z; + value += value + 0.5 * y + (1.0 + + y * (b2 + + y * (b4 + + y * (b6 + + y * b8)))) / z; + return value; +} + + +class CuLDA { + public: + CuLDA(); + ~CuLDA(); + bool Init(std::string opt_path); + void LoadModel(float* alpha, float* beta, + float* grad_alpha, float* new_beta, const int num_words); + std::pair FeedData( + const int* indices, const int* indptr, const bool* vali, + const int num_indices, const int num_indptr, const int num_iters); + void Pull(); + void Push(); + int GetBlockCnt(); + + private: + DeviceInfo dev_info_; + json11::Json opt_; + std::shared_ptr logger_; + thrust::device_vector dev_alpha_, dev_beta_; + thrust::device_vector dev_grad_alpha_, dev_new_beta_; + thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; + thrust::device_vector dev_mutex_; + + float *alpha_, *beta_, *grad_alpha_, *new_beta_; + int block_cnt_, block_dim_; + int num_topics_, num_words_; +}; + +} // namespace cusim diff --git a/cpp/include/types.hpp b/cpp/include/types.hpp deleted file mode 100644 index 3af46c7..0000000 --- a/cpp/include/types.hpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2020 Jisang Yoon -// All rights reserved. -// -// This source code is licensed under the Apache 2.0 license found in the -// LICENSE file in the root directory of this source tree. -#pragma once -#include - -// experimental codes to use half precision -// not properly working yet.. -// #define HALF_PRECISION 1 - -// #if __CUDA_ARCH__ < 530 -// #undef HALF_PRECISION -// #endif - -#ifdef HALF_PRECISION - typedef half cuda_scalar; - #define mul(x, y) ( __hmul(x, y) ) - #define add(x, y) ( __hadd(x, y) ) - #define sub(x, y) ( __hsub(x, y) ) - #define gt(x, y) ( __hgt(x, y) ) // x > y - #define ge(x, y) ( __hge(x, y) ) // x >= y - #define lt(x, y) ( __hlt(x, y) ) // x < y - #define le(x, y) ( __hle(x, y) ) // x <= y - #define out_scalar(x) ( __half2float(x) ) - #define conversion(x) ( __float2half(x) ) -#else - typedef float cuda_scalar; - #define mul(x, y) ( x * y ) - #define add(x, y) ( x + y ) - #define sub(x, y) ( x - y ) - #define gt(x, y) ( x > y ) - #define ge(x, y) ( x >= y ) - #define lt(x, y) ( x < y ) - #define le(x, y) ( x <= y ) - #define out_scalar(x) ( x ) - #define conversion(x) ( x ) -#endif - -#define WARP_SIZE 32 diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh new file mode 100644 index 0000000..026da7f --- /dev/null +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -0,0 +1,172 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include "utils/types.hpp" + +namespace cusim { + +// Error Checking utilities, checks status codes from cuda calls +// and throws exceptions on failure (which cython can proxy back to python) +#define CHECK_CUDA(code) { checkCuda((code), __FILE__, __LINE__); } +inline void checkCuda(cudaError_t code, const char *file, int line) { + if (code != cudaSuccess) { + std::stringstream err; + err << "Cuda Error: " << cudaGetErrorString(code) << " (" << file << ":" << line << ")"; + throw std::runtime_error(err.str()); + } +} + +inline const char* cublasGetErrorString(cublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + } + return "Unknown"; +} + +#define CHECK_CUBLAS(code) { checkCublas((code), __FILE__, __LINE__); } +inline void checkCublas(cublasStatus_t code, const char * file, int line) { + if (code != CUBLAS_STATUS_SUCCESS) { + std::stringstream err; + err << "cublas error: " << cublasGetErrorString(code) + << " (" << file << ":" << line << ")"; + throw std::runtime_error(err.str()); + } +} + +inline DeviceInfo GetDeviceInfo() { + DeviceInfo ret; + CHECK_CUDA(cudaGetDevice(&ret.devId)); + cudaDeviceProp prop; + CHECK_CUDA(cudaGetDeviceProperties(&prop, ret.devId)); + ret.mp_cnt = prop.multiProcessorCount; + ret.major = prop.major; + ret.minor = prop.minor; + // reference: https://stackoverflow.com/a/32531982 + switch (ret.major) { + case 2: // Fermi + if (ret.minor == 1) + ret.cores = ret.mp_cnt * 48; + else + ret.cores = ret.mp_cnt * 32; + break; + case 3: // Kepler + ret.cores = ret.mp_cnt * 192; + break; + case 5: // Maxwell + ret.cores = ret.mp_cnt * 128; + break; + case 6: // Pascal + if (ret.minor == 1 or ret.minor == 2) + ret.cores = ret.mp_cnt * 128; + else if (ret.minor == 0) + ret.cores = ret.mp_cnt * 64; + else + ret.unknown = true; + break; + case 7: // Volta and Turing + if (ret.minor == 0 or ret.minor == 5) + ret.cores = ret.mp_cnt * 64; + else + ret.unknown = true; + break; + case 8: // Ampere + if (ret.minor == 0) + ret.cores = ret.mp_cnt * 64; + else if (ret.minor == 6) + ret.cores = ret.mp_cnt * 128; + else + ret.unknown = true; + break; + default: + ret.unknown = true; + break; + } + if (ret.cores == -1) ret.cores = ret.mp_cnt * 128; + return ret; +} + +__inline__ __device__ +float warp_reduce_sum(float val) { + #if __CUDACC_VER_MAJOR__ >= 9 + // __shfl_down is deprecated with cuda 9+. use newer variants + unsigned int active = __activemask(); + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down_sync(active, val, offset); + } + #else + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down(val, offset); + } + #endif + return val; +} + +__inline__ __device__ +float ReduceSum(const float* vec, const int length) { + + static __shared__ float shared[32]; + + // figure out the warp/ position inside the warp + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // paritial sum + float val = 0.0f; + for (int i = threadIdx.x; i < length; i += blockDim.x) + val += vec[i]; + val = warp_reduce_sum(val); + + // write out the partial reduction to shared memory if appropiate + if (lane == 0) { + shared[warp] = val; + } + __syncthreads(); + + // if we we don't have multiple warps, we're done + if (blockDim.x <= WARP_SIZE) { + return shared[0]; + } + + // otherwise reduce again in the first warp + val = (threadIdx.x < blockDim.x / WARP_SIZE) ? shared[lane]: 0.0f; + if (warp == 0) { + val = warp_reduce_sum(val); + // broadcast back to shared memory + if (threadIdx.x == 0) { + shared[0] = val; + } + } + __syncthreads(); + return shared[0]; +} + +} // namespace cusim diff --git a/cpp/include/ioutils.hpp b/cpp/include/utils/ioutils.hpp similarity index 90% rename from cpp/include/ioutils.hpp rename to cpp/include/utils/ioutils.hpp index 3e8304d..756b4b2 100644 --- a/cpp/include/ioutils.hpp +++ b/cpp/include/utils/ioutils.hpp @@ -21,8 +21,7 @@ #include #include "json11.hpp" -#include "log.hpp" -#include "types.hpp" +#include "utils/log.hpp" namespace cusim { @@ -35,12 +34,12 @@ class IoUtils { std::pair ReadStreamForVocab(int num_lines, int num_threads); std::pair TokenizeStream(int num_lines, int num_threads); void GetWordVocab(int min_count, std::string keys_path); - void GetToken(int* indices, int* indptr, int offset); + void GetToken(int* rows, int* cols, int* indptr); private: void ParseLine(std::string line, std::vector& line_vec); void ParseLineImpl(std::string line, std::vector& line_vec); - std::vector> indices_; + std::vector> cols_; std::vector indptr_; std::mutex global_lock_; std::ifstream stream_fin_; diff --git a/cpp/include/log.hpp b/cpp/include/utils/log.hpp similarity index 100% rename from cpp/include/log.hpp rename to cpp/include/utils/log.hpp diff --git a/cpp/include/utils/types.hpp b/cpp/include/utils/types.hpp new file mode 100644 index 0000000..87e864b --- /dev/null +++ b/cpp/include/utils/types.hpp @@ -0,0 +1,14 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once + +struct DeviceInfo { + int devId, mp_cnt, major, minor, cores; + bool unknown = false; +}; + +#define WARP_SIZE 32 +#define EPS 1e-10f diff --git a/cpp/src/culda.cu b/cpp/src/culda.cu deleted file mode 100644 index 0c12182..0000000 --- a/cpp/src/culda.cu +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2020 Jisang Yoon -// All rights reserved. -// -// This source code is licensed under the Apache 2.0 license found in the -// LICENSE file in the root directory of this source tree. -#include "culda.hpp" - -namespace cusim { - -CuLDA::CuLDA() { - logger_ = CuSimLogger().get_logger(); -} - -CuLDA::~CuLDA() {} - -} // namespace cusim diff --git a/cpp/src/culda/culda.cu b/cpp/src/culda/culda.cu new file mode 100644 index 0000000..9217902 --- /dev/null +++ b/cpp/src/culda/culda.cu @@ -0,0 +1,135 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "culda/culda.hpp" +#include "culda/cuda_lda_kernels.cuh" + +namespace cusim { + +CuLDA::CuLDA() { + logger_ = CuSimLogger().get_logger(); + dev_info_ = GetDeviceInfo(); + if (dev_info_.unknown) DEBUG0("Unknown device type"); + INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", + dev_info_.major, dev_info_.minor, dev_info_.mp_cnt, dev_info_.cores); +} + +CuLDA::~CuLDA() {} + +bool CuLDA::Init(std::string opt_path) { + std::ifstream in(opt_path.c_str()); + if (not in.is_open()) return false; + + std::string str((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + std::string err_cmt; + auto _opt = json11::Json::parse(str, err_cmt); + if (not err_cmt.empty()) return false; + opt_ = _opt; + CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + num_topics_ = opt_["num_topics"].int_value(); + block_dim_ = opt_["block_dim"].int_value(); + block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); + INFO("num_topics: {}, block_dim: {}, block_cnt: {}", num_topics_, block_dim_, block_cnt_); + return true; +} + +void CuLDA::LoadModel(float* alpha, float* beta, + float* grad_alpha, float* new_beta, int num_words) { + num_words_ = num_words; + DEBUG("copy model({} x {})", num_words_, num_topics_); + dev_alpha_.resize(num_topics_); + dev_beta_.resize(num_topics_ * num_words_); + thrust::copy(alpha, alpha + num_topics_, dev_alpha_.begin()); + thrust::copy(beta, beta + num_topics_ * num_words_, dev_beta_.begin()); + alpha_ = alpha; beta_ = beta; + + // resize device vector + grad_alpha_ = grad_alpha; + new_beta_ = new_beta; + dev_grad_alpha_.resize(num_topics_ * block_cnt_); + dev_new_beta_.resize(num_topics_ * num_words_); + // copy to device + thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); + thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); + dev_gamma_.resize(num_topics_ * block_cnt_); + dev_new_gamma_.resize(num_topics_ * block_cnt_); + dev_phi_.resize(num_topics_ * block_cnt_); + + // set mutex + dev_mutex_.resize(num_words_); + std::vector host_mutex(num_words_, 0); + thrust::copy(host_mutex.begin(), host_mutex.end(), dev_mutex_.begin()); + + CHECK_CUDA(cudaDeviceSynchronize()); +} + +std::pair CuLDA::FeedData( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr, const int num_iters) { + + // copy feed data to GPU memory + thrust::device_vector dev_cols(num_cols); + thrust::device_vector dev_indptr(num_indptr + 1); + thrust::device_vector dev_vali(num_cols); + thrust::device_vector dev_train_losses(block_cnt_, 0.0f); + thrust::device_vector dev_vali_losses(block_cnt_, 0.0f); + thrust::copy(cols, cols + num_cols, dev_cols.begin()); + thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); + thrust::copy(vali, vali + num_cols, dev_vali.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("copy feed data to GPU memory"); + + // run E step in GPU + EstepKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_vali.data()), + num_cols, num_indptr, num_topics_, num_iters, + thrust::raw_pointer_cast(dev_gamma_.data()), + thrust::raw_pointer_cast(dev_new_gamma_.data()), + thrust::raw_pointer_cast(dev_phi_.data()), + thrust::raw_pointer_cast(dev_alpha_.data()), + thrust::raw_pointer_cast(dev_beta_.data()), + thrust::raw_pointer_cast(dev_grad_alpha_.data()), + thrust::raw_pointer_cast(dev_new_beta_.data()), + thrust::raw_pointer_cast(dev_train_losses.data()), + thrust::raw_pointer_cast(dev_vali_losses.data()), + thrust::raw_pointer_cast(dev_mutex_.data())); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("run E step in GPU"); + + // pull loss + std::vector train_losses(block_cnt_), vali_losses(block_cnt_); + thrust::copy(dev_train_losses.begin(), dev_train_losses.end(), train_losses.begin()); + thrust::copy(dev_vali_losses.begin(), dev_vali_losses.end(), vali_losses.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("pull loss values"); + + // accumulate + float train_loss = std::accumulate(train_losses.begin(), train_losses.end(), 0.0f); + float vali_loss = std::accumulate(vali_losses.begin(), vali_losses.end(), 0.0f); + return {train_loss, vali_loss}; +} + +void CuLDA::Pull() { + thrust::copy(dev_grad_alpha_.begin(), dev_grad_alpha_.end(), grad_alpha_); + thrust::copy(dev_new_beta_.begin(), dev_new_beta_.end(), new_beta_); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +void CuLDA::Push() { + thrust::copy(alpha_, alpha_ + num_topics_, dev_alpha_.begin()); + thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); + thrust::copy(beta_, beta_ + num_words_ * num_topics_, dev_beta_.begin()); + thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +int CuLDA::GetBlockCnt() { + return block_cnt_; +} + +} // namespace cusim diff --git a/cpp/src/ioutils.cc b/cpp/src/utils/ioutils.cc similarity index 87% rename from cpp/src/ioutils.cc rename to cpp/src/utils/ioutils.cc index 45d551b..14bd94e 100644 --- a/cpp/src/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -3,7 +3,7 @@ // // This source code is licensed under the Apache 2.0 license found in the // LICENSE file in the root directory of this source tree. -#include "ioutils.hpp" +#include "utils/ioutils.hpp" namespace cusim { @@ -37,10 +37,10 @@ void IoUtils::ParseLineImpl(std::string line, std::vector& ret) { int n = line.size(); std::string element; for (int i = 0; i < n; ++i) { - if (line[i] == ' ' or line[i] == ',') { + if (line[i] == ' ') { ret.push_back(element); element.clear(); - } else if (line[i] != '"') { + } else { element += std::tolower(line[i]); } } @@ -69,8 +69,8 @@ std::pair IoUtils::TokenizeStream(int num_lines, int num_threads) { int read_lines = std::min(num_lines, remain_lines_); if (not read_lines) return {0, 0}; remain_lines_ -= read_lines; - indices_.clear(); - indices_.resize(read_lines); + cols_.clear(); + cols_.resize(read_lines); indptr_.resize(read_lines); std::fill(indptr_.begin(), indptr_.end(), 0); #pragma omp parallel num_threads(num_threads) @@ -90,28 +90,29 @@ std::pair IoUtils::TokenizeStream(int num_lines, int num_threads) { // tokenize for (auto& word: line_vec) { - if (not word_count_.count(word)) continue; - indices_[i].push_back(word_count_[word]); + if (not word_idmap_.count(word)) continue; + cols_[i].push_back(word_idmap_[word]); } } } int cumsum = 0; for (int i = 0; i < read_lines; ++i) { - cumsum += indices_[i].size(); + cumsum += cols_[i].size(); indptr_[i] = cumsum; } return {read_lines, indptr_[read_lines - 1]}; } -void IoUtils::GetToken(int* indices, int* indptr, int offset) { - int n = indices_.size(); +void IoUtils::GetToken(int* rows, int* cols, int* indptr) { + int n = cols_.size(); for (int i = 0; i < n; ++i) { int beg = i == 0? 0: indptr_[i - 1]; int end = indptr_[i]; for (int j = beg; j < end; ++j) { - indices[j] = indices_[i][j - beg]; + rows[j] = i; + cols[j] = cols_[i][j - beg]; } - indptr[i] = offset + indptr_[i]; + indptr[i] = indptr_[i]; } } @@ -154,6 +155,7 @@ std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) void IoUtils::GetWordVocab(int min_count, std::string keys_path) { INFO("number of raw words: {}", word_count_.size()); + word_idmap_.clear(); word_list_.clear(); for (auto& it: word_count_) { if (it.second >= min_count) { word_idmap_[it.first] = word_idmap_.size(); @@ -165,11 +167,9 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path) { // write keys to csv file std::ofstream fout(keys_path.c_str()); INFO("dump keys to {}", keys_path); - std::string header = "index,key\n"; - fout.write(header.c_str(), header.size()); int n = word_list_.size(); for (int i = 0; i < n; ++i) { - std::string line = std::to_string(i) + ",\"" + word_list_[i] + "\"\n"; + std::string line = word_list_[i] + "\n"; fout.write(line.c_str(), line.size()); } fout.close(); diff --git a/cpp/src/log.cc b/cpp/src/utils/log.cc similarity index 97% rename from cpp/src/log.cc rename to cpp/src/utils/log.cc index ef5252b..ddfcb0c 100644 --- a/cpp/src/log.cc +++ b/cpp/src/utils/log.cc @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // reference: https://github.com/kakao/buffalo/blob/5f571c2c7d8227e6625c6e538da929e4db11b66d/lib/misc/log.cc -#include "log.hpp" +#include "utils/log.hpp" namespace cusim { diff --git a/cusim/__init__.py b/cusim/__init__.py index 796d7b2..24fe984 100644 --- a/cusim/__init__.py +++ b/cusim/__init__.py @@ -4,3 +4,4 @@ # This source code is licensed under the Apache 2.0 license found in the # LICENSE file in the root directory of this source tree. from cusim.ioutils import IoUtils +from cusim.culda import CuLDA diff --git a/cusim/culda/__init__.py b/cusim/culda/__init__.py new file mode 100644 index 0000000..e27fb6a --- /dev/null +++ b/cusim/culda/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.culda.pyculda import CuLDA diff --git a/cusim/culda/bindings.cc b/cusim/culda/bindings.cc new file mode 100644 index 0000000..f85b2d5 --- /dev/null +++ b/cusim/culda/bindings.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include +#include +#include + +#include +#include "culda/culda.hpp" + +namespace py = pybind11; + +typedef py::array_t float_array; +typedef py::array_t int_array; +typedef py::array_t bool_array; + +class CuLDABind { + public: + CuLDABind() {} + + bool Init(std::string opt_path) { + return obj_.Init(opt_path); + } + + void LoadModel(py::object& alpha, py::object& beta, + py::object& grad_alpha, py::object& new_beta) { + // check shape of alpha and beta + float_array _alpha(alpha); + float_array _beta(beta); + auto alpha_buffer = _alpha.request(); + auto beta_buffer = _beta.request(); + if (alpha_buffer.ndim != 1 or beta_buffer.ndim != 2 or + alpha_buffer.shape[0] != beta_buffer.shape[1]) { + throw std::runtime_error("invalid alpha or beta"); + } + + // check shape of grad alpha and new beta + float_array _grad_alpha(grad_alpha); + float_array _new_beta(new_beta); + auto grad_alpha_buffer = _grad_alpha.request(); + auto new_beta_buffer = _new_beta.request(); + if (grad_alpha_buffer.ndim != 2 or + new_beta_buffer.ndim != 2 or + grad_alpha_buffer.shape[1] != new_beta_buffer.shape[1]) { + throw std::runtime_error("invalid grad_alpha or new_beta"); + } + + int num_words = beta_buffer.shape[0]; + + return obj_.LoadModel(_alpha.mutable_data(0), + _beta.mutable_data(0), + _grad_alpha.mutable_data(0), + _new_beta.mutable_data(0), num_words); + } + + std::pair FeedData(py::object& cols, py::object& indptr, py::object& vali, const int num_iters) { + int_array _cols(cols); + int_array _indptr(indptr); + bool_array _vali(vali); + auto cols_buffer = _cols.request(); + auto indptr_buffer = _indptr.request(); + auto vali_buffer = _vali.request(); + if (cols_buffer.ndim != 1 or indptr_buffer.ndim != 1 or vali_buffer.ndim != 1 + or cols_buffer.shape[0] != vali_buffer.shape[0]) { + throw std::runtime_error("invalid cols or indptr"); + } + int num_cols = cols_buffer.shape[0]; + int num_indptr = indptr_buffer.shape[0] - 1; + return obj_.FeedData(_cols.data(0), _indptr.data(0), _vali.data(0), + num_cols, num_indptr, num_iters); + } + + void Pull() { + obj_.Pull(); + } + + void Push() { + obj_.Push(); + } + + int GetBlockCnt() { + return obj_.GetBlockCnt(); + } + + private: + cusim::CuLDA obj_; +}; + +PYBIND11_PLUGIN(culda_bind) { + py::module m("CuLDABind"); + + py::class_(m, "CuLDABind") + .def(py::init()) + .def("init", &CuLDABind::Init, py::arg("opt_path")) + .def("load_model", &CuLDABind::LoadModel, + py::arg("alpha"), py::arg("beta"), + py::arg("grad_alpha"), py::arg("new_beta")) + .def("feed_data", &CuLDABind::FeedData, + py::arg("cols"), py::arg("indptr"), py::arg("vali"), py::arg("num_iters")) + .def("pull", &CuLDABind::Pull) + .def("push", &CuLDABind::Push) + .def("get_block_cnt", &CuLDABind::GetBlockCnt) + .def("__repr__", + [](const CuLDABind &a) { + return ""; + } + ); + return m.ptr(); +} diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py new file mode 100644 index 0000000..e0fa3d9 --- /dev/null +++ b/cusim/culda/pyculda.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member +import os +from os.path import join as pjoin + +import json +import tempfile + +import h5py +import numpy as np +from scipy.special import polygamma as pg + +from cusim import aux, IoUtils +from cusim.culda.culda_bind import CuLDABind +from cusim.config_pb2 import CuLDAConfigProto + +EPS = 1e-10 + +class CuLDA: + def __init__(self, opt=None): + self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) + self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) + tmp.write(opt_content) + tmp.close() + + self.logger.info("opt: %s", opt_content) + self.obj = CuLDABind() + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" + os.remove(tmp.name) + + self.words, self.num_words, self.num_docs = None, None, None + self.alpha, self.beta, self.grad_alpha, self.new_beta = \ + None, None, None, None + + def preprocess_data(self): + if self.opt.skip_preprocess: + return + iou = IoUtils() + if not self.opt.processed_data_dir: + self.opt.processed_data_dir = tempfile.TemporaryDirectory().name + iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, + self.opt.processed_data_dir) + + def init_model(self): + # load voca + data_dir = self.opt.processed_data_dir + self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) + with open(pjoin(data_dir, "keys.txt"), "rb") as fin: + self.words = [line.strip() for line in fin] + self.num_words = len(self.words) + + # count number of docs + h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") + self.num_docs = h5f["indptr"].shape[0] - 1 + h5f.close() + + self.logger.info("number of words: %d, docs: %d", + self.num_words, self.num_docs) + + # random initialize alpha and beta + np.random.seed(self.opt.seed) + self.alpha = np.random.uniform( \ + size=(self.opt.num_topics,)).astype(np.float32) + self.beta = np.random.uniform( \ + size=(self.num_words, self.opt.num_topics)).astype(np.float32) + self.beta /= np.sum(self.beta, axis=0)[None, :] + self.logger.info("alpha %s, beta %s initialized", + self.alpha.shape, self.beta.shape) + + # zero initialize grad alpha and new beta + block_cnt = self.obj.get_block_cnt() + self.grad_alpha = np.zeros(shape=(block_cnt, self.opt.num_topics), + dtype=np.float32) + self.new_beta = np.zeros(shape=self.beta.shape, dtype=np.float32) + self.logger.info("grad alpha %s, new beta %s initialized", + self.grad_alpha.shape, self.new_beta.shape) + + # push it to gpu + self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) + + def train_model(self): + self.preprocess_data() + self.init_model() + h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") + for epoch in range(1, self.opt.epochs + 1): + self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) + self._train_e_step(h5f) + self._train_m_step() + h5f.close() + + def _train_e_step(self, h5f): + offset, size = 0, h5f["cols"].shape[0] + pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) + train_loss_nume, train_loss_deno = 0, 0 + vali_loss_nume, vali_loss_deno = 0, 0 + while True: + target = h5f["indptr"][offset] + self.opt.batch_size + if target < size: + next_offset = h5f["rows"][target] + else: + next_offset = h5f["indptr"].shape[0] - 1 + indptr = h5f["indptr"][offset:next_offset + 1] + beg, end = indptr[0], indptr[-1] + indptr -= beg + cols = h5f["cols"][beg:end] + vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) + offset = next_offset + + # call cuda kernel + train_loss, vali_loss = \ + self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) + + # accumulate loss + train_loss_nume -= train_loss + vali_loss_nume -= vali_loss + vali_cnt = np.count_nonzero(vali) + train_cnt = len(vali) - vali_cnt + train_loss_deno += train_cnt + vali_loss_deno += vali_cnt + train_loss = train_loss_nume / (train_loss_deno + EPS) + vali_loss = vali_loss_nume / (vali_loss_deno + EPS) + + # update progress bar + pbar.update(end, values=[("train_loss", train_loss), + ("vali_loss", vali_loss)]) + if end == size: + break + + def _train_m_step(self): + self.obj.pull() + + # update beta + self.new_beta[:, :] = np.maximum(self.new_beta, EPS) + self.beta[:, :] = self.new_beta / np.sum(self.new_beta, axis=0)[None, :] + self.new_beta[:, :] = 0 + + # update alpha + alpha_sum = np.sum(self.alpha) + gvec = np.sum(self.grad_alpha, axis=0) + gvec += self.num_docs * (pg(0, alpha_sum) - pg(0, self.alpha)) + hvec = self.num_docs * pg(1, self.alpha) + z_0 = pg(1, alpha_sum) + c_nume = np.sum(gvec / hvec) + c_deno = 1 / z_0 + np.sum(1 / hvec) + c_0 = c_nume / c_deno + delta = (gvec - c_0) / hvec + self.alpha -= delta + self.alpha[:] = np.maximum(self.alpha, EPS) + self.grad_alpha[:,:] = 0 + + self.obj.push() + + def save_model(self, model_path): + self.logger.info("save model path: %s", model_path) + h5f = h5py.File(model_path, "w") + h5f.create_dataset("alpha", data=self.alpha) + h5f.create_dataset("beta", data=self.beta) + h5f.create_dataset("keys", data=np.array(self.words)) + h5f.close() diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc index 5b2c1dd..28fbbc8 100644 --- a/cusim/ioutils/bindings.cc +++ b/cusim/ioutils/bindings.cc @@ -8,7 +8,7 @@ #include #include -#include "ioutils.hpp" +#include "utils/ioutils.hpp" namespace py = pybind11; @@ -39,10 +39,11 @@ class IoUtilsBind { obj_.GetWordVocab(min_count, keys_path); } - void GetToken(py::object& indices, py::object& indptr, int offset) { - int_array _indices(indices); + void GetToken(py::object& rows, py::object& cols, py::object& indptr) { + int_array _rows(rows); + int_array _cols(cols); int_array _indptr(indptr); - obj_.GetToken(_indices.mutable_data(0), _indptr.mutable_data(0), offset); + obj_.GetToken(_rows.mutable_data(0), _cols.mutable_data(0), _indptr.mutable_data(0)); } private: diff --git a/cusim/ioutils/pyioutils.py b/cusim/ioutils/pyioutils.py index 1a65f74..5bce9b7 100644 --- a/cusim/ioutils/pyioutils.py +++ b/cusim/ioutils/pyioutils.py @@ -49,9 +49,10 @@ def load_stream_vocab(self, filepath, min_count, keys_path): self.obj.get_word_vocab(min_count, keys_path) def convert_stream_to_h5(self, filepath, min_count, out_dir, - chunk_indices=10000): + chunk_indices=10000, seed=777): + np.random.seed(seed) os.makedirs(out_dir, exist_ok=True) - keys_path = pjoin(out_dir, "keys.csv") + keys_path = pjoin(out_dir, "keys.txt") token_path = pjoin(out_dir, "token.h5") self.logger.info("save key and token to %s, %s", keys_path, token_path) @@ -60,9 +61,15 @@ def convert_stream_to_h5(self, filepath, min_count, out_dir, pbar = aux.Progbar(full_num_lines, unit_name="line") processed = 0 h5f = h5py.File(token_path, "w") - indices = h5f.create_dataset("indices", shape=(chunk_indices,), - maxshape=(None,), dtype=np.int32, - chunks=(chunk_indices,)) + rows = h5f.create_dataset("rows", shape=(chunk_indices,), + maxshape=(None,), dtype=np.int32, + chunks=(chunk_indices,)) + cols = h5f.create_dataset("cols", shape=(chunk_indices,), + maxshape=(None,), dtype=np.int32, + chunks=(chunk_indices,)) + vali = h5f.create_dataset("vali", shape=(chunk_indices,), + maxshape=(None,), dtype=np.float32, + chunks=(chunk_indices,)) indptr = h5f.create_dataset("indptr", shape=(full_num_lines + 1,), dtype=np.int32, chunks=True) processed, offset = 1, 0 @@ -70,12 +77,18 @@ def convert_stream_to_h5(self, filepath, min_count, out_dir, while True: read_lines, data_size = self.obj.tokenize_stream( self.opt.chunk_lines, self.opt.num_threads) - _indices = np.empty(shape=(data_size,), dtype=np.int32) + _rows = np.empty(shape=(data_size,), dtype=np.int32) + _cols = np.empty(shape=(data_size,), dtype=np.int32) _indptr = np.empty(shape=(read_lines,), dtype=np.int32) - self.obj.get_token(_indices, _indptr, offset) - indices.resize((offset + data_size,)) - indices[offset:offset + data_size] = _indices - indptr[processed:processed + read_lines] = _indptr + self.obj.get_token(_rows, _cols, _indptr) + rows.resize((offset + data_size,)) + rows[offset:offset + data_size] = _rows + (processed - 1) + cols.resize((offset + data_size,)) + cols[offset:offset + data_size] = _cols + vali.resize((offset + data_size,)) + vali[offset:offset + data_size] = \ + np.random.uniform(size=(data_size,)).astype(np.float32) + indptr[processed:processed + read_lines] = _indptr + offset offset += data_size processed += read_lines pbar.update(processed - 1) diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index 071184b..10b3820 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -12,3 +12,22 @@ message IoUtilsConfigProto { optional int32 chunk_lines = 3 [default = 100000]; optional int32 num_threads = 4 [default = 4]; } + +message CuLDAConfigProto { + required string data_path = 7; + + optional int32 py_log_level = 1 [default = 2]; + optional int32 c_log_level = 2 [default = 2]; + + optional int32 num_topics = 3 [default = 10]; + optional int32 block_dim = 4 [default = 32]; + optional int32 hyper_threads = 5 [default = 10]; + optional string processed_data_dir = 6; + optional bool skip_preprocess = 8; + optional int32 word_min_count = 9 [default = 5]; + optional int32 batch_size = 10 [default = 100000]; + optional int32 epochs = 11 [default = 10]; + optional int32 num_iters_in_e_step = 12 [default = 5]; + optional double vali_p = 13 [default = 0.2]; + optional int32 seed = 14 [default = 777]; +} diff --git a/examples/example1.py b/examples/example1.py index 6cbdaa9..f9362cb 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -9,16 +9,20 @@ import subprocess import fire +import h5py +import numpy as np from gensim import downloader as api -from cusim import aux, IoUtils +from cusim import aux, IoUtils, CuLDA LOGGER = aux.get_logger() DOWNLOAD_PATH = "./res" # DATASET = "wiki-english-20171001" -DATASET = "fake-news" +DATASET = "quora-duplicate-questions" DATA_PATH = f"./res/{DATASET}.stream.txt" -DATA_PATH2 = f"./res/{DATASET}-converted" +LDA_PATH = f"./res/{DATASET}-lda.h5" +PROCESSED_DATA_DIR = f"./res/{DATASET}-converted" MIN_COUNT = 5 +TOPK = 10 def download(): if os.path.exists(DATA_PATH): @@ -32,11 +36,35 @@ def download(): LOGGER.info("cmd: %s", cmd) subprocess.call(cmd, shell=True) -def run(): +def run_io(): download() iou = IoUtils(opt={"chunk_lines": 10000, "num_threads": 8}) - iou.convert_stream_to_h5(DATA_PATH, 5, DATA_PATH2) + iou.convert_stream_to_h5(DATA_PATH, 5, PROCESSED_DATA_DIR) +def run_lda(): + opt = { + "data_path": DATA_PATH, + "processed_data_dir": PROCESSED_DATA_DIR, + "skip_preprocess":True, + } + lda = CuLDA(opt) + lda.train_model() + lda.save_model(LDA_PATH) + h5f = h5py.File(LDA_PATH, "r") + beta = h5f["beta"][:] + word_list = h5f["keys"][:] + num_topics = h5f["alpha"].shape[0] + for i in range(num_topics): + print("=" * 50) + print(f"topic {i + 1}") + words = np.argsort(-beta.T[i])[:10] + print("-" * 50) + for j in range(TOPK): + word = word_list[words[j]].decode("utf8") + prob = beta[words[j], i] + print(f"rank {j + 1}. word: {word}, prob: {prob}") + h5f.close() + if __name__ == "__main__": fire.Fire() diff --git a/setup.py b/setup.py index dacb5f3..512d262 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 Jisang Yoon +# Copyright (c) 2021 Jisang Yoon # All rights reserved. # # This source code is licensed under the Apache 2.0 license found in the @@ -30,7 +30,8 @@ raise RuntimeError("Python version 3.6 or later required.") assert platform.system() == 'Linux' # TODO: MacOS - +with open("requirements.txt", "r") as fin: + INSTALL_REQUIRES = [line.strip() for line in fin] MAJOR = 0 MINOR = 0 @@ -68,22 +69,35 @@ def __init__(self, name): extend_compile_flags = get_extend_compile_flags() extra_compile_args = ['-fopenmp', '-std=c++14', '-ggdb', '-O3'] + \ extend_compile_flags -csrcs = glob.glob("cpp/src/*.cu") + glob.glob("cpp/src/*.cc") +util_srcs = glob.glob("cpp/src/utils/*.cc") extensions = [ Extension("cusim.ioutils.ioutils_bind", - sources= csrcs + [ \ + sources = util_srcs + [ \ "cusim/ioutils/bindings.cc", "3rd/json11/json11.cpp"], language="c++", extra_compile_args=extra_compile_args, extra_link_args=["-fopenmp"], + extra_objects=[], + include_dirs=[ \ + "cpp/include/", np.get_include(), pybind11.get_include(), + pybind11.get_include(True), + "3rd/json11", "3rd/spdlog/include"]), + Extension("cusim.culda.culda_bind", + sources= util_srcs + [ \ + "cpp/src/culda/culda.cu", + "cusim/culda/bindings.cc", + "3rd/json11/json11.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=["-fopenmp"], library_dirs=[CUDA['lib64']], libraries=['cudart', 'cublas', 'curand'], extra_objects=[], include_dirs=[ \ "cpp/include/", np.get_include(), pybind11.get_include(), pybind11.get_include(True), CUDA['include'], - "3rd/json11", "3rd/spdlog/include"]) + "3rd/json11", "3rd/spdlog/include"]), ] @@ -160,13 +174,16 @@ def setup_package(): name='cusim', maintainer="Jisang Yoon", maintainer_email="vjs10101v@gmail.com", + author="Jisang Yoon", + author_email="vjs10101v@gmail.com", description=DOCLINES[0], long_description="\n".join(DOCLINES[2:]), url="https://github.com/js1010/cusim", download_url="https://github.com/js1010/cusim/releases", include_package_data=False, - license='Apac2', - packages=['cusim/', "cusim/ioutils/"], + license='Apache2', + packages=['cusim/', "cusim/ioutils/", "cusim/culda/"], + install_requires=INSTALL_REQUIRES, cmdclass=cmdclass, classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f], platforms=['Linux', 'Mac OSX', 'Unix'],