From 882b1d39fed1f167cb567362d59f8cd1366bc530 Mon Sep 17 00:00:00 2001 From: js1010 Date: Thu, 11 Feb 2021 21:59:49 +0900 Subject: [PATCH 01/18] implement huffman tree --- cpp/src/cuw2v/cuw2v.cu | 130 ++++++++++++++++++++++++++++++++++ cpp/src/cuw2v/huffman_tree.cc | 64 +++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 cpp/src/cuw2v/cuw2v.cu create mode 100644 cpp/src/cuw2v/huffman_tree.cc diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu new file mode 100644 index 0000000..e809f90 --- /dev/null +++ b/cpp/src/cuw2v/cuw2v.cu @@ -0,0 +1,130 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "cuw2v/cuw2v.hpp" +#include "cuw2v/cuda_w2v_kernels.cuh" + +namespace cusim { + +CuW2V::CuW2V() { + logger_ = CuSimLogger().get_logger(); + dev_info_ = GetDeviceInfo(); + if (dev_info_.unknown) DEBUG0("Unknown device type"); + INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", + dev_info_.major, dev_info_.minor, dev_info_.mp_cnt, dev_info_.cores); +} + +CuW2V::~CuW2V() {} + +bool CuW2V::Init(std::string opt_path) { + std::ifstream in(opt_path.c_str()); + if (not in.is_open()) return false; + + std::string str((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + std::string err_cmt; + auto _opt = json11::Json::parse(str, err_cmt); + if (not err_cmt.empty()) return false; + opt_ = _opt; + CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + num_topics_ = opt_["num_dims"].int_value(); + block_dim_ = opt_["block_dim"].int_value(); + block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); + sg_ = opt_["skip_gram"].bool_value(); + // if zero, we will use hierarchical softmax + neg_ = opt["negative_sampling"].int_value(); + INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", + num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); + return true; +} + +void CuW2V::LoadModel(float* emb_in, float* emb_out, const int num_words, int num_hs_nodes = 0) { + num_words_ = num_words; + out_size_ = neg_? num_words_: num_hs_nodes; + + // copy embedding + DEBUG("copy model({} x {})", num_words_, num_dims_); + dev_emb_in_.resize(num_words_ * num_dims_); + dev_emb_out_.resize(out_size_ * num_dims_); + thrust::copy(emb_in, emb_in + num_words_ * num_dims_, dev_emb_in_.begin()); + thrust::copy(emb_out, emb_out + out_size_ * num_dims_, dev_emb_out_.begin()); + emb_in_ = emb_in; emb_out_ = emb_out; + + // set mutex + dev_mutex_in_.resize(num_words_); + dev_mutex_out_.resize(out_size_); + std::vector host_mutex_in(num_words_, 0); + std::vector host_mutex_out(out_size_, 0); + thrust::copy(host_mutex_in.begin(), host_mutex_in.end(), dev_mutex_in_.begin()); + thrust::copy(host_mutex_out.begin(), host_mutex_out.end(), dev_mutex_out_.begin()); + + CHECK_CUDA(cudaDeviceSynchronize()); +} + +std::pair CuLDA::FeedData( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr) { + + // copy feed data to GPU memory + thrust::device_vector dev_cols(num_cols); + thrust::device_vector dev_indptr(num_indptr + 1); + thrust::device_vector dev_losses(block_cnt_, 0.0f); + thrust::copy(cols, cols + num_cols, dev_cols.begin()); + thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); + thrust::copy(vali, vali + num_cols, dev_vali.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("copy feed data to GPU memory"); + + // run E step in GPU + EstepKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_vali.data()), + num_cols, num_indptr, num_topics_, num_iters, + thrust::raw_pointer_cast(dev_gamma_.data()), + thrust::raw_pointer_cast(dev_new_gamma_.data()), + thrust::raw_pointer_cast(dev_phi_.data()), + thrust::raw_pointer_cast(dev_alpha_.data()), + thrust::raw_pointer_cast(dev_beta_.data()), + thrust::raw_pointer_cast(dev_grad_alpha_.data()), + thrust::raw_pointer_cast(dev_new_beta_.data()), + thrust::raw_pointer_cast(dev_train_losses.data()), + thrust::raw_pointer_cast(dev_vali_losses.data()), + thrust::raw_pointer_cast(dev_mutex_.data())); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("run E step in GPU"); + + // pull loss + std::vector train_losses(block_cnt_), vali_losses(block_cnt_); + thrust::copy(dev_train_losses.begin(), dev_train_losses.end(), train_losses.begin()); + thrust::copy(dev_vali_losses.begin(), dev_vali_losses.end(), vali_losses.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("pull loss values"); + + // accumulate + float train_loss = std::accumulate(train_losses.begin(), train_losses.end(), 0.0f); + float vali_loss = std::accumulate(vali_losses.begin(), vali_losses.end(), 0.0f); + return {train_loss, vali_loss}; +} + +void CuLDA::Pull() { + thrust::copy(dev_grad_alpha_.begin(), dev_grad_alpha_.end(), grad_alpha_); + thrust::copy(dev_new_beta_.begin(), dev_new_beta_.end(), new_beta_); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +void CuLDA::Push() { + thrust::copy(alpha_, alpha_ + num_topics_, dev_alpha_.begin()); + thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); + thrust::copy(beta_, beta_ + num_words_ * num_topics_, dev_beta_.begin()); + thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +int CuLDA::GetBlockCnt() { + return block_cnt_; +} + +} // namespace cusim diff --git a/cpp/src/cuw2v/huffman_tree.cc b/cpp/src/cuw2v/huffman_tree.cc new file mode 100644 index 0000000..3914770 --- /dev/null +++ b/cpp/src/cuw2v/huffman_tree.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include "cuw2v/cuw2v.hpp" + +namespace cusim { + +struct PqItem { + float count; + int index; + PqItem *left, *right; + bool operator <(const PqItem& left, const PqItem& right) { + return std::tie(left.count, left.index) < std::tie(right.count, right.index); + } +} + +int CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { + num_words_ = num_words; + if (neg_) { + out_size_ = num_words_; + return; + } + std::priority_queue pq; + for (int i = 0; i < num_words; ++i) { + pq.emplace(word_count[i], i, nullptr, nullptr); + } + for (int i = 0; i < num_words - 1; ++i) { + auto min1 = pq.top(); pq.pop(); + auto min2 = pq.top(); pq.pop(); + pq.emplace(min1.count + min2.count, i + num_words, &min1, &min2); + } + + std::vector, std::vector>> stack = {{pq.top(), {}, {}}}; + PqItem node; + std::vector codes; + std::vector points; + codes_.clear(); points_.clear(); + codes_.resize(num_words); points_.resize(num_words); + int max_depth = 0; + while (not stack.empty()) { + std::tie(node, codes, points) = stack.back(); + stack.pop_back(); + int k = node.index; + if (k < num_words) { + codes_[k] = codes; + points_[k] = points; + } else { + points.push_back(k - num_words); + std::vector left_codes = codes; + std::vector right_codes = codes; + left_codes.push_back(false); + right_codes.push_back(true); + stack.push_back({node.left, left_codes, points}); + stack.push_back({node.right, right_codes, points}); + } + } + + +} + + +} // namespace cusim From d478883421ff64b895b7d3246c92c20fe3908e55 Mon Sep 17 00:00:00 2001 From: js1010 Date: Thu, 11 Feb 2021 23:35:44 +0900 Subject: [PATCH 02/18] copy dir --- cpp/include/cuw2v/cuda_lda_kernels.cuh | 121 +++++++++++++++++++++++++ cpp/include/cuw2v/culda.hpp | 88 ++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 cpp/include/cuw2v/cuda_lda_kernels.cuh create mode 100644 cpp/include/cuw2v/culda.hpp diff --git a/cpp/include/cuw2v/cuda_lda_kernels.cuh b/cpp/include/cuw2v/cuda_lda_kernels.cuh new file mode 100644 index 0000000..02dbb37 --- /dev/null +++ b/cpp/include/cuw2v/cuda_lda_kernels.cuh @@ -0,0 +1,121 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" + + +namespace cusim { + +// reference: http://web.science.mq.edu.au/~mjohnson/code/digamma.c +__inline__ __device__ +float Digamma(float x) { + float result = 0.0f, xx, xx2, xx4; + for ( ; x < 7.0f; ++x) + result -= 1.0f / x; + x -= 0.5f; + xx = 1.0f / x; + xx2 = xx * xx; + xx4 = xx2 * xx2; + result += logf(x) + 1.0f / 24.0f * xx2 + - 7.0f / 960.0f * xx4 + 31.0f / 8064.0f * xx4 * xx2 + - 127.0f / 30720.0f * xx4 * xx4; + return result; +} + +__global__ void EstepKernel( + const int* cols, const int* indptr, const bool* vali, + const int num_cols, const int num_indptr, + const int num_topics, const int num_iters, + float* gamma, float* new_gamma, float* phi, + const float* alpha, const float* beta, + float* grad_alpha, float* new_beta, + float* train_losses, float* vali_losses, int* mutex) { + + // storage for block + float* _gamma = gamma + num_topics * blockIdx.x; + float* _new_gamma = new_gamma + num_topics * blockIdx.x; + float* _phi = phi + num_topics * blockIdx.x; + float* _grad_alpha = grad_alpha + num_topics * blockIdx.x; + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + // initialize gamma + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _gamma[j] = alpha[j] + (end - beg) / num_topics; + __syncthreads(); + + // iterate E step + for (int j = 0; j < num_iters; ++j) { + // initialize new gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _new_gamma[k] = 0.0f; + __syncthreads(); + + // compute phi from gamma + for (int k = beg; k < end; ++k) { + const int w = cols[k]; + const bool _vali = vali[k]; + + // compute phi + if (not _vali or j + 1 == num_iters) { + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) + _phi[l] = beta[w * num_topics + l] * expf(Digamma(_gamma[l])); + __syncthreads(); + + // normalize phi and add it to new gamma and new beta + float phi_sum = ReduceSum(_phi, num_topics); + + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + _phi[l] /= phi_sum; + if (not _vali) _new_gamma[l] += _phi[l]; + } + __syncthreads(); + } + + if (j + 1 == num_iters) { + // write access of w th vector of new_beta + if (threadIdx.x == 0) { + while (atomicCAS(&mutex[w], 0, 1)) {} + } + + __syncthreads(); + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { + if (j + 1 == num_iters) { + if (not _vali) new_beta[w * num_topics + l] += _phi[l]; + _phi[l] *= beta[w * num_topics + l]; + } + } + __syncthreads(); + + // release lock + if (threadIdx.x == 0) mutex[w] = 0; + __syncthreads(); + + float p = fmaxf(EPS, ReduceSum(_phi, num_topics)); + if (threadIdx.x == 0) { + if (_vali) + vali_losses[blockIdx.x] += logf(p); + else + train_losses[blockIdx.x] += logf(p); + } + } + __syncthreads(); + } + + // update gamma + for (int k = threadIdx.x; k < num_topics; k += blockDim.x) + _gamma[k] = _new_gamma[k] + alpha[k]; + __syncthreads(); + } + float gamma_sum = ReduceSum(_gamma, num_topics); + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum)); + + __syncthreads(); + } +} + +} // cusim diff --git a/cpp/include/cuw2v/culda.hpp b/cpp/include/cuw2v/culda.hpp new file mode 100644 index 0000000..3a126cc --- /dev/null +++ b/cpp/include/cuw2v/culda.hpp @@ -0,0 +1,88 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT + +#include "json11.hpp" +#include "utils/log.hpp" +#include "utils/types.hpp" + +namespace cusim { + + +// reference: https://people.math.sc.edu/Burkardt/cpp_src/asa121/asa121.cpp +inline float Trigamma(float x) { + const float a = 0.0001f; + const float b = 5.0f; + const float b2 = 0.1666666667f; + const float b4 = -0.03333333333f; + const float b6 = 0.02380952381f; + const float b8 = -0.03333333333f; + float value = 0, y = 0, z = x; + if (x <= a) return 1.0f / x / x; + while (z < b) { + value += 1.0f / z / z; + z++; + } + y = 1.0f / z / z; + value += value + 0.5 * y + (1.0 + + y * (b2 + + y * (b4 + + y * (b6 + + y * b8)))) / z; + return value; +} + + +class CuLDA { + public: + CuLDA(); + ~CuLDA(); + bool Init(std::string opt_path); + void LoadModel(float* alpha, float* beta, + float* grad_alpha, float* new_beta, const int num_words); + std::pair FeedData( + const int* indices, const int* indptr, const bool* vali, + const int num_indices, const int num_indptr, const int num_iters); + void Pull(); + void Push(); + int GetBlockCnt(); + + private: + DeviceInfo dev_info_; + json11::Json opt_; + std::shared_ptr logger_; + thrust::device_vector dev_alpha_, dev_beta_; + thrust::device_vector dev_grad_alpha_, dev_new_beta_; + thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; + thrust::device_vector dev_mutex_; + + float *alpha_, *beta_, *grad_alpha_, *new_beta_; + int block_cnt_, block_dim_; + int num_topics_, num_words_; +}; + +} // namespace cusim From 6b822b419a64abeae2129118f9d505c72e64f3ce Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 00:14:26 +0900 Subject: [PATCH 03/18] add cuw2v --- ...a_lda_kernels.cuh => cuda_w2v_kernels.cuh} | 0 cpp/include/cuw2v/culda.hpp | 88 ----------- cpp/include/cuw2v/cuw2v.hpp | 80 ++++++++++ cpp/src/cuw2v/cuw2v.cu | 149 ++++++++++-------- 4 files changed, 160 insertions(+), 157 deletions(-) rename cpp/include/cuw2v/{cuda_lda_kernels.cuh => cuda_w2v_kernels.cuh} (100%) delete mode 100644 cpp/include/cuw2v/culda.hpp create mode 100644 cpp/include/cuw2v/cuw2v.hpp diff --git a/cpp/include/cuw2v/cuda_lda_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_kernels.cuh similarity index 100% rename from cpp/include/cuw2v/cuda_lda_kernels.cuh rename to cpp/include/cuw2v/cuda_w2v_kernels.cuh diff --git a/cpp/include/cuw2v/culda.hpp b/cpp/include/cuw2v/culda.hpp deleted file mode 100644 index 3a126cc..0000000 --- a/cpp/include/cuw2v/culda.hpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2021 Jisang Yoon -// All rights reserved. -// -// This source code is licensed under the Apache 2.0 license found in the -// LICENSE file in the root directory of this source tree. -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // NOLINT - -#include "json11.hpp" -#include "utils/log.hpp" -#include "utils/types.hpp" - -namespace cusim { - - -// reference: https://people.math.sc.edu/Burkardt/cpp_src/asa121/asa121.cpp -inline float Trigamma(float x) { - const float a = 0.0001f; - const float b = 5.0f; - const float b2 = 0.1666666667f; - const float b4 = -0.03333333333f; - const float b6 = 0.02380952381f; - const float b8 = -0.03333333333f; - float value = 0, y = 0, z = x; - if (x <= a) return 1.0f / x / x; - while (z < b) { - value += 1.0f / z / z; - z++; - } - y = 1.0f / z / z; - value += value + 0.5 * y + (1.0 - + y * (b2 - + y * (b4 - + y * (b6 - + y * b8)))) / z; - return value; -} - - -class CuLDA { - public: - CuLDA(); - ~CuLDA(); - bool Init(std::string opt_path); - void LoadModel(float* alpha, float* beta, - float* grad_alpha, float* new_beta, const int num_words); - std::pair FeedData( - const int* indices, const int* indptr, const bool* vali, - const int num_indices, const int num_indptr, const int num_iters); - void Pull(); - void Push(); - int GetBlockCnt(); - - private: - DeviceInfo dev_info_; - json11::Json opt_; - std::shared_ptr logger_; - thrust::device_vector dev_alpha_, dev_beta_; - thrust::device_vector dev_grad_alpha_, dev_new_beta_; - thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; - thrust::device_vector dev_mutex_; - - float *alpha_, *beta_, *grad_alpha_, *new_beta_; - int block_cnt_, block_dim_; - int num_topics_, num_words_; -}; - -} // namespace cusim diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp new file mode 100644 index 0000000..c722295 --- /dev/null +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -0,0 +1,80 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT + +#include "json11.hpp" +#include "utils/log.hpp" +#include "utils/types.hpp" + +namespace cusim { + + +struct HuffmanTreeNode { + float count; + int index, left, right; + HuffmanTreeNode(float count0, int index0, int left0, int right0) { + count = count0; index = index0; left = left0; right = right0; + } +}; + +std::vector huffman_nodes; +bool CompareIndex(int lhs, int rhs); + +class CuW2V { + public: + CuW2V(); + ~CuW2V(); + bool Init(std::string opt_path); + void LoadModel(float* emb_in, float* emb_out); + void BuildHuffmanTree(const float* word_count, const int num_words); + int GetBlockCnt(); + + private: + DeviceInfo dev_info_; + json11::Json opt_; + std::shared_ptr logger_; + int block_cnt_, block_dim_; + int num_dims_, num_words_; + float *emb_in_, *emb_out_; + thrust::device_vector dev_emb_in_, dev_emb_out_; + + // variables to construct huffman tree + int max_depth_; + std::vector> codes_; + std::vector> points_; + thrust::device_vector dev_codes_; + thrust::device_vector dev_points_, dev_indptr_; + + + bool sg_; + int neg_; + + // mutex to handle concurrent model update + thrust::device_vector dev_mutex_in_, dev_mutex_out_; +}; + +} // namespace cusim diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index e809f90..ee0d4a6 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -8,6 +8,10 @@ namespace cusim { +bool CompareIndex(int lhs, int rhs) { + return huffman_nodes[lhs].count > huffman_nodes[rhs].count; +} + CuW2V::CuW2V() { logger_ = CuSimLogger().get_logger(); dev_info_ = GetDeviceInfo(); @@ -29,101 +33,108 @@ bool CuW2V::Init(std::string opt_path) { if (not err_cmt.empty()) return false; opt_ = _opt; CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); - num_topics_ = opt_["num_dims"].int_value(); + num_dims_ = opt_["num_dims"].int_value(); block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); sg_ = opt_["skip_gram"].bool_value(); // if zero, we will use hierarchical softmax - neg_ = opt["negative_sampling"].int_value(); + neg_ = opt_["negative_sampling"].int_value(); INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); return true; } -void CuW2V::LoadModel(float* emb_in, float* emb_out, const int num_words, int num_hs_nodes = 0) { + +void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { num_words_ = num_words; - out_size_ = neg_? num_words_: num_hs_nodes; + if (neg_) return; + + huffman_nodes.clear(); + std::priority_queue, decltype(&CompareIndex)> pq(CompareIndex); + for (int i = 0; i < num_words; ++i) { + huffman_nodes.emplace_back(word_count[i], i, -1, -1); + pq.push(i); + } + for (int i = 0; i < num_words - 1; ++i) { + auto& min1 = huffman_nodes[pq.top()]; pq.pop(); + auto& min2 = huffman_nodes[pq.top()]; pq.pop(); + huffman_nodes.emplace_back(min1.count + min2.count, i + num_words, min1.index, min2.index); + pq.push(i + num_words); + } + + std::vector, std::vector>> stack = {{pq.top(), {}, {}}}; + int nodeid; + std::vector codes; + std::vector points; + codes_.clear(); points_.clear(); + codes_.resize(num_words); points_.resize(num_words); + max_depth_ = 0; + while (not stack.empty()) { + std::tie(nodeid, codes, points) = stack.back(); + stack.pop_back(); + if (nodeid < num_words) { + codes_[nodeid] = codes; + points_[nodeid] = points; + max_depth_ = std::max(max_depth_, + static_cast(codes.size())); + } else { + points.push_back(nodeid - num_words); + std::vector left_codes = codes; + std::vector right_codes = codes; + left_codes.push_back(false); + right_codes.push_back(true); + auto& node = huffman_nodes[nodeid]; + stack.push_back(make_tuple(node.left, left_codes, points)); + stack.push_back(make_tuple(node.right, right_codes, points)); + } + } + std::vector host_codes; + std::vector host_points; + std::vector host_indptr = {0}; + int size = 0; + for (int i = 0; i < num_words; ++i) { + auto& codes = codes_[i]; + auto& points = points_[i]; + int n = codes.size(); + size += n; + host_indptr.push_back(size); + for (int j = 0; j < n; ++j) { + host_codes.push_back(static_cast(codes[j])); + host_points.push_back(points[j]); + } + } + + dev_codes_.resize(size); dev_points_.resize(size), dev_indptr_.resize(num_words + 1); + thrust::copy(host_codes.begin(), host_codes.end(), dev_codes_.begin()); + thrust::copy(host_points.begin(), host_points.end(), dev_points_.begin()); + thrust::copy(host_indptr.begin(), host_indptr.end(), dev_indptr_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} + +void CuW2V::LoadModel(float* emb_in, float* emb_out) { + int out_words = neg_? num_words_: num_words_ - 1; + // copy embedding DEBUG("copy model({} x {})", num_words_, num_dims_); dev_emb_in_.resize(num_words_ * num_dims_); - dev_emb_out_.resize(out_size_ * num_dims_); + dev_emb_out_.resize(out_words * num_dims_); thrust::copy(emb_in, emb_in + num_words_ * num_dims_, dev_emb_in_.begin()); - thrust::copy(emb_out, emb_out + out_size_ * num_dims_, dev_emb_out_.begin()); + thrust::copy(emb_out, emb_out + out_words * num_dims_, dev_emb_out_.begin()); emb_in_ = emb_in; emb_out_ = emb_out; // set mutex dev_mutex_in_.resize(num_words_); - dev_mutex_out_.resize(out_size_); + dev_mutex_out_.resize(out_words); std::vector host_mutex_in(num_words_, 0); - std::vector host_mutex_out(out_size_, 0); + std::vector host_mutex_out(out_words, 0); thrust::copy(host_mutex_in.begin(), host_mutex_in.end(), dev_mutex_in_.begin()); thrust::copy(host_mutex_out.begin(), host_mutex_out.end(), dev_mutex_out_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); } -std::pair CuLDA::FeedData( - const int* cols, const int* indptr, const bool* vali, - const int num_cols, const int num_indptr) { - - // copy feed data to GPU memory - thrust::device_vector dev_cols(num_cols); - thrust::device_vector dev_indptr(num_indptr + 1); - thrust::device_vector dev_losses(block_cnt_, 0.0f); - thrust::copy(cols, cols + num_cols, dev_cols.begin()); - thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); - thrust::copy(vali, vali + num_cols, dev_vali.begin()); - CHECK_CUDA(cudaDeviceSynchronize()); - DEBUG0("copy feed data to GPU memory"); - - // run E step in GPU - EstepKernel<<>>( - thrust::raw_pointer_cast(dev_cols.data()), - thrust::raw_pointer_cast(dev_indptr.data()), - thrust::raw_pointer_cast(dev_vali.data()), - num_cols, num_indptr, num_topics_, num_iters, - thrust::raw_pointer_cast(dev_gamma_.data()), - thrust::raw_pointer_cast(dev_new_gamma_.data()), - thrust::raw_pointer_cast(dev_phi_.data()), - thrust::raw_pointer_cast(dev_alpha_.data()), - thrust::raw_pointer_cast(dev_beta_.data()), - thrust::raw_pointer_cast(dev_grad_alpha_.data()), - thrust::raw_pointer_cast(dev_new_beta_.data()), - thrust::raw_pointer_cast(dev_train_losses.data()), - thrust::raw_pointer_cast(dev_vali_losses.data()), - thrust::raw_pointer_cast(dev_mutex_.data())); - CHECK_CUDA(cudaDeviceSynchronize()); - DEBUG0("run E step in GPU"); - - // pull loss - std::vector train_losses(block_cnt_), vali_losses(block_cnt_); - thrust::copy(dev_train_losses.begin(), dev_train_losses.end(), train_losses.begin()); - thrust::copy(dev_vali_losses.begin(), dev_vali_losses.end(), vali_losses.begin()); - CHECK_CUDA(cudaDeviceSynchronize()); - DEBUG0("pull loss values"); - - // accumulate - float train_loss = std::accumulate(train_losses.begin(), train_losses.end(), 0.0f); - float vali_loss = std::accumulate(vali_losses.begin(), vali_losses.end(), 0.0f); - return {train_loss, vali_loss}; -} - -void CuLDA::Pull() { - thrust::copy(dev_grad_alpha_.begin(), dev_grad_alpha_.end(), grad_alpha_); - thrust::copy(dev_new_beta_.begin(), dev_new_beta_.end(), new_beta_); - CHECK_CUDA(cudaDeviceSynchronize()); -} - -void CuLDA::Push() { - thrust::copy(alpha_, alpha_ + num_topics_, dev_alpha_.begin()); - thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); - thrust::copy(beta_, beta_ + num_words_ * num_topics_, dev_beta_.begin()); - thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); - CHECK_CUDA(cudaDeviceSynchronize()); -} - -int CuLDA::GetBlockCnt() { +int CuW2V::GetBlockCnt() { return block_cnt_; } From 90ba14423737e8f82cefd394de644c55b69389ea Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 00:14:57 +0900 Subject: [PATCH 04/18] remove huffman_tree.cc --- cpp/src/cuw2v/huffman_tree.cc | 64 ----------------------------------- 1 file changed, 64 deletions(-) delete mode 100644 cpp/src/cuw2v/huffman_tree.cc diff --git a/cpp/src/cuw2v/huffman_tree.cc b/cpp/src/cuw2v/huffman_tree.cc deleted file mode 100644 index 3914770..0000000 --- a/cpp/src/cuw2v/huffman_tree.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2021 Jisang Yoon -// All rights reserved. -// -// This source code is licensed under the Apache 2.0 license found in the -// LICENSE file in the root directory of this source tree. -#include "cuw2v/cuw2v.hpp" - -namespace cusim { - -struct PqItem { - float count; - int index; - PqItem *left, *right; - bool operator <(const PqItem& left, const PqItem& right) { - return std::tie(left.count, left.index) < std::tie(right.count, right.index); - } -} - -int CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { - num_words_ = num_words; - if (neg_) { - out_size_ = num_words_; - return; - } - std::priority_queue pq; - for (int i = 0; i < num_words; ++i) { - pq.emplace(word_count[i], i, nullptr, nullptr); - } - for (int i = 0; i < num_words - 1; ++i) { - auto min1 = pq.top(); pq.pop(); - auto min2 = pq.top(); pq.pop(); - pq.emplace(min1.count + min2.count, i + num_words, &min1, &min2); - } - - std::vector, std::vector>> stack = {{pq.top(), {}, {}}}; - PqItem node; - std::vector codes; - std::vector points; - codes_.clear(); points_.clear(); - codes_.resize(num_words); points_.resize(num_words); - int max_depth = 0; - while (not stack.empty()) { - std::tie(node, codes, points) = stack.back(); - stack.pop_back(); - int k = node.index; - if (k < num_words) { - codes_[k] = codes; - points_[k] = points; - } else { - points.push_back(k - num_words); - std::vector left_codes = codes; - std::vector right_codes = codes; - left_codes.push_back(false); - right_codes.push_back(true); - stack.push_back({node.left, left_codes, points}); - stack.push_back({node.right, right_codes, points}); - } - } - - -} - - -} // namespace cusim From 9a973bfc5a64d11ea8bffe737a3ce8df64cac95b Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 17:20:36 +0900 Subject: [PATCH 05/18] implement negative sampling --- cpp/include/cuw2v/cuda_w2v_kernels.cuh | 243 ++++++++++++++++--------- 1 file changed, 155 insertions(+), 88 deletions(-) diff --git a/cpp/include/cuw2v/cuda_w2v_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_kernels.cuh index 02dbb37..0aeac0c 100644 --- a/cpp/include/cuw2v/cuda_w2v_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_kernels.cuh @@ -6,115 +6,182 @@ #pragma once #include "utils/cuda_utils_kernels.cuh" +using thrust::random::default_random_engine; +using thrust::random::uniform_int_distribution; namespace cusim { -// reference: http://web.science.mq.edu.au/~mjohnson/code/digamma.c + __inline__ __device__ -float Digamma(float x) { - float result = 0.0f, xx, xx2, xx4; - for ( ; x < 7.0f; ++x) - result -= 1.0f / x; - x -= 0.5f; - xx = 1.0f / x; - xx2 = xx * xx; - xx4 = xx2 * xx2; - result += logf(x) + 1.0f / 24.0f * xx2 - - 7.0f / 960.0f * xx4 + 31.0f / 8064.0f * xx4 * xx2 - - 127.0f / 30720.0f * xx4 * xx4; - return result; +void PositiveFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims) { + static __shared__ float g; + float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(-dot); + g = exp_dot / (1 + exp_dot); + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + float tmp = vec2[i]; + vec2[i] += vec1[i] * g; + grad[i] += tmp * g; + } + __syncthreads(); } -__global__ void EstepKernel( - const int* cols, const int* indptr, const bool* vali, - const int num_cols, const int num_indptr, - const int num_topics, const int num_iters, - float* gamma, float* new_gamma, float* phi, - const float* alpha, const float* beta, - float* grad_alpha, float* new_beta, - float* train_losses, float* vali_losses, int* mutex) { +__inline__ __device__ +void NegativeFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims) { + static __shared__ float g; + float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(dot); + g = exp_dot / (1 + exp_dot); + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + float tmp = vec2[i]; + vec2[i] -= vec1[i] * g; + grad[i] -= tmp * g; + } + __syncthreads(); +} + +__global__ void W2VNegSgKernel( + const int* cols, const int* indptr, const int window, + const int* random_table, const int random_size, default_random_engine* rngs, + const int num_cols, const int num_indptr, const int num_dims, const int neg, + float* emb_in, float* emb_out, float* loss_nume, float* loss_deno) { - // storage for block - float* _gamma = gamma + num_topics * blockIdx.x; - float* _new_gamma = new_gamma + num_topics * blockIdx.x; - float* _phi = phi + num_topics * blockIdx.x; - float* _grad_alpha = grad_alpha + num_topics * blockIdx.x; + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + static __shared__ uniform_int_distribution dist_neg(0, random_size - 1); + static __shared__ uniform_int_distribution dist_window(0, window - 1); + static __shared__ int reduced_windows; + static __shared__ int neg_word; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + + // zero-initialize shared mem + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) + grad[i] = 0.0f; + __syncthreads(); for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { int beg = indptr[i], end = indptr[i + 1]; - // initialize gamma - for (int j = threadIdx.x; j < num_topics; j += blockDim.x) - _gamma[j] = alpha[j] + (end - beg) / num_topics; - __syncthreads(); - - // iterate E step - for (int j = 0; j < num_iters; ++j) { - // initialize new gamma - for (int k = threadIdx.x; k < num_topics; k += blockDim.x) - _new_gamma[k] = 0.0f; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); __syncthreads(); - - // compute phi from gamma - for (int k = beg; k < end; ++k) { - const int w = cols[k]; - const bool _vali = vali[k]; - - // compute phi - if (not _vali or j + 1 == num_iters) { - for (int l = threadIdx.x; l < num_topics; l += blockDim.x) - _phi[l] = beta[w * num_topics + l] * expf(Digamma(_gamma[l])); - __syncthreads(); - - // normalize phi and add it to new gamma and new beta - float phi_sum = ReduceSum(_phi, num_topics); - - for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { - _phi[l] /= phi_sum; - if (not _vali) _new_gamma[l] += _phi[l]; - } + int beg2 = max(beg, j - window + reduced_windows); + int end2 = min(end, j + window - reduced_windows + 1); + float* _emb_in = emb_in + num_dims * cols[j]; + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + PositiveFeedback(_emb_in, emb_out + num_dims * cols[k], + grad, _loss_nume, _loss_deno, num_dims) + if (int l = 0; l < neg; ++l) { + if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthreads(); + NegativeFeedback(_emb_in, emb_out + num_dims * neg_word, + grad, _loss_nume, _loss_deno, num_dims); + } + __syncthreads(); + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * j + l] += grad[l]; + grad[l] = 0.0f; } - - if (j + 1 == num_iters) { - // write access of w th vector of new_beta - if (threadIdx.x == 0) { - while (atomicCAS(&mutex[w], 0, 1)) {} - } + __syncthreads(); + } + } + } +} - __syncthreads(); - for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { - if (j + 1 == num_iters) { - if (not _vali) new_beta[w * num_topics + l] += _phi[l]; - _phi[l] *= beta[w * num_topics + l]; - } - } - __syncthreads(); +__global__ void W2VNegCbowKernel( + const int* cols, const int* indptr, const int window, + const int* random_table, const int random_size, default_random_engine* rngs, + const int num_cols, const int num_indptr, const int num_dims, const int neg, + float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean) { + + default_random_engine& rng = rngs[blockIdx.x]; + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; - // release lock - if (threadIdx.x == 0) mutex[w] = 0; - __syncthreads(); + static __shared__ uniform_int_distribution dist_neg(0, random_size - 1); + static __shared__ uniform_int_distribution dist_window(0, window - 1); + static __shared__ int reduced_windows; + static __shared__ int neg_word; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + float* cbow = &shared_memory[num_dims]; - float p = fmaxf(EPS, ReduceSum(_phi, num_topics)); - if (threadIdx.x == 0) { - if (_vali) - vali_losses[blockIdx.x] += logf(p); - else - train_losses[blockIdx.x] += logf(p); - } + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window + reduced_windows); + int end2 = min(end, j + window - reduced_windows + 1); + if (end2 - beg2 <= 1) continue; + + // zero-initialize shared mem + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] = 0.0f; + cbow[k] = 0.0f; + } + + // compute cbow + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + cbow[l] += emb_in[num_dims * cols[k] + l]; } - __syncthreads(); } + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + cbow[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad, + loss_nume, loss_deno, num_dims); + __syncthreads(); + + // update negative feedback + for (int k = 0; k < neg; ++k){ + if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; + __syncthredas(); + NegativeFeedback(cbow, emb_out + num_dims * neg_word, + grad, _loss_nume, _loss_deno, num_dims); + } + __syncthreads(); + + // normalize grad if use_mean = true + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); - // update gamma - for (int k = threadIdx.x; k < num_topics; k += blockDim.x) - _gamma[k] = _new_gamma[k] + alpha[k]; + // update emb_in + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) + emb_in[num_dims * cols[k] + l] += grad[l]; + } __syncthreads(); - } - float gamma_sum = ReduceSum(_gamma, num_topics); - for (int j = threadIdx.x; j < num_topics; j += blockDim.x) - _grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum)); - __syncthreads(); + } } } From ddd9095d92a25dd9ce42d428986eb46a21c2dd55 Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 17:22:25 +0900 Subject: [PATCH 06/18] separate files --- cpp/include/cuw2v/cuda_w2v_base_kernels.cuh | 55 +++++++++++++++++++ ...2v_kernels.cuh => cuda_w2v_ns_kernels.cuh} | 42 +------------- 2 files changed, 56 insertions(+), 41 deletions(-) create mode 100644 cpp/include/cuw2v/cuda_w2v_base_kernels.cuh rename cpp/include/cuw2v/{cuda_w2v_kernels.cuh => cuda_w2v_ns_kernels.cuh} (81%) diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh new file mode 100644 index 0000000..ace4964 --- /dev/null +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -0,0 +1,55 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" + +using thrust::random::default_random_engine; +using thrust::random::uniform_int_distribution; + +namespace cusim { + + +__inline__ __device__ +void PositiveFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims) { + static __shared__ float g; + float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(-dot); + g = exp_dot / (1 + exp_dot); + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + float tmp = vec2[i]; + vec2[i] += vec1[i] * g; + grad[i] += tmp * g; + } + __syncthreads(); +} + +__inline__ __device__ +void NegativeFeedback(const float* vec1, float* vec2, float* grad, + float& loss_nume, float& loss_deno, const int num_dims) { + static __shared__ float g; + float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + if (threadIdx.x == 0) { + float exp_dot = expf(dot); + g = exp_dot / (1 + exp_dot); + loss_nume += logf(1 + exp_dot); + loss_deno++; + } + __syncthreads(); + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { + float tmp = vec2[i]; + vec2[i] -= vec1[i] * g; + grad[i] -= tmp * g; + } + __syncthreads(); +} + +} // cusim diff --git a/cpp/include/cuw2v/cuda_w2v_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh similarity index 81% rename from cpp/include/cuw2v/cuda_w2v_kernels.cuh rename to cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index 0aeac0c..d4a07a1 100644 --- a/cpp/include/cuw2v/cuda_w2v_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -5,53 +5,13 @@ // LICENSE file in the root directory of this source tree. #pragma once #include "utils/cuda_utils_kernels.cuh" +#include "w2v/cuda_w2v_base_kernels.cuh" using thrust::random::default_random_engine; using thrust::random::uniform_int_distribution; namespace cusim { - -__inline__ __device__ -void PositiveFeedback(const float* vec1, float* vec2, float* grad, - float& loss_nume, float& loss_deno, const int num_dims) { - static __shared__ float g; - float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); - if (threadIdx.x == 0) { - float exp_dot = expf(-dot); - g = exp_dot / (1 + exp_dot); - loss_nume += logf(1 + exp_dot); - loss_deno++; - } - __syncthreads(); - for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { - float tmp = vec2[i]; - vec2[i] += vec1[i] * g; - grad[i] += tmp * g; - } - __syncthreads(); -} - -__inline__ __device__ -void NegativeFeedback(const float* vec1, float* vec2, float* grad, - float& loss_nume, float& loss_deno, const int num_dims) { - static __shared__ float g; - float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); - if (threadIdx.x == 0) { - float exp_dot = expf(dot); - g = exp_dot / (1 + exp_dot); - loss_nume += logf(1 + exp_dot); - loss_deno++; - } - __syncthreads(); - for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { - float tmp = vec2[i]; - vec2[i] -= vec1[i] * g; - grad[i] -= tmp * g; - } - __syncthreads(); -} - __global__ void W2VNegSgKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, From 4bcb5abc60acdd8842b8f217b5b93f89e8031b2e Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 17:22:53 +0900 Subject: [PATCH 07/18] add Dot product function --- cpp/include/utils/cuda_utils_kernels.cuh | 39 ++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh index 026da7f..670303d 100644 --- a/cpp/include/utils/cuda_utils_kernels.cuh +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -130,6 +130,45 @@ float warp_reduce_sum(float val) { return val; } +__inline__ __device__ +float Dot(const float* vec1, const float* vec2, const int length) { + + static __shared__ float shared[32]; + + // figure out the warp/ position inside the warp + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // paritial sum + float val = 0.0f; + for (int i = threadIdx.x; i < length; i += blockDim.x) + val += vec1[i] * vec2[i]; + val = warp_reduce_sum(val); + + // write out the partial reduction to shared memory if appropiate + if (lane == 0) { + shared[warp] = val; + } + __syncthreads(); + + // if we we don't have multiple warps, we're done + if (blockDim.x <= WARP_SIZE) { + return shared[0]; + } + + // otherwise reduce again in the first warp + val = (threadIdx.x < blockDim.x / WARP_SIZE) ? shared[lane]: 0.0f; + if (warp == 0) { + val = warp_reduce_sum(val); + // broadcast back to shared memory + if (threadIdx.x == 0) { + shared[0] = val; + } + } + __syncthreads(); + return shared[0]; +} + __inline__ __device__ float ReduceSum(const float* vec, const int length) { From 31929737010196fc11b0b3df2f299bfd72f96216 Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 18:08:03 +0900 Subject: [PATCH 08/18] implement hierarchical softmax --- cpp/include/cuw2v/cuda_w2v_base_kernels.cuh | 8 +- cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh | 146 ++++++++++++++++++++ cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh | 16 +-- 3 files changed, 158 insertions(+), 12 deletions(-) create mode 100644 cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh index ace4964..35adfa7 100644 --- a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -14,12 +14,12 @@ namespace cusim { __inline__ __device__ void PositiveFeedback(const float* vec1, float* vec2, float* grad, - float& loss_nume, float& loss_deno, const int num_dims) { + float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); if (threadIdx.x == 0) { float exp_dot = expf(-dot); - g = exp_dot / (1 + exp_dot); + g = exp_dot / (1 + exp_dot) * lr; loss_nume += logf(1 + exp_dot); loss_deno++; } @@ -34,12 +34,12 @@ void PositiveFeedback(const float* vec1, float* vec2, float* grad, __inline__ __device__ void NegativeFeedback(const float* vec1, float* vec2, float* grad, - float& loss_nume, float& loss_deno, const int num_dims) { + float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); if (threadIdx.x == 0) { float exp_dot = expf(dot); - g = exp_dot / (1 + exp_dot); + g = exp_dot / (1 + exp_dot) * lr; loss_nume += logf(1 + exp_dot); loss_deno++; } diff --git a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh new file mode 100644 index 0000000..d25ec13 --- /dev/null +++ b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh @@ -0,0 +1,146 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include "utils/cuda_utils_kernels.cuh" +#include "w2v/cuda_w2v_base_kernels.cuh" + +using thrust::random::default_random_engine; +using thrust::random::uniform_int_distribution; + +namespace cusim { + +__global__ void W2VHsSgKernel( + const int* cols, const int* indptr, const int window, + const bool* codes, const int* points, const int* hs_indptr, + const int num_indptr, const int num_dims, + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, const float lr) { + + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + static __shared__ int reduced_windows; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + + // zero-initialize shared mem + for (int i = threadIdx.x; i < num_dims; i += blockDim.x) + grad[i] = 0.0f; + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window + reduced_windows); + int end2 = min(end, j + window - reduced_windows + 1); + float* _emb_in = emb_in + num_dims * cols[j]; + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + int beg3 = hs_indptr[cols[k]]; + int end3 = hs_indptr[cols[k] + 1]; + if (int l = beg3; l < end3; ++l) { + if (codes[l]) { + PositiveFeedback(_emb_in, emb_out + num_dims * points[l], + grad, _loss_nume, _loss_deno, num_dims, lr); + } else { + NegativeFeedback(_emb_in, emb_out + num_dims * points[l], + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + } + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * cols[j] + l] += grad[l]; + grad[l] = 0.0f; + } + __syncthreads(); + } + } + } +} + +__global__ void W2VHsCbowKernel( + const int* cols, const int* indptr, const int window, + const bool* codes, const int* points, const int* hs_indptr, + const int num_indptr, const int num_dims, const int neg, + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, + const bool use_mean, const float lr) { + + float& _loss_nume = loss_nume[blockIdx.x]; + float& _loss_deno = loss_deno[blockIdx.x]; + + static __shared__ int reduced_windows; + extern __shared__ float shared_memory[]; + float* grad = &shared_memory[0]; + float* cbow = &shared_memory[num_dims]; + + __syncthreads(); + + for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { + int beg = indptr[i], end = indptr[i + 1]; + for (int j = beg; j < end; ++j) { + if (threadIdx.x == 0) reduced_windows = dist_window(rng); + __syncthreads(); + int beg2 = max(beg, j - window + reduced_windows); + int end2 = min(end, j + window - reduced_windows + 1); + if (end2 - beg2 <= 1) continue; + + // zero-initialize shared mem + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] = 0.0f; + cbow[k] = 0.0f; + } + + // compute cbow + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + cbow[l] += emb_in[num_dims * cols[k] + l]; + } + } + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + cbow[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + int beg3 = hs_indptr[cols[k]]; + int end3 = hs_indptr[cols[k] + 1]; + if (int k = beg3; k < end3; ++k) { + if (codes[k]) { + PositiveFeedback(cbow, emb_out + num_dims * points[k], + grad, _loss_nume, _loss_deno, num_dims, lr); + } else { + NegativeFeedback(cbow, emb_out + num_dims * points[k], + grad, _loss_nume, _loss_deno, num_dims, lr); + } + __syncthreads(); + } + + // normalize grad if use_mean = true + if (use_mean) { + for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { + grad[k] /= (end2 - beg2 - 1); + } + } + __syncthreads(); + + // update emb_in + for (int k = beg2; k < end2; ++k) { + if (k == j) continue; + for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { + emb_in[num_dims * cols[k] + l] += grad[l]; + } + __syncthreads(); + } + } + } +} + +} // cusim diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index d4a07a1..9e3acfc 100644 --- a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -15,8 +15,8 @@ namespace cusim { __global__ void W2VNegSgKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, - const int num_cols, const int num_indptr, const int num_dims, const int neg, - float* emb_in, float* emb_out, float* loss_nume, float* loss_deno) { + const int num_indptr, const int num_dims, const int neg, + float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) { default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; @@ -45,16 +45,16 @@ __global__ void W2VNegSgKernel( for (int k = beg2; k < end2; ++k) { if (k == j) continue; PositiveFeedback(_emb_in, emb_out + num_dims * cols[k], - grad, _loss_nume, _loss_deno, num_dims) + grad, _loss_nume, _loss_deno, num_dims, lr) if (int l = 0; l < neg; ++l) { if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthreads(); NegativeFeedback(_emb_in, emb_out + num_dims * neg_word, - grad, _loss_nume, _loss_deno, num_dims); + grad, _loss_nume, _loss_deno, num_dims, lr); } __syncthreads(); for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { - emb_in[num_dims * j + l] += grad[l]; + emb_in[num_dims * cols[j] + l] += grad[l]; grad[l] = 0.0f; } __syncthreads(); @@ -66,7 +66,7 @@ __global__ void W2VNegSgKernel( __global__ void W2VNegCbowKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, - const int num_cols, const int num_indptr, const int num_dims, const int neg, + const int num_indptr, const int num_dims, const int neg, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean) { default_random_engine& rng = rngs[blockIdx.x]; @@ -113,7 +113,7 @@ __global__ void W2VNegCbowKernel( __syncthreads(); PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad, - loss_nume, loss_deno, num_dims); + loss_nume, loss_deno, num_dims, lr); __syncthreads(); // update negative feedback @@ -121,7 +121,7 @@ __global__ void W2VNegCbowKernel( if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthredas(); NegativeFeedback(cbow, emb_out + num_dims * neg_word, - grad, _loss_nume, _loss_deno, num_dims); + grad, _loss_nume, _loss_deno, num_dims, lr); } __syncthreads(); From 30aae5557791a8fe74a17a10b5ba804b8847cba0 Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 18:09:36 +0900 Subject: [PATCH 09/18] add lr to variables --- cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index 9e3acfc..f25342d 100644 --- a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -67,7 +67,8 @@ __global__ void W2VNegCbowKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, const int num_indptr, const int num_dims, const int neg, - float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean) { + float* emb_in, float* emb_out, + float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; From 414bdb1dadce0ccec3a63d34f21e1db782be0800 Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 21:22:21 +0900 Subject: [PATCH 10/18] no lint warning --- cpp/include/cuw2v/cuda_w2v_base_kernels.cuh | 10 +- cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh | 21 ++-- cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh | 26 ++--- cpp/include/cuw2v/cuw2v.hpp | 17 ++-- cpp/src/cuw2v/cuw2v.cu | 102 +++++++++++++------- 5 files changed, 107 insertions(+), 69 deletions(-) diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh index 35adfa7..e831b05 100644 --- a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -16,7 +16,7 @@ __inline__ __device__ void PositiveFeedback(const float* vec1, float* vec2, float* grad, float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; - float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + float dot = Dot(vec1, vec2, num_dims); if (threadIdx.x == 0) { float exp_dot = expf(-dot); g = exp_dot / (1 + exp_dot) * lr; @@ -25,9 +25,8 @@ void PositiveFeedback(const float* vec1, float* vec2, float* grad, } __syncthreads(); for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { - float tmp = vec2[i]; + grad[i] += vec2[i] * g; vec2[i] += vec1[i] * g; - grad[i] += tmp * g; } __syncthreads(); } @@ -36,7 +35,7 @@ __inline__ __device__ void NegativeFeedback(const float* vec1, float* vec2, float* grad, float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; - float dot = Dot(emb_in[num_dims * j], emb_out[num_dims * k], num_dims); + float dot = Dot(vec1, vec2, num_dims); if (threadIdx.x == 0) { float exp_dot = expf(dot); g = exp_dot / (1 + exp_dot) * lr; @@ -45,9 +44,8 @@ void NegativeFeedback(const float* vec1, float* vec2, float* grad, } __syncthreads(); for (int i = threadIdx.x; i < num_dims; i += blockDim.x) { - float tmp = vec2[i]; + grad[i] -= vec2[i] * g; vec2[i] -= vec1[i] * g; - grad[i] -= tmp * g; } __syncthreads(); } diff --git a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh index d25ec13..1570d80 100644 --- a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh @@ -5,23 +5,24 @@ // LICENSE file in the root directory of this source tree. #pragma once #include "utils/cuda_utils_kernels.cuh" -#include "w2v/cuda_w2v_base_kernels.cuh" +#include "cuw2v/cuda_w2v_base_kernels.cuh" -using thrust::random::default_random_engine; -using thrust::random::uniform_int_distribution; namespace cusim { __global__ void W2VHsSgKernel( const int* cols, const int* indptr, const int window, const bool* codes, const int* points, const int* hs_indptr, - const int num_indptr, const int num_dims, + const int num_indptr, const int num_dims, const int window_size, + default_random_engine* rngs, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) { + default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; float& _loss_deno = loss_deno[blockIdx.x]; + uniform_int_distribution dist_window(0, window_size - 1); static __shared__ int reduced_windows; extern __shared__ float shared_memory[]; float* grad = &shared_memory[0]; @@ -43,7 +44,7 @@ __global__ void W2VHsSgKernel( if (k == j) continue; int beg3 = hs_indptr[cols[k]]; int end3 = hs_indptr[cols[k] + 1]; - if (int l = beg3; l < end3; ++l) { + for (int l = beg3; l < end3; ++l) { if (codes[l]) { PositiveFeedback(_emb_in, emb_out + num_dims * points[l], grad, _loss_nume, _loss_deno, num_dims, lr); @@ -66,14 +67,16 @@ __global__ void W2VHsSgKernel( __global__ void W2VHsCbowKernel( const int* cols, const int* indptr, const int window, const bool* codes, const int* points, const int* hs_indptr, - const int num_indptr, const int num_dims, const int neg, + const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { + default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; float& _loss_deno = loss_deno[blockIdx.x]; + uniform_int_distribution dist_window(0, window_size - 1); static __shared__ int reduced_windows; extern __shared__ float shared_memory[]; float* grad = &shared_memory[0]; @@ -110,9 +113,9 @@ __global__ void W2VHsCbowKernel( } __syncthreads(); - int beg3 = hs_indptr[cols[k]]; - int end3 = hs_indptr[cols[k] + 1]; - if (int k = beg3; k < end3; ++k) { + int beg3 = hs_indptr[cols[j]]; + int end3 = hs_indptr[cols[j] + 1]; + for (int k = beg3; k < end3; ++k) { if (codes[k]) { PositiveFeedback(cbow, emb_out + num_dims * points[k], grad, _loss_nume, _loss_deno, num_dims, lr); diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index f25342d..be8c160 100644 --- a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once #include "utils/cuda_utils_kernels.cuh" -#include "w2v/cuda_w2v_base_kernels.cuh" +#include "cuw2v/cuda_w2v_base_kernels.cuh" using thrust::random::default_random_engine; using thrust::random::uniform_int_distribution; @@ -15,17 +15,17 @@ namespace cusim { __global__ void W2VNegSgKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, - const int num_indptr, const int num_dims, const int neg, + const int num_indptr, const int num_dims, const int neg, const int window_size, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) { default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; float& _loss_deno = loss_deno[blockIdx.x]; - static __shared__ uniform_int_distribution dist_neg(0, random_size - 1); - static __shared__ uniform_int_distribution dist_window(0, window - 1); - static __shared__ int reduced_windows; - static __shared__ int neg_word; + uniform_int_distribution dist_neg(0, random_size - 1); + uniform_int_distribution dist_window(0, window_size - 1); + __shared__ int reduced_windows; + __shared__ int neg_word; extern __shared__ float shared_memory[]; float* grad = &shared_memory[0]; @@ -45,8 +45,8 @@ __global__ void W2VNegSgKernel( for (int k = beg2; k < end2; ++k) { if (k == j) continue; PositiveFeedback(_emb_in, emb_out + num_dims * cols[k], - grad, _loss_nume, _loss_deno, num_dims, lr) - if (int l = 0; l < neg; ++l) { + grad, _loss_nume, _loss_deno, num_dims, lr); + for (int l = 0; l < neg; ++l) { if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthreads(); NegativeFeedback(_emb_in, emb_out + num_dims * neg_word, @@ -66,7 +66,7 @@ __global__ void W2VNegSgKernel( __global__ void W2VNegCbowKernel( const int* cols, const int* indptr, const int window, const int* random_table, const int random_size, default_random_engine* rngs, - const int num_indptr, const int num_dims, const int neg, + const int num_indptr, const int num_dims, const int neg, const int window_size, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { @@ -74,8 +74,8 @@ __global__ void W2VNegCbowKernel( float& _loss_nume = loss_nume[blockIdx.x]; float& _loss_deno = loss_deno[blockIdx.x]; - static __shared__ uniform_int_distribution dist_neg(0, random_size - 1); - static __shared__ uniform_int_distribution dist_window(0, window - 1); + uniform_int_distribution dist_neg(0, random_size - 1); + uniform_int_distribution dist_window(0, window_size - 1); static __shared__ int reduced_windows; static __shared__ int neg_word; extern __shared__ float shared_memory[]; @@ -114,13 +114,13 @@ __global__ void W2VNegCbowKernel( __syncthreads(); PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad, - loss_nume, loss_deno, num_dims, lr); + _loss_nume, _loss_deno, num_dims, lr); __syncthreads(); // update negative feedback for (int k = 0; k < neg; ++k){ if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; - __syncthredas(); + __syncthreads(); NegativeFeedback(cbow, emb_out + num_dims * neg_word, grad, _loss_nume, _loss_deno, num_dims, lr); } diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index c722295..3313acd 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -32,6 +32,7 @@ namespace cusim { +bool CompareIndex(int lhs, int rhs); struct HuffmanTreeNode { float count; @@ -51,7 +52,11 @@ class CuW2V { bool Init(std::string opt_path); void LoadModel(float* emb_in, float* emb_out); void BuildHuffmanTree(const float* word_count, const int num_words); + void BuildRandomTable(const float* word_count, const int num_words, + const int table_size, const int num_threads); int GetBlockCnt(); + float FeedData(const int* cols, const int* indptr, + const int num_cols, const int num_indptr); private: DeviceInfo dev_info_; @@ -64,17 +69,17 @@ class CuW2V { // variables to construct huffman tree int max_depth_; - std::vector> codes_; - std::vector> points_; thrust::device_vector dev_codes_; - thrust::device_vector dev_points_, dev_indptr_; - + thrust::device_vector dev_points_, dev_hs_indptr_; + // related to negative sampling / hierarchical softmax and skip gram / cbow bool sg_; int neg_; - // mutex to handle concurrent model update - thrust::device_vector dev_mutex_in_, dev_mutex_out_; + // variables to construct random table + thrust::device_vector dev_random_table_; + int table_size_, table_seed_; + std::mt19937 table_rng_; }; } // namespace cusim diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index ee0d4a6..5cefa9f 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -4,7 +4,9 @@ // This source code is licensed under the Apache 2.0 license found in the // LICENSE file in the root directory of this source tree. #include "cuw2v/cuw2v.hpp" -#include "cuw2v/cuda_w2v_kernels.cuh" +#include "cuw2v/cuda_w2v_base_kernels.cuh" +#include "cuw2v/cuda_w2v_ns_kernels.cuh" +#include "cuw2v/cuda_w2v_hs_kernels.cuh" namespace cusim { @@ -37,17 +39,50 @@ bool CuW2V::Init(std::string opt_path) { block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); sg_ = opt_["skip_gram"].bool_value(); + // if zero, we will use hierarchical softmax neg_ = opt_["negative_sampling"].int_value(); + + // set seed for constructing random table of negative sampling + table_seed_ = opt_["table_seed"].int_value(); + const unsigned int table_seed = table_seed_; + table_rng_.seed(table_seed); + INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); return true; } +void CuW2V::BuildRandomTable(const float* word_count, const int num_words, + const int table_size, const int num_threads) { + num_words_ = num_words; + table_size_ = table_size; + std::vector acc; + float cumsum = 0; + for (int i = 0; i < num_words; ++i) { + cumsum += word_count[i]; + acc.push_back(cumsum); + } + + std::uniform_real_distribution dist(0.0f, cumsum); + dev_random_table_.resize(table_size_); + std::vector host_random_table(table_size); + #pragma omp parallel num_threads(num_threads) + { + #pragma omp for schedule(static) + for (int i = 0; i < table_size_; ++i) { + float r = dist(table_rng_); + int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin(); + host_random_table[i] = pos; + } + } + + thrust::copy(host_random_table.begin(), host_random_table.end(), dev_random_table_.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); +} void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { num_words_ = num_words; - if (neg_) return; huffman_nodes.clear(); std::priority_queue, decltype(&CompareIndex)> pq(CompareIndex); @@ -64,51 +99,51 @@ void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { std::vector, std::vector>> stack = {{pq.top(), {}, {}}}; int nodeid; - std::vector codes; - std::vector points; - codes_.clear(); points_.clear(); - codes_.resize(num_words); points_.resize(num_words); + std::vector code; + std::vector point; + std::vector> codes(num_words); + std::vector> points(num_words); max_depth_ = 0; while (not stack.empty()) { - std::tie(nodeid, codes, points) = stack.back(); + std::tie(nodeid, code, point) = stack.back(); stack.pop_back(); if (nodeid < num_words) { - codes_[nodeid] = codes; - points_[nodeid] = points; + codes[nodeid] = code; + points[nodeid] = point; max_depth_ = std::max(max_depth_, - static_cast(codes.size())); + static_cast(code.size())); } else { - points.push_back(nodeid - num_words); - std::vector left_codes = codes; - std::vector right_codes = codes; - left_codes.push_back(false); - right_codes.push_back(true); + point.push_back(nodeid - num_words); + std::vector left_code = code; + std::vector right_code = code; + left_code.push_back(false); + right_code.push_back(true); auto& node = huffman_nodes[nodeid]; - stack.push_back(make_tuple(node.left, left_codes, points)); - stack.push_back(make_tuple(node.right, right_codes, points)); + stack.push_back(make_tuple(node.left, left_code, point)); + stack.push_back(make_tuple(node.right, right_code, point)); } } - std::vector host_codes; + std::vector host_codes; std::vector host_points; - std::vector host_indptr = {0}; + std::vector host_hs_indptr = {0}; int size = 0; for (int i = 0; i < num_words; ++i) { - auto& codes = codes_[i]; - auto& points = points_[i]; - int n = codes.size(); + code = codes[i]; + point = points[i]; + int n = code.size(); size += n; - host_indptr.push_back(size); + host_hs_indptr.push_back(size); for (int j = 0; j < n; ++j) { - host_codes.push_back(static_cast(codes[j])); - host_points.push_back(points[j]); + host_codes.push_back(code[j]); + host_points.push_back(point[j]); } } - dev_codes_.resize(size); dev_points_.resize(size), dev_indptr_.resize(num_words + 1); + dev_codes_.resize(size); dev_points_.resize(size), dev_hs_indptr_.resize(num_words + 1); thrust::copy(host_codes.begin(), host_codes.end(), dev_codes_.begin()); thrust::copy(host_points.begin(), host_points.end(), dev_points_.begin()); - thrust::copy(host_indptr.begin(), host_indptr.end(), dev_indptr_.begin()); + thrust::copy(host_hs_indptr.begin(), host_hs_indptr.end(), dev_hs_indptr_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); } @@ -123,14 +158,6 @@ void CuW2V::LoadModel(float* emb_in, float* emb_out) { thrust::copy(emb_out, emb_out + out_words * num_dims_, dev_emb_out_.begin()); emb_in_ = emb_in; emb_out_ = emb_out; - // set mutex - dev_mutex_in_.resize(num_words_); - dev_mutex_out_.resize(out_words); - std::vector host_mutex_in(num_words_, 0); - std::vector host_mutex_out(out_words, 0); - thrust::copy(host_mutex_in.begin(), host_mutex_in.end(), dev_mutex_in_.begin()); - thrust::copy(host_mutex_out.begin(), host_mutex_out.end(), dev_mutex_out_.begin()); - CHECK_CUDA(cudaDeviceSynchronize()); } @@ -138,4 +165,9 @@ int CuW2V::GetBlockCnt() { return block_cnt_; } + +float FeedData(const int* cols, const int* indptr, const int num_cols, const int* num_indptr) { + return 0; +} + } // namespace cusim From add8461a6d3797a8cb157a6aa9b6f5c416ff34c6 Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 23:04:54 +0900 Subject: [PATCH 11/18] implement FeedData --- cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh | 12 +-- cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh | 18 ++-- cpp/include/cuw2v/cuw2v.hpp | 15 ++-- cpp/src/cuw2v/cuw2v.cu | 101 ++++++++++++++++++++-- 4 files changed, 117 insertions(+), 29 deletions(-) diff --git a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh index 1570d80..c7aaca5 100644 --- a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh @@ -11,7 +11,7 @@ namespace cusim { __global__ void W2VHsSgKernel( - const int* cols, const int* indptr, const int window, + const int* cols, const int* indptr, const bool* codes, const int* points, const int* hs_indptr, const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs, @@ -37,8 +37,8 @@ __global__ void W2VHsSgKernel( for (int j = beg; j < end; ++j) { if (threadIdx.x == 0) reduced_windows = dist_window(rng); __syncthreads(); - int beg2 = max(beg, j - window + reduced_windows); - int end2 = min(end, j + window - reduced_windows + 1); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); float* _emb_in = emb_in + num_dims * cols[j]; for (int k = beg2; k < end2; ++k) { if (k == j) continue; @@ -65,7 +65,7 @@ __global__ void W2VHsSgKernel( } __global__ void W2VHsCbowKernel( - const int* cols, const int* indptr, const int window, + const int* cols, const int* indptr, const bool* codes, const int* points, const int* hs_indptr, const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs, float* emb_in, float* emb_out, @@ -89,8 +89,8 @@ __global__ void W2VHsCbowKernel( for (int j = beg; j < end; ++j) { if (threadIdx.x == 0) reduced_windows = dist_window(rng); __syncthreads(); - int beg2 = max(beg, j - window + reduced_windows); - int end2 = min(end, j + window - reduced_windows + 1); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); if (end2 - beg2 <= 1) continue; // zero-initialize shared mem diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index be8c160..8f6bef0 100644 --- a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -7,14 +7,12 @@ #include "utils/cuda_utils_kernels.cuh" #include "cuw2v/cuda_w2v_base_kernels.cuh" -using thrust::random::default_random_engine; -using thrust::random::uniform_int_distribution; namespace cusim { __global__ void W2VNegSgKernel( - const int* cols, const int* indptr, const int window, - const int* random_table, const int random_size, default_random_engine* rngs, + const int* cols, const int* indptr, + const int* random_table, default_random_engine* rngs, const int random_size, const int num_indptr, const int num_dims, const int neg, const int window_size, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) { @@ -39,8 +37,8 @@ __global__ void W2VNegSgKernel( for (int j = beg; j < end; ++j) { if (threadIdx.x == 0) reduced_windows = dist_window(rng); __syncthreads(); - int beg2 = max(beg, j - window + reduced_windows); - int end2 = min(end, j + window - reduced_windows + 1); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); float* _emb_in = emb_in + num_dims * cols[j]; for (int k = beg2; k < end2; ++k) { if (k == j) continue; @@ -64,8 +62,8 @@ __global__ void W2VNegSgKernel( } __global__ void W2VNegCbowKernel( - const int* cols, const int* indptr, const int window, - const int* random_table, const int random_size, default_random_engine* rngs, + const int* cols, const int* indptr, + const int* random_table, default_random_engine* rngs, const int random_size, const int num_indptr, const int num_dims, const int neg, const int window_size, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { @@ -89,8 +87,8 @@ __global__ void W2VNegCbowKernel( for (int j = beg; j < end; ++j) { if (threadIdx.x == 0) reduced_windows = dist_window(rng); __syncthreads(); - int beg2 = max(beg, j - window + reduced_windows); - int end2 = min(end, j + window - reduced_windows + 1); + int beg2 = max(beg, j - window_size + reduced_windows); + int end2 = min(end, j + window_size - reduced_windows + 1); if (end2 - beg2 <= 1) continue; // zero-initialize shared mem diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index 3313acd..5d499ce 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -30,6 +30,8 @@ #include "utils/log.hpp" #include "utils/types.hpp" +using thrust::random::default_random_engine; + namespace cusim { bool CompareIndex(int lhs, int rhs); @@ -55,7 +57,7 @@ class CuW2V { void BuildRandomTable(const float* word_count, const int num_words, const int table_size, const int num_threads); int GetBlockCnt(); - float FeedData(const int* cols, const int* indptr, + std::pair FeedData(const int* cols, const int* indptr, const int num_cols, const int num_indptr); private: @@ -63,23 +65,24 @@ class CuW2V { json11::Json opt_; std::shared_ptr logger_; int block_cnt_, block_dim_; - int num_dims_, num_words_; - float *emb_in_, *emb_out_; + int num_dims_, num_words_, window_size_; + float *emb_in_, *emb_out_, lr_; thrust::device_vector dev_emb_in_, dev_emb_out_; // variables to construct huffman tree int max_depth_; - thrust::device_vector dev_codes_; + thrust::device_vector dev_codes_; thrust::device_vector dev_points_, dev_hs_indptr_; // related to negative sampling / hierarchical softmax and skip gram / cbow - bool sg_; + bool sg_, use_mean_; int neg_; // variables to construct random table thrust::device_vector dev_random_table_; - int table_size_, table_seed_; + int random_size_, table_seed_; std::mt19937 table_rng_; + thrust::device_vector dev_rngs_; }; } // namespace cusim diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index 5cefa9f..ae02ac8 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -39,7 +39,10 @@ bool CuW2V::Init(std::string opt_path) { block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); sg_ = opt_["skip_gram"].bool_value(); - + use_mean_ = opt_["use_mean"].bool_value(); + window_size_ = opt_["window_size"].int_value(); + lr_ = opt_["lr"].number_value(); + // if zero, we will use hierarchical softmax neg_ = opt_["negative_sampling"].int_value(); @@ -47,7 +50,7 @@ bool CuW2V::Init(std::string opt_path) { table_seed_ = opt_["table_seed"].int_value(); const unsigned int table_seed = table_seed_; table_rng_.seed(table_seed); - + INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); return true; @@ -56,7 +59,7 @@ bool CuW2V::Init(std::string opt_path) { void CuW2V::BuildRandomTable(const float* word_count, const int num_words, const int table_size, const int num_threads) { num_words_ = num_words; - table_size_ = table_size; + random_size_ = table_size; std::vector acc; float cumsum = 0; for (int i = 0; i < num_words; ++i) { @@ -65,12 +68,12 @@ void CuW2V::BuildRandomTable(const float* word_count, const int num_words, } std::uniform_real_distribution dist(0.0f, cumsum); - dev_random_table_.resize(table_size_); + dev_random_table_.resize(random_size_); std::vector host_random_table(table_size); #pragma omp parallel num_threads(num_threads) { #pragma omp for schedule(static) - for (int i = 0; i < table_size_; ++i) { + for (int i = 0; i < random_size_; ++i) { float r = dist(table_rng_); int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin(); host_random_table[i] = pos; @@ -166,8 +169,92 @@ int CuW2V::GetBlockCnt() { } -float FeedData(const int* cols, const int* indptr, const int num_cols, const int* num_indptr) { - return 0; +std::pair CuW2V::FeedData(const int* cols, const int* indptr, + const int num_cols, const int num_indptr) { + + // copy feed data to GPU memory + thrust::device_vector dev_cols(num_cols); + thrust::device_vector dev_indptr(num_indptr + 1); + thrust::device_vector dev_loss_nume(block_cnt_, 0.0f); + thrust::device_vector dev_loss_deno(block_cnt_, 0.0f); + thrust::copy(cols, cols + num_cols, dev_cols.begin()); + thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + DEBUG0("copy feed data to GPU memory"); + + // run GPU kernels + if (neg_ > 0) { + if (sg_) { + W2VNegSgKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_random_table_.data()), + thrust::raw_pointer_cast(dev_rngs_.data()), + random_size_, num_indptr, num_dims_, neg_, window_size_, + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + lr_); + } else { + W2VNegCbowKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_random_table_.data()), + thrust::raw_pointer_cast(dev_rngs_.data()), + random_size_, num_indptr, num_dims_, neg_, window_size_, + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + use_mean_, lr_); + } + } else { + if (sg_) { + W2VHsSgKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_codes_.data()), + thrust::raw_pointer_cast(dev_points_.data()), + thrust::raw_pointer_cast(dev_hs_indptr_.data()), + num_indptr, num_dims_, window_size_, + thrust::raw_pointer_cast(dev_rngs_.data()), + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + lr_); + + } else { + W2VHsCbowKernel<<>>( + thrust::raw_pointer_cast(dev_cols.data()), + thrust::raw_pointer_cast(dev_indptr.data()), + thrust::raw_pointer_cast(dev_codes_.data()), + thrust::raw_pointer_cast(dev_points_.data()), + thrust::raw_pointer_cast(dev_hs_indptr_.data()), + num_indptr, num_dims_, window_size_, + thrust::raw_pointer_cast(dev_rngs_.data()), + thrust::raw_pointer_cast(dev_emb_in_.data()), + thrust::raw_pointer_cast(dev_emb_out_.data()), + thrust::raw_pointer_cast(dev_loss_nume.data()), + thrust::raw_pointer_cast(dev_loss_deno.data()), + use_mean_, lr_); + + } + + } + CHECK_CUDA(cudaDeviceSynchronize()); + + // accumulate loss nume / deno + std::vector loss_nume(block_cnt_), loss_deno(block_cnt_); + thrust::copy(dev_loss_nume.begin(), dev_loss_nume.end(), loss_nume.begin()); + thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_nume.begin()); + CHECK_CUDA(cudaDeviceSynchronize()); + float loss_nume_sum = std::accumulate(loss_nume.begin(), loss_nume.end(), 0.0f); + float loss_deno_sum = std::accumulate(loss_deno.begin(), loss_deno.end(), 0.0f); + DEBUG("loss nume: {}, deno: {}", loss_nume_sum, loss_deno_sum); + + return {loss_nume_sum, loss_deno_sum}; } } // namespace cusim From 084c4cb27c2a2e5e5bcdf0f2b3f3ee38cb86691e Mon Sep 17 00:00:00 2001 From: js1010 Date: Fri, 12 Feb 2021 23:21:01 +0900 Subject: [PATCH 12/18] implement InitRngsKernel --- cpp/include/cuw2v/cuw2v.hpp | 3 +-- cpp/include/utils/cuda_utils_kernels.cuh | 4 ++++ cpp/src/cuw2v/cuw2v.cu | 21 ++++++++++++++------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index 5d499ce..133aac8 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -80,8 +80,7 @@ class CuW2V { // variables to construct random table thrust::device_vector dev_random_table_; - int random_size_, table_seed_; - std::mt19937 table_rng_; + int random_size_, table_seed_, cuda_seed_; thrust::device_vector dev_rngs_; }; diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh index 670303d..126b00f 100644 --- a/cpp/include/utils/cuda_utils_kernels.cuh +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -208,4 +208,8 @@ float ReduceSum(const float* vec, const int length) { return shared[0]; } +__global__ void InitRngsKernel(default_random_engine* rngs, int rand_seed) { + rngs[blockIdx.x].seed(blockIdx.x + rand_seed); +} + } // namespace cusim diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index ae02ac8..cec5620 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -46,11 +46,13 @@ bool CuW2V::Init(std::string opt_path) { // if zero, we will use hierarchical softmax neg_ = opt_["negative_sampling"].int_value(); - // set seed for constructing random table of negative sampling + // random seed table_seed_ = opt_["table_seed"].int_value(); - const unsigned int table_seed = table_seed_; - table_rng_.seed(table_seed); - + cuda_seed_ = opt_["cuda_seed"].int_value(); + dev_rngs_.resize(block_cnt_); + InitRngsKernel<<>>( + thrust::raw_pointer_cast(dev_rngs_.data()), cuda_seed_); + INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); return true; @@ -63,22 +65,25 @@ void CuW2V::BuildRandomTable(const float* word_count, const int num_words, std::vector acc; float cumsum = 0; for (int i = 0; i < num_words; ++i) { - cumsum += word_count[i]; acc.push_back(cumsum); + cumsum += word_count[i]; } - std::uniform_real_distribution dist(0.0f, cumsum); dev_random_table_.resize(random_size_); std::vector host_random_table(table_size); #pragma omp parallel num_threads(num_threads) { + const unsigned int table_seed = table_seed_ + omp_get_thread_num(); + std::mt19937 rng(table_seed); + std::uniform_real_distribution dist(0.0f, cumsum); #pragma omp for schedule(static) for (int i = 0; i < random_size_; ++i) { - float r = dist(table_rng_); + float r = dist(rng); int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin(); host_random_table[i] = pos; } } + table_seed_ += num_threads; thrust::copy(host_random_table.begin(), host_random_table.end(), dev_random_table_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); @@ -148,6 +153,8 @@ void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { thrust::copy(host_points.begin(), host_points.end(), dev_points_.begin()); thrust::copy(host_hs_indptr.begin(), host_hs_indptr.end(), dev_hs_indptr_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); + + huffman_nodes.clear(); } void CuW2V::LoadModel(float* emb_in, float* emb_out) { From 7cf87c91c0216453e64b028523c69f116ddc1280 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 00:04:45 +0900 Subject: [PATCH 13/18] implement bindings --- cpp/include/cuw2v/cuw2v.hpp | 1 + cpp/src/cuw2v/cuw2v.cu | 6 ++ cusim/cuw2v/__init__.py | 6 ++ cusim/cuw2v/bindings.cc | 107 +++++++++++++++++++++++ cusim/cuw2v/pycuw2v.py | 167 ++++++++++++++++++++++++++++++++++++ 5 files changed, 287 insertions(+) create mode 100644 cusim/cuw2v/__init__.py create mode 100644 cusim/cuw2v/bindings.cc create mode 100644 cusim/cuw2v/pycuw2v.py diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index 133aac8..d59fbbc 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -59,6 +59,7 @@ class CuW2V { int GetBlockCnt(); std::pair FeedData(const int* cols, const int* indptr, const int num_cols, const int num_indptr); + void Pull(); private: DeviceInfo dev_info_; diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index cec5620..6500b00 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -264,4 +264,10 @@ std::pair CuW2V::FeedData(const int* cols, const int* indptr, return {loss_nume_sum, loss_deno_sum}; } +void CuW2V::Pull() { + thrust::copy(dev_emb_in_.begin(), dev_emb_in_.end(), emb_in_); + thrust::copy(dev_emb_out_.begin(), dev_emb_out_.end(), emb_out_); + CHECK_CUDA(cudaDeviceSynchronize()); +} + } // namespace cusim diff --git a/cusim/cuw2v/__init__.py b/cusim/cuw2v/__init__.py new file mode 100644 index 0000000..c3acfc5 --- /dev/null +++ b/cusim/cuw2v/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. +from cusim.cuw2v.pycuw2v import CuW2V diff --git a/cusim/cuw2v/bindings.cc b/cusim/cuw2v/bindings.cc new file mode 100644 index 0000000..643228d --- /dev/null +++ b/cusim/cuw2v/bindings.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2021 Jisang Yoon +// All rights reserved. +// +// This source code is licensed under the Apache 2.0 license found in the +// LICENSE file in the root directory of this source tree. +#include +#include +#include + +#include +#include "cuw2v/cuw2v.hpp" + +namespace py = pybind11; + +typedef py::array_t float_array; +typedef py::array_t int_array; + +class CuW2VBind { + public: + CuW2VBind() {} + + bool Init(std::string opt_path) { + return obj_.Init(opt_path); + } + + void LoadModel(py::object& emb_in, py::object& emb_out) { + // check shape of alpha and beta + float_array _emb_in(emb_in); + float_array _emb_out(emb_out); + auto emb_in_buffer = _emb_in.request(); + auto emb_out_buffer = _emb_out.request(); + if (emb_in_buffer.ndim != 2 or emb_out_buffer_buffer.ndim != 2 or + emb_in_buffer.shape[1] != emb_out_buffer.shape[1]) { + throw std::runtime_error("invalid emb_in or emb_out"); + } + + return obj_.LoadModel(_emb_in.mutable_data(0), _emb_out.mutable_data(0)); + } + + void BuildRandomTable(py::object& word_count, int table_size, int num_threads) { + float_array _word_count(word_count); + auto wc_buffer = _word_count.requiest(); + if (wc_buffer.ndim != 1) { + throw std::runtime_error("invalid word count"); + } + int num_words = wc_buffer.shape[0]; + obj_.BuildRandomTable(_word_count.data(0), num_words, table_size, num_threads); + } + + void BuildHuffmanTree(py::object& word_count) { + float_array _word_count(word_count); + auto wc_buffer = _word_count.requiest(); + if (wc_buffer.ndim != 1) { + throw std::runtime_error("invalid word count"); + } + int num_words = wc_buffer.shape[0]; + obj_.BuildHuffmanTree(_word_count.data(0), num_words); + } + + std::pair FeedData(py::object& cols, py::object& indptr) { + int_array _cols(cols); + int_array _indptr(indptr); + auto cols_buffer = _cols.request(); + auto indptr_buffer = _indptr.request(); + if (cols_buffer.ndim != 1 or indptr_buffer.ndim != 1) { + throw std::runtime_error("invalid cols or indptr"); + } + int num_cols = cols_buffer.shape[0]; + int num_indptr = indptr_buffer.shape[0] - 1; + return obj_.FeedData(_cols.data(0), _indptr.data(0), num_cols, num_indptr); + } + + void Pull() { + obj_.Pull(); + } + + int GetBlockCnt() { + return obj_.GetBlockCnt(); + } + + private: + cusim::CuW2V obj_; +}; + +PYBIND11_PLUGIN(cuw2v_bind) { + py::module m("CuW2VBind"); + + py::class_(m, "CuW2VBind") + .def(py::init()) + .def("init", &CuW2VBind::Init, py::arg("opt_path")) + .def("load_model", &CuW2VBind::LoadModel, + py::arg("emb_in"), py::arg("emb_out")) + .def("feed_data", &CuW2VBind::FeedData, + py::arg("cols"), py::arg("indptr")) + .def("pull", &CuW2VBind::Pull) + .def("build_random_table", &CuW2VBind::BuildRandomTable, + py::arg("word_count"), py::arg("table_size"), py::arg("num_threads")) + .def("build_huffman_tree", &CuW2VBind::BuildHuffmanTree, + py::arg("word_count")) + .def("get_block_cnt", &CuW2VBind::GetBlockCnt) + .def("__repr__", + [](const CuW2VBind &a) { + return ""; + } + ); + return m.ptr(); +} diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py new file mode 100644 index 0000000..e0fa3d9 --- /dev/null +++ b/cusim/cuw2v/pycuw2v.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member +import os +from os.path import join as pjoin + +import json +import tempfile + +import h5py +import numpy as np +from scipy.special import polygamma as pg + +from cusim import aux, IoUtils +from cusim.culda.culda_bind import CuLDABind +from cusim.config_pb2 import CuLDAConfigProto + +EPS = 1e-10 + +class CuLDA: + def __init__(self, opt=None): + self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) + self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) + opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) + tmp.write(opt_content) + tmp.close() + + self.logger.info("opt: %s", opt_content) + self.obj = CuLDABind() + assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" + os.remove(tmp.name) + + self.words, self.num_words, self.num_docs = None, None, None + self.alpha, self.beta, self.grad_alpha, self.new_beta = \ + None, None, None, None + + def preprocess_data(self): + if self.opt.skip_preprocess: + return + iou = IoUtils() + if not self.opt.processed_data_dir: + self.opt.processed_data_dir = tempfile.TemporaryDirectory().name + iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, + self.opt.processed_data_dir) + + def init_model(self): + # load voca + data_dir = self.opt.processed_data_dir + self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) + with open(pjoin(data_dir, "keys.txt"), "rb") as fin: + self.words = [line.strip() for line in fin] + self.num_words = len(self.words) + + # count number of docs + h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") + self.num_docs = h5f["indptr"].shape[0] - 1 + h5f.close() + + self.logger.info("number of words: %d, docs: %d", + self.num_words, self.num_docs) + + # random initialize alpha and beta + np.random.seed(self.opt.seed) + self.alpha = np.random.uniform( \ + size=(self.opt.num_topics,)).astype(np.float32) + self.beta = np.random.uniform( \ + size=(self.num_words, self.opt.num_topics)).astype(np.float32) + self.beta /= np.sum(self.beta, axis=0)[None, :] + self.logger.info("alpha %s, beta %s initialized", + self.alpha.shape, self.beta.shape) + + # zero initialize grad alpha and new beta + block_cnt = self.obj.get_block_cnt() + self.grad_alpha = np.zeros(shape=(block_cnt, self.opt.num_topics), + dtype=np.float32) + self.new_beta = np.zeros(shape=self.beta.shape, dtype=np.float32) + self.logger.info("grad alpha %s, new beta %s initialized", + self.grad_alpha.shape, self.new_beta.shape) + + # push it to gpu + self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) + + def train_model(self): + self.preprocess_data() + self.init_model() + h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") + for epoch in range(1, self.opt.epochs + 1): + self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) + self._train_e_step(h5f) + self._train_m_step() + h5f.close() + + def _train_e_step(self, h5f): + offset, size = 0, h5f["cols"].shape[0] + pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) + train_loss_nume, train_loss_deno = 0, 0 + vali_loss_nume, vali_loss_deno = 0, 0 + while True: + target = h5f["indptr"][offset] + self.opt.batch_size + if target < size: + next_offset = h5f["rows"][target] + else: + next_offset = h5f["indptr"].shape[0] - 1 + indptr = h5f["indptr"][offset:next_offset + 1] + beg, end = indptr[0], indptr[-1] + indptr -= beg + cols = h5f["cols"][beg:end] + vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) + offset = next_offset + + # call cuda kernel + train_loss, vali_loss = \ + self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) + + # accumulate loss + train_loss_nume -= train_loss + vali_loss_nume -= vali_loss + vali_cnt = np.count_nonzero(vali) + train_cnt = len(vali) - vali_cnt + train_loss_deno += train_cnt + vali_loss_deno += vali_cnt + train_loss = train_loss_nume / (train_loss_deno + EPS) + vali_loss = vali_loss_nume / (vali_loss_deno + EPS) + + # update progress bar + pbar.update(end, values=[("train_loss", train_loss), + ("vali_loss", vali_loss)]) + if end == size: + break + + def _train_m_step(self): + self.obj.pull() + + # update beta + self.new_beta[:, :] = np.maximum(self.new_beta, EPS) + self.beta[:, :] = self.new_beta / np.sum(self.new_beta, axis=0)[None, :] + self.new_beta[:, :] = 0 + + # update alpha + alpha_sum = np.sum(self.alpha) + gvec = np.sum(self.grad_alpha, axis=0) + gvec += self.num_docs * (pg(0, alpha_sum) - pg(0, self.alpha)) + hvec = self.num_docs * pg(1, self.alpha) + z_0 = pg(1, alpha_sum) + c_nume = np.sum(gvec / hvec) + c_deno = 1 / z_0 + np.sum(1 / hvec) + c_0 = c_nume / c_deno + delta = (gvec - c_0) / hvec + self.alpha -= delta + self.alpha[:] = np.maximum(self.alpha, EPS) + self.grad_alpha[:,:] = 0 + + self.obj.push() + + def save_model(self, model_path): + self.logger.info("save model path: %s", model_path) + h5f = h5py.File(model_path, "w") + h5f.create_dataset("alpha", data=self.alpha) + h5f.create_dataset("beta", data=self.beta) + h5f.create_dataset("keys", data=np.array(self.words)) + h5f.close() From b805ed0deab9ea711626c3310b90adb4dcf00259 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 00:31:38 +0900 Subject: [PATCH 14/18] first draft --- cpp/include/utils/ioutils.hpp | 2 +- cpp/src/cuw2v/cuw2v.cu | 2 +- cpp/src/utils/ioutils.cc | 11 ++-- cusim/__init__.py | 1 + cusim/cuw2v/pycuw2v.py | 116 ++++++++++++---------------------- cusim/ioutils/bindings.cc | 6 +- cusim/ioutils/pyioutils.py | 12 ++-- cusim/proto/config.proto | 32 ++++++++++ setup.py | 15 +++++ 9 files changed, 106 insertions(+), 91 deletions(-) diff --git a/cpp/include/utils/ioutils.hpp b/cpp/include/utils/ioutils.hpp index 756b4b2..00639a4 100644 --- a/cpp/include/utils/ioutils.hpp +++ b/cpp/include/utils/ioutils.hpp @@ -33,7 +33,7 @@ class IoUtils { int LoadStreamFile(std::string filepath); std::pair ReadStreamForVocab(int num_lines, int num_threads); std::pair TokenizeStream(int num_lines, int num_threads); - void GetWordVocab(int min_count, std::string keys_path); + void GetWordVocab(int min_count, std::string keys_path, std::string count_path); void GetToken(int* rows, int* cols, int* indptr); private: void ParseLine(std::string line, std::vector& line_vec); diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index 6500b00..c122783 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -44,7 +44,7 @@ bool CuW2V::Init(std::string opt_path) { lr_ = opt_["lr"].number_value(); // if zero, we will use hierarchical softmax - neg_ = opt_["negative_sampling"].int_value(); + neg_ = opt_["neg"].int_value(); // random seed table_seed_ = opt_["table_seed"].int_value(); diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc index 14bd94e..802712c 100644 --- a/cpp/src/utils/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -153,7 +153,7 @@ std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) return {read_lines, word_count_.size()}; } -void IoUtils::GetWordVocab(int min_count, std::string keys_path) { +void IoUtils::GetWordVocab(int min_count, std::string keys_path, std::string count_path) { INFO("number of raw words: {}", word_count_.size()); word_idmap_.clear(); word_list_.clear(); for (auto& it: word_count_) { @@ -164,13 +164,16 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path) { } INFO("number of words after filtering: {}", word_list_.size()); - // write keys to csv file - std::ofstream fout(keys_path.c_str()); + // write keys and count to csv file + std::ofstream fout1(keys_path.c_str()); + std::ofstream fout2(count_path.c_str()); INFO("dump keys to {}", keys_path); int n = word_list_.size(); for (int i = 0; i < n; ++i) { std::string line = word_list_[i] + "\n"; - fout.write(line.c_str(), line.size()); + fout1.write(line.c_str(), line.size()); + line = std::to_string(word_count_[word_list_[i]]) + "\n"; + fout2.write(line.c_str(), line.size()); } fout.close(); } diff --git a/cusim/__init__.py b/cusim/__init__.py index 24fe984..a1d864c 100644 --- a/cusim/__init__.py +++ b/cusim/__init__.py @@ -5,3 +5,4 @@ # LICENSE file in the root directory of this source tree. from cusim.ioutils import IoUtils from cusim.culda import CuLDA +from cusim.cuw2v import CuW2V diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py index e0fa3d9..8ba4e4c 100644 --- a/cusim/cuw2v/pycuw2v.py +++ b/cusim/cuw2v/pycuw2v.py @@ -13,17 +13,16 @@ import h5py import numpy as np -from scipy.special import polygamma as pg from cusim import aux, IoUtils -from cusim.culda.culda_bind import CuLDABind -from cusim.config_pb2 import CuLDAConfigProto +from cusim.cuw2v.cuw2v_bind import CuW2VBind +from cusim.config_pb2 import CuW2VConfigProto EPS = 1e-10 -class CuLDA: +class CuW2V: def __init__(self, opt=None): - self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) + self.opt = aux.get_opt_as_proto(opt or {}, CuW2VConfigProto) self.logger = aux.get_logger("culda", level=self.opt.py_log_level) tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) @@ -32,13 +31,13 @@ def __init__(self, opt=None): tmp.close() self.logger.info("opt: %s", opt_content) - self.obj = CuLDABind() + self.obj = CuW2VBind() assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" os.remove(tmp.name) - self.words, self.num_words, self.num_docs = None, None, None - self.alpha, self.beta, self.grad_alpha, self.new_beta = \ + self.words, self.word_count, self.num_words, self.num_docs = \ None, None, None, None + self.emb_in, self.emb_out = None, None def preprocess_data(self): if self.opt.skip_preprocess: @@ -52,9 +51,15 @@ def preprocess_data(self): def init_model(self): # load voca data_dir = self.opt.processed_data_dir - self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) - with open(pjoin(data_dir, "keys.txt"), "rb") as fin: + keys_path = pjoin(data_dir, "keys.txt") + count_path = pjoin(data_dir, "count.txt") + self.logger.info("load key, count from %s, %s", keys_path, count_path) + with open(keys_path, "rb") as fin: self.words = [line.strip() for line in fin] + with open(count_path, "rb") as fin: + self.word_count = np.array([float(line.strip()) for line in fin], + dtype=np.float32) + self.word_count = np.power(self.word_count, self.opt.count_power) self.num_words = len(self.words) # count number of docs @@ -67,40 +72,33 @@ def init_model(self): # random initialize alpha and beta np.random.seed(self.opt.seed) - self.alpha = np.random.uniform( \ - size=(self.opt.num_topics,)).astype(np.float32) - self.beta = np.random.uniform( \ - size=(self.num_words, self.opt.num_topics)).astype(np.float32) - self.beta /= np.sum(self.beta, axis=0)[None, :] - self.logger.info("alpha %s, beta %s initialized", - self.alpha.shape, self.beta.shape) - - # zero initialize grad alpha and new beta - block_cnt = self.obj.get_block_cnt() - self.grad_alpha = np.zeros(shape=(block_cnt, self.opt.num_topics), - dtype=np.float32) - self.new_beta = np.zeros(shape=self.beta.shape, dtype=np.float32) - self.logger.info("grad alpha %s, new beta %s initialized", - self.grad_alpha.shape, self.new_beta.shape) + self.emb_in = np.random.normal( \ + size=(self.num_words, self.opt.num_dims)).astype(np.float32) + out_words = self.num_words if self.opt.neg else self.num_words - 1 + self.emb_out = np.random.uniform( \ + size=(out_words, self.opt.num_dims)).astype(np.float32) + self.logger.info("emb_in %s, emb_out %s initialized", + self.emb_in.shape, self.emb_out.shape) # push it to gpu - self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) + self.obj.load_model(self.emb_in, self.emb_out) def train_model(self): self.preprocess_data() self.init_model() + if not self.opt.neg: + self.obj.build_huffman_tree(self.word_count) h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") for epoch in range(1, self.opt.epochs + 1): self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) - self._train_e_step(h5f) - self._train_m_step() + self._train_epoch(h5f) + self.pull() h5f.close() - def _train_e_step(self, h5f): + def _train_epoch(self, h5f): offset, size = 0, h5f["cols"].shape[0] - pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) - train_loss_nume, train_loss_deno = 0, 0 - vali_loss_nume, vali_loss_deno = 0, 0 + pbar = aux.Progbar(size, stateful_metrics=["loss"]) + loss_nume, loss_deno = 0, 0 while True: target = h5f["indptr"][offset] + self.opt.batch_size if target < size: @@ -111,57 +109,21 @@ def _train_e_step(self, h5f): beg, end = indptr[0], indptr[-1] indptr -= beg cols = h5f["cols"][beg:end] - vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) offset = next_offset # call cuda kernel - train_loss, vali_loss = \ - self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) + if self.opt.neg: + self.obj.build_random_table( \ + self.word_count, self.opt.random_size, self.opt.num_threads) + _loss_nume, _loss_deno = \ + self.obj.feed_data(cols, indptr) # accumulate loss - train_loss_nume -= train_loss - vali_loss_nume -= vali_loss - vali_cnt = np.count_nonzero(vali) - train_cnt = len(vali) - vali_cnt - train_loss_deno += train_cnt - vali_loss_deno += vali_cnt - train_loss = train_loss_nume / (train_loss_deno + EPS) - vali_loss = vali_loss_nume / (vali_loss_deno + EPS) + loss_nume += _loss_nume + loss_deno += _loss_deno + loss = loss_nume / (loss_deno + EPS) # update progress bar - pbar.update(end, values=[("train_loss", train_loss), - ("vali_loss", vali_loss)]) + pbar.update(end, values=[("loss", loss)]) if end == size: break - - def _train_m_step(self): - self.obj.pull() - - # update beta - self.new_beta[:, :] = np.maximum(self.new_beta, EPS) - self.beta[:, :] = self.new_beta / np.sum(self.new_beta, axis=0)[None, :] - self.new_beta[:, :] = 0 - - # update alpha - alpha_sum = np.sum(self.alpha) - gvec = np.sum(self.grad_alpha, axis=0) - gvec += self.num_docs * (pg(0, alpha_sum) - pg(0, self.alpha)) - hvec = self.num_docs * pg(1, self.alpha) - z_0 = pg(1, alpha_sum) - c_nume = np.sum(gvec / hvec) - c_deno = 1 / z_0 + np.sum(1 / hvec) - c_0 = c_nume / c_deno - delta = (gvec - c_0) / hvec - self.alpha -= delta - self.alpha[:] = np.maximum(self.alpha, EPS) - self.grad_alpha[:,:] = 0 - - self.obj.push() - - def save_model(self, model_path): - self.logger.info("save model path: %s", model_path) - h5f = h5py.File(model_path, "w") - h5f.create_dataset("alpha", data=self.alpha) - h5f.create_dataset("beta", data=self.beta) - h5f.create_dataset("keys", data=np.array(self.words)) - h5f.close() diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc index 28fbbc8..73a4bf0 100644 --- a/cusim/ioutils/bindings.cc +++ b/cusim/ioutils/bindings.cc @@ -35,8 +35,8 @@ class IoUtilsBind { return obj_.TokenizeStream(num_lines, num_threads); } - void GetWordVocab(int min_count, std::string keys_path) { - obj_.GetWordVocab(min_count, keys_path); + void GetWordVocab(int min_count, std::string keys_path, std::string count_path) { + obj_.GetWordVocab(min_count, keys_path, count_path); } void GetToken(py::object& rows, py::object& cols, py::object& indptr) { @@ -62,7 +62,7 @@ PYBIND11_PLUGIN(ioutils_bind) { .def("tokenize_stream", &IoUtilsBind::TokenizeStream, py::arg("num_lines"), py::arg("num_threads")) .def("get_word_vocab", &IoUtilsBind::GetWordVocab, - py::arg("min_count"), py::arg("keys_path")) + py::arg("min_count"), py::arg("keys_path"), py::Arg("count_path")) .def("get_token", &IoUtilsBind::GetToken, py::arg("indices"), py::arg("indptr"), py::arg("offset")) .def("__repr__", diff --git a/cusim/ioutils/pyioutils.py b/cusim/ioutils/pyioutils.py index 5bce9b7..71b52fd 100644 --- a/cusim/ioutils/pyioutils.py +++ b/cusim/ioutils/pyioutils.py @@ -33,7 +33,8 @@ def __init__(self, opt=None): assert self.obj.init(bytes(tmp.name, "utf8")), f"failed to load {tmp.name}" os.remove(tmp.name) - def load_stream_vocab(self, filepath, min_count, keys_path): + def load_stream_vocab(self, filepath, min_count, + keys_path, count_path): full_num_lines = self.obj.load_stream_file(filepath) pbar = aux.Progbar(full_num_lines, unit_name="line", stateful_metrics=["word_count"]) @@ -46,17 +47,18 @@ def load_stream_vocab(self, filepath, min_count, keys_path): pbar.update(processed, values=[("word_count", word_count)]) if processed == full_num_lines: break - self.obj.get_word_vocab(min_count, keys_path) + self.obj.get_word_vocab(min_count, keys_path, count_path) def convert_stream_to_h5(self, filepath, min_count, out_dir, chunk_indices=10000, seed=777): np.random.seed(seed) os.makedirs(out_dir, exist_ok=True) keys_path = pjoin(out_dir, "keys.txt") + count_path = pjoin(out_dir, "count.txt") token_path = pjoin(out_dir, "token.h5") - self.logger.info("save key and token to %s, %s", - keys_path, token_path) - self.load_stream_vocab(filepath, min_count, keys_path) + self.logger.info("save key, count, token to %s, %s, %s", + keys_path, count_path, token_path) + self.load_stream_vocab(filepath, min_count, keys_path, count_path) full_num_lines = self.obj.load_stream_file(filepath) pbar = aux.Progbar(full_num_lines, unit_name="line") processed = 0 diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index 10b3820..79680fb 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -31,3 +31,35 @@ message CuLDAConfigProto { optional double vali_p = 13 [default = 0.2]; optional int32 seed = 14 [default = 777]; } + +message CuW2VConfigProto { + required string data_path = 7; + + optional int32 py_log_level = 1 [default = 2]; + optional int32 c_log_level = 2 [default = 2]; + + optional int32 num_dims = 3 [default = 50]; + optional int32 block_dim = 4 [default = 32]; + optional int32 hyper_threads = 5 [default = 10]; + optional string processed_data_dir = 6; + optional bool skip_preprocess = 8; + optional int32 word_min_count = 9 [default = 5]; + optional int32 batch_size = 10 [default = 100000]; + optional int32 epochs = 11 [default = 10]; + + // seed fields + optional int32 seed = 14 [default = 777]; + optional int32 table_seed = 15 [default = 777]; + optional int32 cuda_seed = 16 [default = 777]; + optional int32 random_size = 17 [default = 1000000]; + + optional int32 neg = 17 [default = 10]; + // as recommended in w2v paper + optional double count_power = 18 [default = 0.75]; + optional bool skip_gram = 19 [default = true]; + optional bool use_mean = 20 [default = true]; + optional double lr = 21 [default = 0.001]; + optional int32 window_size = 22 [default = 5]; + + +} diff --git a/setup.py b/setup.py index 512d262..3624c2f 100644 --- a/setup.py +++ b/setup.py @@ -98,6 +98,21 @@ def __init__(self, name): "cpp/include/", np.get_include(), pybind11.get_include(), pybind11.get_include(True), CUDA['include'], "3rd/json11", "3rd/spdlog/include"]), + Extension("cusim.cuw2v.cuw2v_bind", + sources= util_srcs + [ \ + "cpp/src/cuw2v/cuw2v.cu", + "cusim/cuw2v/bindings.cc", + "3rd/json11/json11.cpp"], + language="c++", + extra_compile_args=extra_compile_args, + extra_link_args=["-fopenmp"], + library_dirs=[CUDA['lib64']], + libraries=['cudart', 'cublas', 'curand'], + extra_objects=[], + include_dirs=[ \ + "cpp/include/", np.get_include(), pybind11.get_include(), + pybind11.get_include(True), CUDA['include'], + "3rd/json11", "3rd/spdlog/include"]), ] From f63a7065cac89da29d6674ae59e4bb3c511dfa06 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 10:24:06 +0900 Subject: [PATCH 15/18] compile succeed --- cpp/include/cuw2v/cuda_w2v_base_kernels.cuh | 3 --- cpp/include/cuw2v/cuw2v.hpp | 13 ------------- cpp/include/utils/cuda_utils_kernels.cuh | 3 +++ cpp/src/cuw2v/cuw2v.cu | 9 +++++++++ cpp/src/utils/ioutils.cc | 2 +- cusim/cuw2v/bindings.cc | 6 +++--- cusim/ioutils/bindings.cc | 2 +- 7 files changed, 17 insertions(+), 21 deletions(-) diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh index e831b05..8046cf3 100644 --- a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -6,9 +6,6 @@ #pragma once #include "utils/cuda_utils_kernels.cuh" -using thrust::random::default_random_engine; -using thrust::random::uniform_int_distribution; - namespace cusim { diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index d59fbbc..6ca189d 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -34,19 +34,6 @@ using thrust::random::default_random_engine; namespace cusim { -bool CompareIndex(int lhs, int rhs); - -struct HuffmanTreeNode { - float count; - int index, left, right; - HuffmanTreeNode(float count0, int index0, int left0, int right0) { - count = count0; index = index0; left = left0; right = right0; - } -}; - -std::vector huffman_nodes; -bool CompareIndex(int lhs, int rhs); - class CuW2V { public: CuW2V(); diff --git a/cpp/include/utils/cuda_utils_kernels.cuh b/cpp/include/utils/cuda_utils_kernels.cuh index 126b00f..47d447e 100644 --- a/cpp/include/utils/cuda_utils_kernels.cuh +++ b/cpp/include/utils/cuda_utils_kernels.cuh @@ -23,6 +23,9 @@ #include #include "utils/types.hpp" +using thrust::random::default_random_engine; +using thrust::random::uniform_int_distribution; + namespace cusim { // Error Checking utilities, checks status codes from cuda calls diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index c122783..448382f 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -10,6 +10,15 @@ namespace cusim { +struct HuffmanTreeNode { + float count; + int index, left, right; + HuffmanTreeNode(float count0, int index0, int left0, int right0) { + count = count0; index = index0; left = left0; right = right0; + } +}; + +std::vector huffman_nodes; bool CompareIndex(int lhs, int rhs) { return huffman_nodes[lhs].count > huffman_nodes[rhs].count; } diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc index 802712c..4af202b 100644 --- a/cpp/src/utils/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -175,7 +175,7 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path, std::string cou line = std::to_string(word_count_[word_list_[i]]) + "\n"; fout2.write(line.c_str(), line.size()); } - fout.close(); + fout1.close(); fout2.close(); } } // namespace cusim diff --git a/cusim/cuw2v/bindings.cc b/cusim/cuw2v/bindings.cc index 643228d..3ca45d6 100644 --- a/cusim/cuw2v/bindings.cc +++ b/cusim/cuw2v/bindings.cc @@ -29,7 +29,7 @@ class CuW2VBind { float_array _emb_out(emb_out); auto emb_in_buffer = _emb_in.request(); auto emb_out_buffer = _emb_out.request(); - if (emb_in_buffer.ndim != 2 or emb_out_buffer_buffer.ndim != 2 or + if (emb_in_buffer.ndim != 2 or emb_out_buffer.ndim != 2 or emb_in_buffer.shape[1] != emb_out_buffer.shape[1]) { throw std::runtime_error("invalid emb_in or emb_out"); } @@ -39,7 +39,7 @@ class CuW2VBind { void BuildRandomTable(py::object& word_count, int table_size, int num_threads) { float_array _word_count(word_count); - auto wc_buffer = _word_count.requiest(); + auto wc_buffer = _word_count.request(); if (wc_buffer.ndim != 1) { throw std::runtime_error("invalid word count"); } @@ -49,7 +49,7 @@ class CuW2VBind { void BuildHuffmanTree(py::object& word_count) { float_array _word_count(word_count); - auto wc_buffer = _word_count.requiest(); + auto wc_buffer = _word_count.request(); if (wc_buffer.ndim != 1) { throw std::runtime_error("invalid word count"); } diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc index 73a4bf0..06204f8 100644 --- a/cusim/ioutils/bindings.cc +++ b/cusim/ioutils/bindings.cc @@ -62,7 +62,7 @@ PYBIND11_PLUGIN(ioutils_bind) { .def("tokenize_stream", &IoUtilsBind::TokenizeStream, py::arg("num_lines"), py::arg("num_threads")) .def("get_word_vocab", &IoUtilsBind::GetWordVocab, - py::arg("min_count"), py::arg("keys_path"), py::Arg("count_path")) + py::arg("min_count"), py::arg("keys_path"), py::arg("count_path")) .def("get_token", &IoUtilsBind::GetToken, py::arg("indices"), py::arg("indptr"), py::arg("offset")) .def("__repr__", From 3859d7db4365957ca2729d556aa32b00ffcca377 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 11:46:40 +0900 Subject: [PATCH 16/18] separate logger --- cpp/include/culda/culda.hpp | 1 + cpp/include/cuw2v/cuw2v.hpp | 1 + cpp/include/utils/ioutils.hpp | 1 + cpp/include/utils/log.hpp | 3 ++- cpp/src/culda/culda.cu | 5 +++-- cpp/src/cuw2v/cuw2v.cu | 5 +++-- cpp/src/utils/ioutils.cc | 5 +++-- cpp/src/utils/log.cc | 18 +++++++++++++----- cusim/cuw2v/pycuw2v.py | 8 ++++---- cusim/proto/config.proto | 2 +- examples/example1.py | 14 ++++++++++++-- setup.py | 2 +- 12 files changed, 45 insertions(+), 20 deletions(-) diff --git a/cpp/include/culda/culda.hpp b/cpp/include/culda/culda.hpp index 3a126cc..ff52dca 100644 --- a/cpp/include/culda/culda.hpp +++ b/cpp/include/culda/culda.hpp @@ -75,6 +75,7 @@ class CuLDA { DeviceInfo dev_info_; json11::Json opt_; std::shared_ptr logger_; + std::unique_ptr logger_container_; thrust::device_vector dev_alpha_, dev_beta_; thrust::device_vector dev_grad_alpha_, dev_new_beta_; thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index 6ca189d..fba7f2b 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -52,6 +52,7 @@ class CuW2V { DeviceInfo dev_info_; json11::Json opt_; std::shared_ptr logger_; + std::unique_ptr logger_container_; int block_cnt_, block_dim_; int num_dims_, num_words_, window_size_; float *emb_in_, *emb_out_, lr_; diff --git a/cpp/include/utils/ioutils.hpp b/cpp/include/utils/ioutils.hpp index 00639a4..54ca2e1 100644 --- a/cpp/include/utils/ioutils.hpp +++ b/cpp/include/utils/ioutils.hpp @@ -45,6 +45,7 @@ class IoUtils { std::ifstream stream_fin_; json11::Json opt_; std::shared_ptr logger_; + std::unique_ptr logger_container_; std::unordered_map word_idmap_, word_count_; std::vector word_list_; int num_lines_, remain_lines_; diff --git a/cpp/include/utils/log.hpp b/cpp/include/utils/log.hpp index 05f30ec..270e727 100644 --- a/cpp/include/utils/log.hpp +++ b/cpp/include/utils/log.hpp @@ -7,7 +7,7 @@ // reference: https://github.com/kakao/buffalo/blob/5f571c2c7d8227e6625c6e538da929e4db11b66d/lib/misc/log.cc #pragma once #include - +#include #define SPDLOG_EOL "" #define SPDLOG_TRACE_ON #include "spdlog/spdlog.h" @@ -32,6 +32,7 @@ namespace cusim { class CuSimLogger { public: CuSimLogger(); + explicit CuSimLogger(std::string name); std::shared_ptr& get_logger(); void set_log_level(int level); int get_log_level(); diff --git a/cpp/src/culda/culda.cu b/cpp/src/culda/culda.cu index 9217902..c92bfeb 100644 --- a/cpp/src/culda/culda.cu +++ b/cpp/src/culda/culda.cu @@ -9,7 +9,8 @@ namespace cusim { CuLDA::CuLDA() { - logger_ = CuSimLogger().get_logger(); + logger_container_.reset(new CuSimLogger("lda")); + logger_ = logger_container_->get_logger(); dev_info_ = GetDeviceInfo(); if (dev_info_.unknown) DEBUG0("Unknown device type"); INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", @@ -28,7 +29,7 @@ bool CuLDA::Init(std::string opt_path) { auto _opt = json11::Json::parse(str, err_cmt); if (not err_cmt.empty()) return false; opt_ = _opt; - CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + logger_container_->set_log_level(opt_["c_log_level"].int_value()); num_topics_ = opt_["num_topics"].int_value(); block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index 448382f..8474291 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -24,7 +24,8 @@ bool CompareIndex(int lhs, int rhs) { } CuW2V::CuW2V() { - logger_ = CuSimLogger().get_logger(); + logger_container_.reset(new CuSimLogger("w2v")); + logger_ = logger_container_->get_logger(); dev_info_ = GetDeviceInfo(); if (dev_info_.unknown) DEBUG0("Unknown device type"); INFO("cuda device info, major: {}, minor: {}, multi processors: {}, cores: {}", @@ -43,7 +44,7 @@ bool CuW2V::Init(std::string opt_path) { auto _opt = json11::Json::parse(str, err_cmt); if (not err_cmt.empty()) return false; opt_ = _opt; - CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + logger_container_->set_log_level(opt_["c_log_level"].int_value()); num_dims_ = opt_["num_dims"].int_value(); block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc index 4af202b..53833a3 100644 --- a/cpp/src/utils/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -8,7 +8,8 @@ namespace cusim { IoUtils::IoUtils() { - logger_ = CuSimLogger().get_logger(); + logger_container_.reset(new CuSimLogger("ioutils")); + logger_ = logger_container_->get_logger(); } IoUtils::~IoUtils() {} @@ -23,7 +24,7 @@ bool IoUtils::Init(std::string opt_path) { auto _opt = json11::Json::parse(str, err_cmt); if (not err_cmt.empty()) return false; opt_ = _opt; - CuSimLogger().set_log_level(opt_["c_log_level"].int_value()); + logger_container_->set_log_level(opt_["c_log_level"].int_value()); return true; } diff --git a/cpp/src/utils/log.cc b/cpp/src/utils/log.cc index ddfcb0c..cec5fa4 100644 --- a/cpp/src/utils/log.cc +++ b/cpp/src/utils/log.cc @@ -16,6 +16,14 @@ CuSimLogger::CuSimLogger() { logger_ = spdlog::default_logger(); } +CuSimLogger::CuSimLogger(std::string name) { + // auto console_sink = std::make_shared(); + auto stderr_sink = std::make_shared(); + // spdlog::sinks_init_list sinks = {console_sink, stderr_sink}; + logger_ = std::make_shared(name, stderr_sink); + logger_->set_pattern("[%^%-8l%$] %Y-%m-%d %H:%M:%S %v"); +} + std::shared_ptr& CuSimLogger::get_logger() { return logger_; } @@ -23,11 +31,11 @@ std::shared_ptr& CuSimLogger::get_logger() { void CuSimLogger::set_log_level(int level) { global_logging_level_ = level; switch (level) { - case 0: spdlog::set_level(spdlog::level::off); break; - case 1: spdlog::set_level(spdlog::level::warn); break; - case 2: spdlog::set_level(spdlog::level::info); break; - case 3: spdlog::set_level(spdlog::level::debug); break; - default: spdlog::set_level(spdlog::level::trace); break; + case 0: logger_->set_level(spdlog::level::off); break; + case 1: logger_->set_level(spdlog::level::warn); break; + case 2: logger_->set_level(spdlog::level::info); break; + case 3: logger_->set_level(spdlog::level::debug); break; + default: logger_->set_level(spdlog::level::trace); break; } } diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py index 8ba4e4c..af58c46 100644 --- a/cusim/cuw2v/pycuw2v.py +++ b/cusim/cuw2v/pycuw2v.py @@ -86,7 +86,10 @@ def init_model(self): def train_model(self): self.preprocess_data() self.init_model() - if not self.opt.neg: + if self.opt.neg: + self.obj.build_random_table( \ + self.word_count, self.opt.random_size, self.opt.num_threads) + else: self.obj.build_huffman_tree(self.word_count) h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") for epoch in range(1, self.opt.epochs + 1): @@ -112,9 +115,6 @@ def _train_epoch(self, h5f): offset = next_offset # call cuda kernel - if self.opt.neg: - self.obj.build_random_table( \ - self.word_count, self.opt.random_size, self.opt.num_threads) _loss_nume, _loss_deno = \ self.obj.feed_data(cols, indptr) diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index 79680fb..a54a094 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -51,7 +51,7 @@ message CuW2VConfigProto { optional int32 seed = 14 [default = 777]; optional int32 table_seed = 15 [default = 777]; optional int32 cuda_seed = 16 [default = 777]; - optional int32 random_size = 17 [default = 1000000]; + optional int32 random_size = 12 [default = 1000000]; optional int32 neg = 17 [default = 10]; // as recommended in w2v paper diff --git a/examples/example1.py b/examples/example1.py index f9362cb..25132d0 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -12,7 +12,7 @@ import h5py import numpy as np from gensim import downloader as api -from cusim import aux, IoUtils, CuLDA +from cusim import aux, IoUtils, CuLDA, CuW2V LOGGER = aux.get_logger() DOWNLOAD_PATH = "./res" @@ -46,7 +46,7 @@ def run_lda(): opt = { "data_path": DATA_PATH, "processed_data_dir": PROCESSED_DATA_DIR, - "skip_preprocess":True, + # "skip_preprocess":True, } lda = CuLDA(opt) lda.train_model() @@ -66,5 +66,15 @@ def run_lda(): print(f"rank {j + 1}. word: {word}, prob: {prob}") h5f.close() +def run_w2v(): + opt = { + "c_log_level": 3, + "data_path": DATA_PATH, + "processed_data_dir": PROCESSED_DATA_DIR, + # "skip_preprocess":True, + } + w2v = CuW2V(opt) + w2v.train_model() + if __name__ == "__main__": fire.Fire() diff --git a/setup.py b/setup.py index 3624c2f..669403a 100644 --- a/setup.py +++ b/setup.py @@ -197,7 +197,7 @@ def setup_package(): download_url="https://github.com/js1010/cusim/releases", include_package_data=False, license='Apache2', - packages=['cusim/', "cusim/ioutils/", "cusim/culda/"], + packages=['cusim/', "cusim/ioutils/", "cusim/culda/", "cusim/cuw2v/"], install_requires=INSTALL_REQUIRES, cmdclass=cmdclass, classifiers=[_f for _f in CLASSIFIERS.split('\n') if _f], From 46b31a5325f3bada40012e403e1026c379d24e59 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 11:52:45 +0900 Subject: [PATCH 17/18] bug-fix --- cpp/src/cuw2v/cuw2v.cu | 2 +- cusim/culda/pyculda.py | 5 +++++ cusim/cuw2v/pycuw2v.py | 17 ++++++++++++----- cusim/proto/config.proto | 3 +-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index 8474291..20f6640 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -265,7 +265,7 @@ std::pair CuW2V::FeedData(const int* cols, const int* indptr, // accumulate loss nume / deno std::vector loss_nume(block_cnt_), loss_deno(block_cnt_); thrust::copy(dev_loss_nume.begin(), dev_loss_nume.end(), loss_nume.begin()); - thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_nume.begin()); + thrust::copy(dev_loss_deno.begin(), dev_loss_deno.end(), loss_deno.begin()); CHECK_CUDA(cudaDeviceSynchronize()); float loss_nume_sum = std::accumulate(loss_nume.begin(), loss_nume.end(), 0.0f); float loss_deno_sum = std::accumulate(loss_deno.begin(), loss_deno.end(), 0.0f); diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py index e0fa3d9..052f52c 100644 --- a/cusim/culda/pyculda.py +++ b/cusim/culda/pyculda.py @@ -20,12 +20,17 @@ from cusim.config_pb2 import CuLDAConfigProto EPS = 1e-10 +WARP_SIZE = 32 class CuLDA: def __init__(self, opt=None): self.opt = aux.get_opt_as_proto(opt or {}, CuLDAConfigProto) self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + assert self.opt.block_dim <= WARP_SIZE ** 2 and \ + self.opt.block_dim % WARP_SIZE == 0, \ + f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})" + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) tmp.write(opt_content) diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py index af58c46..3d4c4ee 100644 --- a/cusim/cuw2v/pycuw2v.py +++ b/cusim/cuw2v/pycuw2v.py @@ -19,12 +19,17 @@ from cusim.config_pb2 import CuW2VConfigProto EPS = 1e-10 +WARP_SIZE = 32 class CuW2V: def __init__(self, opt=None): self.opt = aux.get_opt_as_proto(opt or {}, CuW2VConfigProto) self.logger = aux.get_logger("culda", level=self.opt.py_log_level) + assert self.opt.block_dim <= WARP_SIZE ** 2 and \ + self.opt.block_dim % WARP_SIZE == 0, \ + f"invalid block dim ({self.opt.block_dim}, warp size: {WARP_SIZE})" + tmp = tempfile.NamedTemporaryFile(mode='w', delete=False) opt_content = json.dumps(aux.proto_to_dict(self.opt), indent=2) tmp.write(opt_content) @@ -61,6 +66,7 @@ def init_model(self): dtype=np.float32) self.word_count = np.power(self.word_count, self.opt.count_power) self.num_words = len(self.words) + assert len(self.words) == len(self.word_count) # count number of docs h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") @@ -70,6 +76,12 @@ def init_model(self): self.logger.info("number of words: %d, docs: %d", self.num_words, self.num_docs) + if self.opt.neg: + self.obj.build_random_table( \ + self.word_count, self.opt.random_size, self.opt.num_threads) + else: + self.obj.build_huffman_tree(self.word_count) + # random initialize alpha and beta np.random.seed(self.opt.seed) self.emb_in = np.random.normal( \ @@ -86,11 +98,6 @@ def init_model(self): def train_model(self): self.preprocess_data() self.init_model() - if self.opt.neg: - self.obj.build_random_table( \ - self.word_count, self.opt.random_size, self.opt.num_threads) - else: - self.obj.build_huffman_tree(self.word_count) h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") for epoch in range(1, self.opt.epochs + 1): self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index a54a094..f468bb0 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -60,6 +60,5 @@ message CuW2VConfigProto { optional bool use_mean = 20 [default = true]; optional double lr = 21 [default = 0.001]; optional int32 window_size = 22 [default = 5]; - - + optional int32 num_threads = 23 [default = 4]; } From a6b9e7ee98bcb5c6a30a2e852c6e840795d528c9 Mon Sep 17 00:00:00 2001 From: js1010 Date: Sat, 13 Feb 2021 11:58:22 +0900 Subject: [PATCH 18/18] add constants.py --- cusim/constants.py | 10 ++++++++++ cusim/culda/pyculda.py | 3 +-- cusim/cuw2v/pycuw2v.py | 6 ++---- examples/example1.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) create mode 100644 cusim/constants.py diff --git a/cusim/constants.py b/cusim/constants.py new file mode 100644 index 0000000..6d69ce8 --- /dev/null +++ b/cusim/constants.py @@ -0,0 +1,10 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,too-few-public-methods,no-member + +EPS = 1e-10 +WARP_SIZE = 32 diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py index 052f52c..dbd0da3 100644 --- a/cusim/culda/pyculda.py +++ b/cusim/culda/pyculda.py @@ -18,9 +18,8 @@ from cusim import aux, IoUtils from cusim.culda.culda_bind import CuLDABind from cusim.config_pb2 import CuLDAConfigProto +from cusim.constants import EPS, WARP_SIZE -EPS = 1e-10 -WARP_SIZE = 32 class CuLDA: def __init__(self, opt=None): diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py index 3d4c4ee..f2bd265 100644 --- a/cusim/cuw2v/pycuw2v.py +++ b/cusim/cuw2v/pycuw2v.py @@ -17,9 +17,7 @@ from cusim import aux, IoUtils from cusim.cuw2v.cuw2v_bind import CuW2VBind from cusim.config_pb2 import CuW2VConfigProto - -EPS = 1e-10 -WARP_SIZE = 32 +from cusim.constants import EPS, WARP_SIZE class CuW2V: def __init__(self, opt=None): @@ -102,7 +100,7 @@ def train_model(self): for epoch in range(1, self.opt.epochs + 1): self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) self._train_epoch(h5f) - self.pull() + self.obj.pull() h5f.close() def _train_epoch(self, h5f): diff --git a/examples/example1.py b/examples/example1.py index 25132d0..1bea971 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -68,7 +68,7 @@ def run_lda(): def run_w2v(): opt = { - "c_log_level": 3, + # "c_log_level": 3, "data_path": DATA_PATH, "processed_data_dir": PROCESSED_DATA_DIR, # "skip_preprocess":True,