diff --git a/cpp/include/culda/culda.hpp b/cpp/include/culda/culda.hpp index 3a126cc..ff52dca 100644 --- a/cpp/include/culda/culda.hpp +++ b/cpp/include/culda/culda.hpp @@ -75,6 +75,7 @@ class CuLDA { DeviceInfo dev_info_; json11::Json opt_; std::shared_ptr logger_; + std::unique_ptr logger_container_; thrust::device_vector dev_alpha_, dev_beta_; thrust::device_vector dev_grad_alpha_, dev_new_beta_; thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh new file mode 100644 index 0000000..8046cf3 --- /dev/null +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -0,0 +1,50 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" + +namespace cusim { + + +__inline__ __device__ +void PositiveFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims, const float lr) { + static __shared__ float g; + float dot = Dot(vec1, vec2, num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(-dot); + g = exp_dot / (1 + exp_dot) * lr; + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + grad[i] += vec2[i] * g; + vec2[i] += vec1[i] * g; + } + __syncthreads(); +} + +__inline__ __device__ +void NegativeFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims, const float lr) { + static __shared__ float g; + float dot = Dot(vec1, vec2, num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(dot); + g = exp_dot / (1 + exp_dot) * lr; + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + grad[i] -= vec2[i] * g; + vec2[i] -= vec1[i] * g; + } + __syncthreads(); +} + +} // cusim diff --git a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh new file mode 100644 index 0000000..c7aaca5 --- /dev/null +++ b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh @@ -0,0 +1,149 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" +#include "cuw2v/cuda_w2v_base_kernels.cuh" + + +namespace cusim { + +__global__ void W2VHsSgKernel( + const int* cols, const int* indptr, + const bool* codes, const int* points, const int* hs_indptr, + const int num_indptr, const int num_dims, const int window_size, + default_random_engine* rngs, + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, const float lr) { + + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + uniform_int_distribution dist_window(0, window_size - 1); + static __shared__ int reduced_windows; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + + // zero-initialize shared mem + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) + grad[i] = 0.0f; + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); + float* _emb_in = emb_in + num_dims * cols[j]; + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + int beg3 = hs_indptr[cols[k]]; + int end3 = hs_indptr[cols[k] + 1]; + for (int l = beg3; l < end3; ++l) { + if (codes[l]) { + PositiveFeedback(_emb_in, emb_out + num_dims * points[l], + grad, _loss_nume, _loss_deno, num_dims, lr); + } else { + NegativeFeedback(_emb_in, emb_out + num_dims * points[l], + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + } + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * cols[j] + l] += grad[l]; + grad[l] = 0.0f; + } + __syncthreads(); + } + } + } +} + +__global__ void W2VHsCbowKernel( + const int* cols, const int* indptr, + const bool* codes, const int* points, const int* hs_indptr, + const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs, + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, + const bool use_mean, const float lr) { + + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + uniform_int_distribution dist_window(0, window_size - 1); + static __shared__ int reduced_windows; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + float* cbow = &shared_memory[num_dims]; + + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); + if (end2 - beg2 <= 1) continue; + + // zero-initialize shared mem + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] = 0.0f; + cbow[k] = 0.0f; + } + + // compute cbow + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + cbow[l] += emb_in[num_dims * cols[k] + l]; + } + } + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + cbow[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + int beg3 = hs_indptr[cols[j]]; + int end3 = hs_indptr[cols[j] + 1]; + for (int k = beg3; k < end3; ++k) { + if (codes[k]) { + PositiveFeedback(cbow, emb_out + num_dims * points[k], + grad, _loss_nume, _loss_deno, num_dims, lr); + } else { + NegativeFeedback(cbow, emb_out + num_dims * points[k], + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + } + + // normalize grad if use_mean = true + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + // update emb_in + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * cols[k] + l] += grad[l]; + } + __syncthreads(); + } + } + } +} + +} // cusim diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh new file mode 100644 index 0000000..8f6bef0 --- /dev/null +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -0,0 +1,147 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" +#include "cuw2v/cuda_w2v_base_kernels.cuh" + + +namespace cusim { + +__global__ void W2VNegSgKernel( + const int* cols, const int* indptr, + const int* random_table, default_random_engine* rngs, const int random_size, + const int num_indptr, const int num_dims, const int neg, const int window_size, + float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) { + + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + uniform_int_distribution dist_neg(0, random_size - 1); + uniform_int_distribution dist_window(0, window_size - 1); + __shared__ int reduced_windows; + __shared__ int neg_word; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + + // zero-initialize shared mem + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) + grad[i] = 0.0f; + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); + float* _emb_in = emb_in + num_dims * cols[j]; + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + PositiveFeedback(_emb_in, emb_out + num_dims * cols[k], + grad, _loss_nume, _loss_deno, num_dims, lr); + for (int l = 0; l < neg; ++l) { + if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; + __syncthreads(); + NegativeFeedback(_emb_in, emb_out + num_dims * neg_word, + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * cols[j] + l] += grad[l]; + grad[l] = 0.0f; + } + __syncthreads(); + } + } + } +} + +__global__ void W2VNegCbowKernel( + const int* cols, const int* indptr, + const int* random_table, default_random_engine* rngs, const int random_size, + const int num_indptr, const int num_dims, const int neg, const int window_size, + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { + + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + uniform_int_distribution dist_neg(0, random_size - 1); + uniform_int_distribution dist_window(0, window_size - 1); + static __shared__ int reduced_windows; + static __shared__ int neg_word; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + float* cbow = &shared_memory[num_dims]; + + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); + if (end2 - beg2 <= 1) continue; + + // zero-initialize shared mem + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] = 0.0f; + cbow[k] = 0.0f; + } + + // compute cbow + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + cbow[l] += emb_in[num_dims * cols[k] + l]; + } + } + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + cbow[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad, + _loss_nume, _loss_deno, num_dims, lr); + __syncthreads(); + + // update negative feedback + for (int k = 0; k < neg; ++k){ + if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; + __syncthreads(); + NegativeFeedback(cbow, emb_out + num_dims * neg_word, + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + + // normalize grad if use_mean = true + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + // update emb_in + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) + emb_in[num_dims * cols[k] + l] += grad[l]; + } + __syncthreads(); + + } + } +} + +} // cusim diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp new file mode 100644 index 0000000..fba7f2b --- /dev/null +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -0,0 +1,76 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT + +#include "json11.hpp" +#include "utils/log.hpp" +#include "utils/types.hpp" + +using thrust::random::default_random_engine; + +namespace cusim { + +class CuW2V { + public: + CuW2V(); + ~CuW2V(); + bool Init(std::string opt_path); + void LoadModel(float* emb_in, float* emb_out); + void BuildHuffmanTree(const float* word_count, const int num_words); + void BuildRandomTable(const float* word_count, const int num_words, + const int table_size, const int num_threads); + int GetBlockCnt(); + std::pair FeedData(const int* cols, const int* indptr, + const int num_cols, const int num_indptr); + void Pull(); + + private: + DeviceInfo dev_info_; + json11::Json opt_; + std::shared_ptr logger_; + std::unique_ptr logger_container_; + int block_cnt_, block_dim_; + int num_dims_, num_words_, window_size_; + float *emb_in_, *emb_out_, lr_; + thrust::device_vector dev_emb_in_, dev_emb_out_; + + // variables to construct huffman tree + int max_depth_; + thrust::device_vector dev_codes_; + thrust::device_vector dev_points_, dev_hs_indptr_; + + // related to negative sampling / hierarchical softmax and skip gram / cbow + bool sg_, use_mean_; + int neg_; + + // variables to construct random table + thrust::device_vector dev_random_table_; + int random_size_, table_seed_, cuda_seed_; + thrust::device_vector dev_rngs_; +}; + +} // namespace cusim diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh index 026da7f..47d447e 100644 --- a/cpp/include/utils/cuda_utils_kernels.cuh +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -23,6 +23,9 @@ #include #include "utils/types.hpp" +using thrust::random::default_random_engine; +using thrust::random::uniform_int_distribution; + namespace cusim { // Error Checking utilities, checks status codes from cuda calls @@ -130,6 +133,45 @@ float warp_reduce_sum(float val) { return val; } +__inline__ __device__ +float Dot(const float* vec1, const float* vec2, const int length) { + + static __shared__ float shared[32]; + + // figure out the warp/ position inside the warp + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // paritial sum + float val = 0.0f; + for (int i = threadIdx.x; i < length; i += blockDim.x) + val += vec1[i] * vec2[i]; + val = warp_reduce_sum(val); + + // write out the partial reduction to shared memory if appropiate + if (lane == 0) { + shared[warp] = val; + } + __syncthreads(); + + // if we we don't have multiple warps, we're done + if (blockDim.x <= WARP_SIZE) { + return shared[0]; + } + + // otherwise reduce again in the first warp + val = (threadIdx.x < blockDim.x / WARP_SIZE) ? shared[lane]: 0.0f; + if (warp == 0) { + val = warp_reduce_sum(val); + // broadcast back to shared memory + if (threadIdx.x == 0) { + shared[0] = val; + } + } + __syncthreads(); + return shared[0]; +} + __inline__ __device__ float ReduceSum(const float* vec, const int length) { @@ -169,4 +211,8 @@ float ReduceSum(const float* vec, const int length) { return shared[0]; } +__global__ void InitRngsKernel(default_random_engine* rngs, int rand_seed) { + rngs[blockIdx.x].seed(blockIdx.x + rand_seed); +} + } // namespace cusim diff --git a/cpp/include/utils/ioutils.hpp b/cpp/include/utils/ioutils.hpp index 756b4b2..54ca2e1 100644 --- a/cpp/include/utils/ioutils.hpp +++ b/cpp/include/utils/ioutils.hpp @@ -33,7 +33,7 @@ class IoUtils { int LoadStreamFile(std::string filepath); std::pair ReadStreamForVocab(int num_lines, int num_threads); std::pair TokenizeStream(int num_lines, int num_threads); - void GetWordVocab(int min_count, std::string keys_path); + void GetWordVocab(int min_count, std::string keys_path, std::string count_path); void GetToken(int* rows, int* cols, int* indptr); private: void ParseLine(std::string line, std::vector& line_vec); @@ -45,6 +45,7 @@ class IoUtils { std::ifstream stream_fin_; json11::Json opt_; std::shared_ptr logger_; + std::unique_ptr logger_container_; std::unordered_map word_idmap_, word_count_; std::vector word_list_; int num_lines_, remain_lines_; diff --git a/cpp/include/utils/log.hpp b/cpp/include/utils/log.hpp index 05f30ec..270e727 100644 --- a/cpp/include/utils/log.hpp +++ b/cpp/include/utils/log.hpp @@ -7,7 +7,7 @@ // reference: https://github.com/kakao/buffalo/blob/5f571c2c7d8227e6625c6e538da929e4db11b66d/lib/misc/log.cc #pragma once #include - +#include #define SPDLOG_EOL "" #define SPDLOG_TRACE_ON #include "spdlog/spdlog.h" @@ -32,6 +32,7 @@ namespace cusim { class CuSimLogger { public: CuSimLogger(); + explicit CuSimLogger(std::string name); std::shared_ptr& get_logger(); void set_log_level(int level); int get_log_level(); diff --git a/cpp/src/culda/culda.cu b/cpp/src/culda/culda.cu index 9217902..c92bfeb 100644 --- a/cpp/src/culda/culda.cu +++ b/cpp/src/culda/culda.cu @@ -9,7 +9,8 @@ namespace cusim { CuLDA::CuLDA() { - logger_ = CuSimLogger().get_logger(); + logger_container_.reset(new CuSimLogger("lda")); + logger_ = logger_container_->get_logger(); dev_info_ = GetDeviceInfo(); if (dev_info_.unknown) DEBUG0("Unknown device type"); INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", @@ -28,7 +29,7 @@ bool CuLDA::Init(std::string opt_path) { auto _opt = json11::Json::parse(str, err_cmt); if (not err_cmt.empty()) return false; opt_ = _opt; - CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + logger_container_->set_log_level(opt_["c_log_level"].int_value()); num_topics_ = opt_["num_topics"].int_value(); block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu new file mode 100644 index 0000000..20f6640 --- /dev/null +++ b/cpp/src/cuw2v/cuw2v.cu @@ -0,0 +1,283 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "cuw2v/cuw2v.hpp" +#include "cuw2v/cuda_w2v_base_kernels.cuh" +#include "cuw2v/cuda_w2v_ns_kernels.cuh" +#include "cuw2v/cuda_w2v_hs_kernels.cuh" + +namespace cusim { + +struct HuffmanTreeNode { + float count; + int index, left, right; + HuffmanTreeNode(float count0, int index0, int left0, int right0) { + count = count0; index = index0; left = left0; right = right0; + } +}; + +std::vector huffman_nodes; +bool CompareIndex(int lhs, int rhs) { + return huffman_nodes[lhs].count > huffman_nodes[rhs].count; +} + +CuW2V::CuW2V() { + logger_container_.reset(new CuSimLogger("w2v")); + logger_ = logger_container_->get_logger(); + dev_info_ = GetDeviceInfo(); + if (dev_info_.unknown) DEBUG0("Unknown device type"); + INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", + dev_info_.major, dev_info_.minor, dev_info_.mp_cnt, dev_info_.cores); +} + +CuW2V::~CuW2V() {} + +bool CuW2V::Init(std::string opt_path) { + std::ifstream in(opt_path.c_str()); + if (not in.is_open()) return false; + + std::string str((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + std::string err_cmt; + auto _opt = json11::Json::parse(str, err_cmt); + if (not err_cmt.empty()) return false; + opt_ = _opt; + logger_container_->set_log_level(opt_["c_log_level"].int_value()); + num_dims_ = opt_["num_dims"].int_value(); + block_dim_ = opt_["block_dim"].int_value(); + block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); + sg_ = opt_["skip_gram"].bool_value(); + use_mean_ = opt_["use_mean"].bool_value(); + window_size_ = opt_["window_size"].int_value(); + lr_ = opt_["lr"].number_value(); + + // if zero, we will use hierarchical softmax + neg_ = opt_["neg"].int_value(); + + // random seed + table_seed_ = opt_["table_seed"].int_value(); + cuda_seed_ = opt_["cuda_seed"].int_value(); + dev_rngs_.resize(block_cnt_); + InitRngsKernel<<>>( + thrust::raw_pointer_cast(dev_rngs_.data()), cuda_seed_); + + INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", + num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); + return true; +} + +void CuW2V::BuildRandomTable(const float* word_count, const int num_words, + const int table_size, const int num_threads) { + num_words_ = num_words; + random_size_ = table_size; + std::vector acc; + float cumsum = 0; + for (int i = 0; i < num_words; ++i) { + acc.push_back(cumsum); + cumsum += word_count[i]; + } + + dev_random_table_.resize(random_size_); + std::vector host_random_table(table_size); + #pragma omp parallel num_threads(num_threads) + { + const unsigned int table_seed = table_seed_ + omp_get_thread_num(); + std::mt19937 rng(table_seed); + std::uniform_real_distribution dist(0.0f, cumsum); + #pragma omp for schedule(static) + for (int i = 0; i < random_size_; ++i) { + float r = dist(rng); + int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin(); + host_random_table[i] = pos; + } + } + table_seed_ += num_threads; + + thrust::copy(host_random_table.begin(), host_random_table.end(), dev_random_table_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { + num_words_ = num_words; + + huffman_nodes.clear(); + std::priority_queue, decltype(&CompareIndex)> pq(CompareIndex); + for (int i = 0; i < num_words; ++i) { + huffman_nodes.emplace_back(word_count[i], i, -1, -1); + pq.push(i); + } + for (int i = 0; i < num_words - 1; ++i) { + auto& min1 = huffman_nodes[pq.top()]; pq.pop(); + auto& min2 = huffman_nodes[pq.top()]; pq.pop(); + huffman_nodes.emplace_back(min1.count + min2.count, i + num_words, min1.index, min2.index); + pq.push(i + num_words); + } + + std::vector, std::vector>> stack = {{pq.top(), {}, {}}}; + int nodeid; + std::vector code; + std::vector point; + std::vector> codes(num_words); + std::vector> points(num_words); + max_depth_ = 0; + while (not stack.empty()) { + std::tie(nodeid, code, point) = stack.back(); + stack.pop_back(); + if (nodeid < num_words) { + codes[nodeid] = code; + points[nodeid] = point; + max_depth_ = std::max(max_depth_, + static_cast(code.size())); + } else { + point.push_back(nodeid - num_words); + std::vector left_code = code; + std::vector right_code = code; + left_code.push_back(false); + right_code.push_back(true); + auto& node = huffman_nodes[nodeid]; + stack.push_back(make_tuple(node.left, left_code, point)); + stack.push_back(make_tuple(node.right, right_code, point)); + } + } + + std::vector host_codes; + std::vector host_points; + std::vector host_hs_indptr = {0}; + int size = 0; + for (int i = 0; i < num_words; ++i) { + code = codes[i]; + point = points[i]; + int n = code.size(); + size += n; + host_hs_indptr.push_back(size); + for (int j = 0; j < n; ++j) { + host_codes.push_back(code[j]); + host_points.push_back(point[j]); + } + } + + dev_codes_.resize(size); dev_points_.resize(size), dev_hs_indptr_.resize(num_words + 1); + thrust::copy(host_codes.begin(), host_codes.end(), dev_codes_.begin()); + thrust::copy(host_points.begin(), host_points.end(), dev_points_.begin()); + thrust::copy(host_hs_indptr.begin(), host_hs_indptr.end(), dev_hs_indptr_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + + huffman_nodes.clear(); +} + +void CuW2V::LoadModel(float* emb_in, float* emb_out) { + int out_words = neg_? num_words_: num_words_ - 1; + + // copy embedding + DEBUG("copy model({} x {})", num_words_, num_dims_); + dev_emb_in_.resize(num_words_ * num_dims_); + dev_emb_out_.resize(out_words * num_dims_); + thrust::copy(emb_in, emb_in + num_words_ * num_dims_, dev_emb_in_.begin()); + thrust::copy(emb_out, emb_out + out_words * num_dims_, dev_emb_out_.begin()); + emb_in_ = emb_in; emb_out_ = emb_out; + + CHECK_CUDA(cudaDeviceSynchronize()); +} + +int CuW2V::GetBlockCnt() { + return block_cnt_; +} + + +std::pair CuW2V::FeedData(const int* cols, const int* indptr, + const int num_cols, const int num_indptr) { + + // copy feed data to GPU memory + thrust::device_vector dev_cols(num_cols); + thrust::device_vector dev_indptr(num_indptr + 1); + thrust::device_vector dev_loss_nume(block_cnt_, 0.0f); + thrust::device_vector dev_loss_deno(block_cnt_, 0.0f); + thrust::copy(cols, cols + num_cols, dev_cols.begin()); + thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("copy feed data to GPU memory"); + + // run GPU kernels + if (neg_ > 0) { + if (sg_) { + W2VNegSgKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_random_table_.data()), + thrust::raw_pointer_cast(dev_rngs_.data()), + random_size_, num_indptr, num_dims_, neg_, window_size_, + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + lr_); + } else { + W2VNegCbowKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_random_table_.data()), + thrust::raw_pointer_cast(dev_rngs_.data()), + random_size_, num_indptr, num_dims_, neg_, window_size_, + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + use_mean_, lr_); + } + } else { + if (sg_) { + W2VHsSgKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_codes_.data()), + thrust::raw_pointer_cast(dev_points_.data()), + thrust::raw_pointer_cast(dev_hs_indptr_.data()), + num_indptr, num_dims_, window_size_, + thrust::raw_pointer_cast(dev_rngs_.data()), + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + lr_); + + } else { + W2VHsCbowKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_codes_.data()), + thrust::raw_pointer_cast(dev_points_.data()), + thrust::raw_pointer_cast(dev_hs_indptr_.data()), + num_indptr, num_dims_, window_size_, + thrust::raw_pointer_cast(dev_rngs_.data()), + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + use_mean_, lr_); + + } + + } + CHECK_CUDA(cudaDeviceSynchronize()); + + // accumulate loss nume / deno + std::vector loss_nume(block_cnt_), loss_deno(block_cnt_); + thrust::copy(dev_loss_nume.begin(), dev_loss_nume.end(), loss_nume.begin()); + thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_deno.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + float loss_nume_sum = std::accumulate(loss_nume.begin(), loss_nume.end(), 0.0f); + float loss_deno_sum = std::accumulate(loss_deno.begin(), loss_deno.end(), 0.0f); + DEBUG("loss nume: {}, deno: {}", loss_nume_sum, loss_deno_sum); + + return {loss_nume_sum, loss_deno_sum}; +} + +void CuW2V::Pull() { + thrust::copy(dev_emb_in_.begin(), dev_emb_in_.end(), emb_in_); + thrust::copy(dev_emb_out_.begin(), dev_emb_out_.end(), emb_out_); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +} // namespace cusim diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc index 14bd94e..53833a3 100644 --- a/cpp/src/utils/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -8,7 +8,8 @@ namespace cusim { IoUtils::IoUtils() { - logger_ = CuSimLogger().get_logger(); + logger_container_.reset(new CuSimLogger("ioutils")); + logger_ = logger_container_->get_logger(); } IoUtils::~IoUtils() {} @@ -23,7 +24,7 @@ bool IoUtils::Init(std::string opt_path) { auto _opt = json11::Json::parse(str, err_cmt); if (not err_cmt.empty()) return false; opt_ = _opt; - CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + logger_container_->set_log_level(opt_["c_log_level"].int_value()); return true; } @@ -153,7 +154,7 @@ std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) return {read_lines, word_count_.size()}; } -void IoUtils::GetWordVocab(int min_count, std::string keys_path) { +void IoUtils::GetWordVocab(int min_count, std::string keys_path, std::string count_path) { INFO("number of raw words: {}", word_count_.size()); word_idmap_.clear(); word_list_.clear(); for (auto& it: word_count_) { @@ -164,15 +165,18 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path) { } INFO("number of words after filtering: {}", word_list_.size()); - // write keys to csv file - std::ofstream fout(keys_path.c_str()); + // write keys and count to csv file + std::ofstream fout1(keys_path.c_str()); + std::ofstream fout2(count_path.c_str()); INFO("dump keys to {}", keys_path); int n = word_list_.size(); for (int i = 0; i < n; ++i) { std::string line = word_list_[i] + "\n"; - fout.write(line.c_str(), line.size()); + fout1.write(line.c_str(), line.size()); + line = std::to_string(word_count_[word_list_[i]]) + "\n"; + fout2.write(line.c_str(), line.size()); } - fout.close(); + fout1.close(); fout2.close(); } } // namespace cusim diff --git a/cpp/src/utils/log.cc b/cpp/src/utils/log.cc index ddfcb0c..cec5fa4 100644 --- a/cpp/src/utils/log.cc +++ b/cpp/src/utils/log.cc @@ -16,6 +16,14 @@ CuSimLogger::CuSimLogger() { logger_ = spdlog::default_logger(); } +CuSimLogger::CuSimLogger(std::string name) { + // auto console_sink = std::make_shared(); + auto stderr_sink = std::make_shared(); + // spdlog::sinks_init_list sinks = {console_sink, stderr_sink}; + logger_ = std::make_shared(name, stderr_sink); + logger_->set_pattern("[%^%-8l%$] %Y-%m-%d %H:%M:%S %v"); +} + std::shared_ptr& CuSimLogger::get_logger() { return logger_; } @@ -23,11 +31,11 @@ std::shared_ptr& CuSimLogger::get_logger() { void CuSimLogger::set_log_level(int level) { global_logging_level_ = level; switch (level) { - case 0: spdlog::set_level(spdlog::level::off); break; - case 1: spdlog::set_level(spdlog::level::warn); break; - case 2: spdlog::set_level(spdlog::level::info); break; - case 3: spdlog::set_level(spdlog::level::debug); break; - default: spdlog::set_level(spdlog::level::trace); break; + case 0: logger_->set_level(spdlog::level::off); break; + case 1: logger_->set_level(spdlog::level::warn); break; + case 2: logger_->set_level(spdlog::level::info); break; + case 3: logger_->set_level(spdlog::level::debug); break; + default: logger_->set_level(spdlog::level::trace); break; } } diff --git a/cusim/__init__.py b/cusim/__init__.py index 24fe984..a1d864c 100644 --- a/cusim/__init__.py +++ b/cusim/__init__.py @@ -5,3 +5,4 @@ # LICENSE file in the root directory of this source tree. from cusim.ioutils import IoUtils from cusim.culda import CuLDA +from cusim.cuw2v import CuW2V diff --git a/cusim/constants.py b/cusim/constants.py new file mode 100644 index 0000000..6d69ce8 --- /dev/null +++ b/cusim/constants.py @@ -0,0 +1,10 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member + +EPS = 1e-10 +WARP_SIZE = 32 diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py index e0fa3d9..dbd0da3 100644 --- a/cusim/culda/pyculda.py +++ b/cusim/culda/pyculda.py @@ -18,14 +18,18 @@ from cusim import aux, IoUtils from cusim.culda.culda_bind import CuLDABind from cusim.config_pb2 import CuLDAConfigProto +from cusim.constants import EPS, WARP_SIZE -EPS = 1e-10 class CuLDA: def __init__(self, opt=None): self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + assert self.opt.block_dim <= WARP_SIZE ** 2 and \ + self.opt.block_dim % WARP_SIZE == 0, \ + f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})" + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) tmp.write(opt_content) diff --git a/cusim/cuw2v/__init__.py b/cusim/cuw2v/__init__.py new file mode 100644 index 0000000..c3acfc5 --- /dev/null +++ b/cusim/cuw2v/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.cuw2v.pycuw2v import CuW2V diff --git a/cusim/cuw2v/bindings.cc b/cusim/cuw2v/bindings.cc new file mode 100644 index 0000000..3ca45d6 --- /dev/null +++ b/cusim/cuw2v/bindings.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include +#include +#include + +#include +#include "cuw2v/cuw2v.hpp" + +namespace py = pybind11; + +typedef py::array_t float_array; +typedef py::array_t int_array; + +class CuW2VBind { + public: + CuW2VBind() {} + + bool Init(std::string opt_path) { + return obj_.Init(opt_path); + } + + void LoadModel(py::object& emb_in, py::object& emb_out) { + // check shape of alpha and beta + float_array _emb_in(emb_in); + float_array _emb_out(emb_out); + auto emb_in_buffer = _emb_in.request(); + auto emb_out_buffer = _emb_out.request(); + if (emb_in_buffer.ndim != 2 or emb_out_buffer.ndim != 2 or + emb_in_buffer.shape[1] != emb_out_buffer.shape[1]) { + throw std::runtime_error("invalid emb_in or emb_out"); + } + + return obj_.LoadModel(_emb_in.mutable_data(0), _emb_out.mutable_data(0)); + } + + void BuildRandomTable(py::object& word_count, int table_size, int num_threads) { + float_array _word_count(word_count); + auto wc_buffer = _word_count.request(); + if (wc_buffer.ndim != 1) { + throw std::runtime_error("invalid word count"); + } + int num_words = wc_buffer.shape[0]; + obj_.BuildRandomTable(_word_count.data(0), num_words, table_size, num_threads); + } + + void BuildHuffmanTree(py::object& word_count) { + float_array _word_count(word_count); + auto wc_buffer = _word_count.request(); + if (wc_buffer.ndim != 1) { + throw std::runtime_error("invalid word count"); + } + int num_words = wc_buffer.shape[0]; + obj_.BuildHuffmanTree(_word_count.data(0), num_words); + } + + std::pair FeedData(py::object& cols, py::object& indptr) { + int_array _cols(cols); + int_array _indptr(indptr); + auto cols_buffer = _cols.request(); + auto indptr_buffer = _indptr.request(); + if (cols_buffer.ndim != 1 or indptr_buffer.ndim != 1) { + throw std::runtime_error("invalid cols or indptr"); + } + int num_cols = cols_buffer.shape[0]; + int num_indptr = indptr_buffer.shape[0] - 1; + return obj_.FeedData(_cols.data(0), _indptr.data(0), num_cols, num_indptr); + } + + void Pull() { + obj_.Pull(); + } + + int GetBlockCnt() { + return obj_.GetBlockCnt(); + } + + private: + cusim::CuW2V obj_; +}; + +PYBIND11_PLUGIN(cuw2v_bind) { + py::module m("CuW2VBind"); + + py::class_(m, "CuW2VBind") + .def(py::init()) + .def("init", &CuW2VBind::Init, py::arg("opt_path")) + .def("load_model", &CuW2VBind::LoadModel, + py::arg("emb_in"), py::arg("emb_out")) + .def("feed_data", &CuW2VBind::FeedData, + py::arg("cols"), py::arg("indptr")) + .def("pull", &CuW2VBind::Pull) + .def("build_random_table", &CuW2VBind::BuildRandomTable, + py::arg("word_count"), py::arg("table_size"), py::arg("num_threads")) + .def("build_huffman_tree", &CuW2VBind::BuildHuffmanTree, + py::arg("word_count")) + .def("get_block_cnt", &CuW2VBind::GetBlockCnt) + .def("__repr__", + [](const CuW2VBind &a) { + return ""; + } + ); + return m.ptr(); +} diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py new file mode 100644 index 0000000..f2bd265 --- /dev/null +++ b/cusim/cuw2v/pycuw2v.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member +import os +from os.path import join as pjoin + +import json +import tempfile + +import h5py +import numpy as np + +from cusim import aux, IoUtils +from cusim.cuw2v.cuw2v_bind import CuW2VBind +from cusim.config_pb2 import CuW2VConfigProto +from cusim.constants import EPS, WARP_SIZE + +class CuW2V: + def __init__(self, opt=None): + self.opt = aux.get_opt_as_proto(opt or {}, CuW2VConfigProto) + self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + + assert self.opt.block_dim <= WARP_SIZE ** 2 and \ + self.opt.block_dim % WARP_SIZE == 0, \ + f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})" + + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) + tmp.write(opt_content) + tmp.close() + + self.logger.info("opt: %s", opt_content) + self.obj = CuW2VBind() + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" + os.remove(tmp.name) + + self.words, self.word_count, self.num_words, self.num_docs = \ + None, None, None, None + self.emb_in, self.emb_out = None, None + + def preprocess_data(self): + if self.opt.skip_preprocess: + return + iou = IoUtils() + if not self.opt.processed_data_dir: + self.opt.processed_data_dir = tempfile.TemporaryDirectory().name + iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, + self.opt.processed_data_dir) + + def init_model(self): + # load voca + data_dir = self.opt.processed_data_dir + keys_path = pjoin(data_dir, "keys.txt") + count_path = pjoin(data_dir, "count.txt") + self.logger.info("load key, count from %s, %s", keys_path, count_path) + with open(keys_path, "rb") as fin: + self.words = [line.strip() for line in fin] + with open(count_path, "rb") as fin: + self.word_count = np.array([float(line.strip()) for line in fin], + dtype=np.float32) + self.word_count = np.power(self.word_count, self.opt.count_power) + self.num_words = len(self.words) + assert len(self.words) == len(self.word_count) + + # count number of docs + h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") + self.num_docs = h5f["indptr"].shape[0] - 1 + h5f.close() + + self.logger.info("number of words: %d, docs: %d", + self.num_words, self.num_docs) + + if self.opt.neg: + self.obj.build_random_table( \ + self.word_count, self.opt.random_size, self.opt.num_threads) + else: + self.obj.build_huffman_tree(self.word_count) + + # random initialize alpha and beta + np.random.seed(self.opt.seed) + self.emb_in = np.random.normal( \ + size=(self.num_words, self.opt.num_dims)).astype(np.float32) + out_words = self.num_words if self.opt.neg else self.num_words - 1 + self.emb_out = np.random.uniform( \ + size=(out_words, self.opt.num_dims)).astype(np.float32) + self.logger.info("emb_in %s, emb_out %s initialized", + self.emb_in.shape, self.emb_out.shape) + + # push it to gpu + self.obj.load_model(self.emb_in, self.emb_out) + + def train_model(self): + self.preprocess_data() + self.init_model() + h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") + for epoch in range(1, self.opt.epochs + 1): + self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) + self._train_epoch(h5f) + self.obj.pull() + h5f.close() + + def _train_epoch(self, h5f): + offset, size = 0, h5f["cols"].shape[0] + pbar = aux.Progbar(size, stateful_metrics=["loss"]) + loss_nume, loss_deno = 0, 0 + while True: + target = h5f["indptr"][offset] + self.opt.batch_size + if target < size: + next_offset = h5f["rows"][target] + else: + next_offset = h5f["indptr"].shape[0] - 1 + indptr = h5f["indptr"][offset:next_offset + 1] + beg, end = indptr[0], indptr[-1] + indptr -= beg + cols = h5f["cols"][beg:end] + offset = next_offset + + # call cuda kernel + _loss_nume, _loss_deno = \ + self.obj.feed_data(cols, indptr) + + # accumulate loss + loss_nume += _loss_nume + loss_deno += _loss_deno + loss = loss_nume / (loss_deno + EPS) + + # update progress bar + pbar.update(end, values=[("loss", loss)]) + if end == size: + break diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc index 28fbbc8..06204f8 100644 --- a/cusim/ioutils/bindings.cc +++ b/cusim/ioutils/bindings.cc @@ -35,8 +35,8 @@ class IoUtilsBind { return obj_.TokenizeStream(num_lines, num_threads); } - void GetWordVocab(int min_count, std::string keys_path) { - obj_.GetWordVocab(min_count, keys_path); + void GetWordVocab(int min_count, std::string keys_path, std::string count_path) { + obj_.GetWordVocab(min_count, keys_path, count_path); } void GetToken(py::object& rows, py::object& cols, py::object& indptr) { @@ -62,7 +62,7 @@ PYBIND11_PLUGIN(ioutils_bind) { .def("tokenize_stream", &IoUtilsBind::TokenizeStream, py::arg("num_lines"), py::arg("num_threads")) .def("get_word_vocab", &IoUtilsBind::GetWordVocab, - py::arg("min_count"), py::arg("keys_path")) + py::arg("min_count"), py::arg("keys_path"), py::arg("count_path")) .def("get_token", &IoUtilsBind::GetToken, py::arg("indices"), py::arg("indptr"), py::arg("offset")) .def("__repr__", diff --git a/cusim/ioutils/pyioutils.py b/cusim/ioutils/pyioutils.py index 5bce9b7..71b52fd 100644 --- a/cusim/ioutils/pyioutils.py +++ b/cusim/ioutils/pyioutils.py @@ -33,7 +33,8 @@ def __init__(self, opt=None): assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" os.remove(tmp.name) - def load_stream_vocab(self, filepath, min_count, keys_path): + def load_stream_vocab(self, filepath, min_count, + keys_path, count_path): full_num_lines = self.obj.load_stream_file(filepath) pbar = aux.Progbar(full_num_lines, unit_name="line", stateful_metrics=["word_count"]) @@ -46,17 +47,18 @@ def load_stream_vocab(self, filepath, min_count, keys_path): pbar.update(processed, values=[("word_count", word_count)]) if processed == full_num_lines: break - self.obj.get_word_vocab(min_count, keys_path) + self.obj.get_word_vocab(min_count, keys_path, count_path) def convert_stream_to_h5(self, filepath, min_count, out_dir, chunk_indices=10000, seed=777): np.random.seed(seed) os.makedirs(out_dir, exist_ok=True) keys_path = pjoin(out_dir, "keys.txt") + count_path = pjoin(out_dir, "count.txt") token_path = pjoin(out_dir, "token.h5") - self.logger.info("save key and token to %s, %s", - keys_path, token_path) - self.load_stream_vocab(filepath, min_count, keys_path) + self.logger.info("save key, count, token to %s, %s, %s", + keys_path, count_path, token_path) + self.load_stream_vocab(filepath, min_count, keys_path, count_path) full_num_lines = self.obj.load_stream_file(filepath) pbar = aux.Progbar(full_num_lines, unit_name="line") processed = 0 diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index 10b3820..f468bb0 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -31,3 +31,34 @@ message CuLDAConfigProto { optional double vali_p = 13 [default = 0.2]; optional int32 seed = 14 [default = 777]; } + +message CuW2VConfigProto { + required string data_path = 7; + + optional int32 py_log_level = 1 [default = 2]; + optional int32 c_log_level = 2 [default = 2]; + + optional int32 num_dims = 3 [default = 50]; + optional int32 block_dim = 4 [default = 32]; + optional int32 hyper_threads = 5 [default = 10]; + optional string processed_data_dir = 6; + optional bool skip_preprocess = 8; + optional int32 word_min_count = 9 [default = 5]; + optional int32 batch_size = 10 [default = 100000]; + optional int32 epochs = 11 [default = 10]; + + // seed fields + optional int32 seed = 14 [default = 777]; + optional int32 table_seed = 15 [default = 777]; + optional int32 cuda_seed = 16 [default = 777]; + optional int32 random_size = 12 [default = 1000000]; + + optional int32 neg = 17 [default = 10]; + // as recommended in w2v paper + optional double count_power = 18 [default = 0.75]; + optional bool skip_gram = 19 [default = true]; + optional bool use_mean = 20 [default = true]; + optional double lr = 21 [default = 0.001]; + optional int32 window_size = 22 [default = 5]; + optional int32 num_threads = 23 [default = 4]; +} diff --git a/examples/example1.py b/examples/example1.py index f9362cb..1bea971 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -12,7 +12,7 @@ import h5py import numpy as np from gensim import downloader as api -from cusim import aux, IoUtils, CuLDA +from cusim import aux, IoUtils, CuLDA, CuW2V LOGGER = aux.get_logger() DOWNLOAD_PATH = "./res" @@ -46,7 +46,7 @@ def run_lda(): opt = { "data_path": DATA_PATH, "processed_data_dir": PROCESSED_DATA_DIR, - "skip_preprocess":True, + # "skip_preprocess":True, } lda = CuLDA(opt) lda.train_model() @@ -66,5 +66,15 @@ def run_lda(): print(f"rank {j + 1}. word: {word}, prob: {prob}") h5f.close() +def run_w2v(): + opt = { + # "c_log_level": 3, + "data_path": DATA_PATH, + "processed_data_dir": PROCESSED_DATA_DIR, + # "skip_preprocess":True, + } + w2v = CuW2V(opt) + w2v.train_model() + if __name__ == "__main__": fire.Fire() diff --git a/setup.py b/setup.py index 512d262..669403a 100644 --- a/setup.py +++ b/setup.py @@ -98,6 +98,21 @@ def __init__(self, name): "cpp/include/", np.get_include(), pybind11.get_include(), pybind11.get_include(True), CUDA['include'], "3rd/json11", "3rd/spdlog/include"]), + Extension("cusim.cuw2v.cuw2v_bind", + sources= util_srcs + [ \ + "cpp/src/cuw2v/cuw2v.cu", + "cusim/cuw2v/bindings.cc", + "3rd/json11/json11.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=["-fopenmp"], + library_dirs=[CUDA['lib64']], + libraries=['cudart', 'cublas', 'curand'], + extra_objects=[], + include_dirs=[ \ + "cpp/include/", np.get_include(), pybind11.get_include(), + pybind11.get_include(True), CUDA['include'], + "3rd/json11", "3rd/spdlog/include"]), ] @@ -182,7 +197,7 @@ def setup_package(): download_url="https://github.com/js1010/cusim/releases", include_package_data=False, license='Apache2', - packages=['cusim/', "cusim/ioutils/", "cusim/culda/"], + packages=['cusim/', "cusim/ioutils/", "cusim/culda/", "cusim/cuw2v/"], install_requires=INSTALL_REQUIRES, cmdclass=cmdclass, classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f],