Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
fc62175
add gensim example
js1010 Feb 13, 2021
cee1a0b
lower characgter
js1010 Feb 13, 2021
d75bee8
update
js1010 Feb 13, 2021
2997093
use np.int32 => np.int64
js1010 Feb 13, 2021
9cf8c65
bug-fix in ioutils
js1010 Feb 13, 2021
445b3aa
save and load w2v model compatibly
js1010 Feb 13, 2021
25ff133
add evaluation code
js1010 Feb 13, 2021
edbdaf7
quality debugging
js1010 Feb 13, 2021
96d44ce
fix context-word reversion
js1010 Feb 14, 2021
ee76181
evaluate case sensitively
js1010 Feb 14, 2021
a7c5daa
change random table
js1010 Feb 14, 2021
3cb7e31
float => double
js1010 Feb 14, 2021
6107360
bug-fix in random table
js1010 Feb 14, 2021
4b5be55
update example
js1010 Feb 14, 2021
90f08fd
add required libs
js1010 Feb 14, 2021
4a0cc5b
separate w2v and lda
js1010 Feb 14, 2021
410eb23
make fairer
js1010 Feb 14, 2021
c491d98
read bag of words
js1010 Feb 14, 2021
509093b
compile succeed
js1010 Feb 14, 2021
84f3ba9
remove unnecessary requirements
js1010 Feb 14, 2021
6792064
bug-fix in processing bow file
js1010 Feb 14, 2021
f6ba3ee
train counts
js1010 Feb 14, 2021
51eb414
fix typo
js1010 Feb 14, 2021
239663a
bug-fix
js1010 Feb 14, 2021
0c68672
change required fields
js1010 Feb 14, 2021
e6070a8
change default values
js1010 Feb 14, 2021
3d92783
add gensim lda code
js1010 Feb 14, 2021
edbc949
backup gamma
js1010 Feb 14, 2021
3f1e49d
optionize reuse_gammaa
js1010 Feb 14, 2021
b9b52cc
backup
js1010 Feb 14, 2021
6aca507
save gamma and remove tmps
js1010 Feb 15, 2021
ac86e26
remove tmps in w2v
js1010 Feb 15, 2021
1ced64e
make loss more sense
js1010 Feb 15, 2021
425a1df
add comment
js1010 Feb 15, 2021
9180474
update lda results
js1010 Feb 15, 2021
e24a25a
add readme
js1010 Feb 15, 2021
ffc944a
refactor example codes
js1010 Feb 15, 2021
d26e062
add reame
js1010 Feb 15, 2021
955a3c8
add performance
js1010 Feb 15, 2021
1db5a46
Update README.md
js1010 Feb 15, 2021
30af2c7
add files for pypi
js1010 Feb 15, 2021
a31e313
change number
js1010 Feb 15, 2021
9a608c2
Merge branch 'main' of github.com:js1010/cusim into task/add-benchmark
js1010 Feb 15, 2021
fa51d2e
change words
js1010 Feb 15, 2021
94ca9b4
update README.md
js1010 Feb 15, 2021
4251dd5
add performance results
js1010 Feb 15, 2021
94f4e94
use bold
js1010 Feb 15, 2021
5898669
add results
js1010 Feb 15, 2021
469d0ff
Merge pull request #6 from js1010/task/add-benchmark
js1010 Feb 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
include cuda_setup.py
include requirements.txt
include pyproject.toml
recursive-include cpp/src/cuw2v/ *.cu
recursive-include cpp/src/culda/ *.cu
recursive-include cpp/src/ioutils/ *.cc
recursive-include cpp/include/cuw2v/ *.cuh
recursive-include cpp/include/cuw2v/ *.hpp
recursive-include cpp/include/culda/ *.cuh
recursive-include cpp/include/culda/ *.hpp
recursive-include cpp/include/ioutils/ *.cuh
recursive-include cpp/include/ioutils/ *.hpp
recursive-include 3rd/json11/ *
recursive-include 3rd/spdlog/ *
65 changes: 65 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
### Introduction

This project is to speed up various ML models (e.g. topic modeling, word embedding, etc) by CUDA. It would be nice to think of it as [gensim](https://github.com/RaRe-Technologies/gensim)'s GPU version project. As a starting step, I implemented the most widely used word embedding model, the [word2vec](https://arxiv.org/pdf/1301.3781.pdf) model, and the most representative topic model, the [LDA (Latent Dirichlet Allocation)](https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf) model.

### How to install

- install from source

```shell
# clone repo and submodules
Expand All @@ -14,3 +19,63 @@ python -m grpc_tools.protoc --python_out cusim/ --proto_path cusim/proto/ config
# install
python setup.py install
```

- pip installation will be available soon

### How to use

- `examples/example_w2v.py`, `examples/example_lda.py` and `examples/README.md` will be very helpful to understand the usage.
- paremeter description can be seen in `cusim/proto/config.proto`

### Performance

- [AWS g4dn 2xlarge instance](https://aws.amazon.com/ec2/instance-types/g4/) is used to the experiment. (One NVIDIA T4 GPU with 8 vcpus, Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz)
- results can be reproduced by simply running `examples/example_w2v.py` and `examples/example_lda.py`
- To evaluate w2v model, I used `evaluate_word_pairs` function ([ref link](https://radimrehurek.com/gensim/auto_examples/tutorials/run_word2vec.html#evaluating)) in gensim, note that better performance on WS-353 test set does not necessarily mean that the model will workbetter in application as desribed on the link. However, it is good to be measured quantitively and fast training time will be at least very objective measure of the performaance.
- I trained W2V model on `quora-duplicat-questions` dataset from gensim downloader api on GPU with cusim and compare the performance (both speed and model quality) with gensim.
- To evaluate LDA model, I found there is no good way to measure the quality of traing results quantitatively. But we can check the model by looking at the top words of each topic. Also, we can compare the training time quantitatively.
- W2V (skip gram, hierarchical softmax)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 892.596 | 544.212 | 310.727 | 226.472 | **16.162** |
| pearson | 0.487832 | 0.487696 | 0.482821 | 0.487136 | **0.492101** |
| spearman | 0.500846 | 0.506214 | 0.501048 | **0.506718** | 0.479468 |

- W2V (skip gram, negative sampling)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 586.545 | 340.489 | 220.804 | 146.23 | **33.9173** |
| pearson | 0.354448 | 0.353952 | 0.352398 | 0.352925 | **0.360436** |
| spearman | 0.369146 | 0.369365 | **0.370565** | 0.365822 | 0.355204 |

- W2V (CBOW, hierarchical softmax)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 250.135 | 155.121 | 103.57 | 73.8073 | **6.20787** |
| pearson | 0.309651 | 0.321803 | 0.324854 | 0.314255 | **0.480298** |
| spearman | 0.294047 | 0.308723 | 0.318293 | 0.300591 | **0.480971** |

- W2V (CBOW, negative sampling)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 176.923 | 100.369 | 69.7829 | 49.9274 | **9.90391** |
| pearson | 0.18772 | 0.193152 | 0.204509 | 0.187924 | **0.368202** |
| spearman | 0.243975 | 0.24587 | 0.260531 | 0.237441 | **0.358042** |

- LDA (`nytimes` dataset from https://archive.ics.uci.edu/ml/datasets/bag+of+words)
- I found that setting `workers` variable in gensim LdaMulticore does not work properly (it uses all cores in instance anyway), so I just compared the speed between cusim with single GPU and gensim with 8 vcpus.
- One can compare the quality of modeling by looking at `examples/cusim.topics.txt` and `examples/gensim.topics.txt`.

| attr | gensim (8 vpus) | cusim (NVIDIA T4)|
|:--------------------|------------------:|--------:|
| training time (sec) | 447.376 | **76.6972** |

### Future tasks

- support half precision
- support multi device (multi device implementation on LDA model will not be that hard, while multi device training on w2v may require some considerations)
- implement other models such as FastText, BERT, etc
107 changes: 73 additions & 34 deletions cpp/include/culda/cuda_lda_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,36 @@ float Digamma(float x) {
}

__global__ void EstepKernel(
const int* cols, const int* indptr, const bool* vali,
const int num_cols, const int num_indptr,
const int* cols, const int* indptr,
const bool* vali, const float* counts,
const bool init_gamma, const int num_cols, const int num_indptr,
const int num_topics, const int num_iters,
float* gamma, float* new_gamma, float* phi,
const float* alpha, const float* beta,
float* grad_alpha, float* new_beta,
float* train_losses, float* vali_losses, int* mutex) {
float* gamma, float* grad_alpha, float* new_beta,
float* train_losses, float* vali_losses, int* locks) {

// storage for block
float* _gamma = gamma + num_topics * blockIdx.x;
float* _new_gamma = new_gamma + num_topics * blockIdx.x;
float* _phi = phi + num_topics * blockIdx.x;
extern __shared__ float shared_memory[];
float* _new_gamma = &shared_memory[0];
float* _phi = &shared_memory[num_topics];
float* _loss_vec = &shared_memory[num_topics * 2];
float* _vali_phi_sum = &shared_memory[num_topics * 3];

float* _grad_alpha = grad_alpha + num_topics * blockIdx.x;

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
// initialize gamma
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_gamma[j] = alpha[j] + (end - beg) / num_topics;
float* _gamma = gamma + num_topics * i;
if (init_gamma) {
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
_gamma[j] = alpha[j] + (end - beg) / num_topics;
}
}
__syncthreads();

// initiate phi sum for validation data for computing vali loss
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_vali_phi_sum[j] = 0.0f;

// iterate E step
for (int j = 0; j < num_iters; ++j) {
Expand All @@ -58,7 +68,7 @@ __global__ void EstepKernel(
for (int k = beg; k < end; ++k) {
const int w = cols[k];
const bool _vali = vali[k];

const float c = counts[k];
// compute phi
if (not _vali or j + 1 == num_iters) {
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
Expand All @@ -70,37 +80,52 @@ __global__ void EstepKernel(

for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
_phi[l] /= phi_sum;
if (not _vali) _new_gamma[l] += _phi[l];

// update gamma for train data and phi_sum for computing loss
if (_vali)
_vali_phi_sum[l] += _phi[l] * c;
else
_new_gamma[l] += _phi[l] * c;

}
__syncthreads();
}

if (j + 1 == num_iters) {
// write access of w th vector of new_beta
if (threadIdx.x == 0) {
while (atomicCAS(&mutex[w], 0, 1)) {}
}
// update beta for train data
if (not _vali) {
// write access of w th vector of new_beta
if (threadIdx.x == 0) {
while (atomicCAS(&locks[w], 0, 1)) {}
}

__syncthreads();
__syncthreads();
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
new_beta[w * num_topics + l] += _phi[l] * c;
__syncthreads();

// release lock
if (threadIdx.x == 0) locks[w] = 0;
__syncthreads();
}

// comput loss and reset shared mem
// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
if (j + 1 == num_iters) {
if (not _vali) new_beta[w * num_topics + l] += _phi[l];
_phi[l] *= beta[w * num_topics + l];
}
_loss_vec[l] = logf(fmaxf(beta[w * num_topics + l], EPS));
_loss_vec[l] -= logf(fmaxf(_phi[l], EPS));
_loss_vec[l] *= _phi[l];
}
__syncthreads();

// release lock
if (threadIdx.x == 0) mutex[w] = 0;
__syncthreads();

float p = fmaxf(EPS, ReduceSum(_phi, num_topics));
float _loss = ReduceSum(_loss_vec, num_topics) * c;
if (threadIdx.x == 0) {
if (_vali)
vali_losses[blockIdx.x] += logf(p);
if (_vali)
vali_losses[blockIdx.x] += _loss;
else
train_losses[blockIdx.x] += logf(p);
}
train_losses[blockIdx.x] += _loss;
}
__syncthreads();

}
__syncthreads();
}
Expand All @@ -110,9 +135,23 @@ __global__ void EstepKernel(
_gamma[k] = _new_gamma[k] + alpha[k];
__syncthreads();
}

// update gradient of alpha and loss from E[log(theta)]
float gamma_sum = ReduceSum(_gamma, num_topics);
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum));
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
float Elogthetad = Digamma(_gamma[j]) - Digamma(gamma_sum);
_grad_alpha[j] += Elogthetad;
_new_gamma[j] *= Elogthetad;
_vali_phi_sum[j] *= Elogthetad;
}

// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
float train_loss = ReduceSum(_new_gamma, num_topics);
float vali_loss = ReduceSum(_vali_phi_sum, num_topics);
if (threadIdx.x == 0) {
train_losses[blockIdx.x] += train_loss;
vali_losses[blockIdx.x] += vali_loss;
}

__syncthreads();
}
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/culda/culda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ class CuLDA {
void LoadModel(float* alpha, float* beta,
float* grad_alpha, float* new_beta, const int num_words);
std::pair<float, float> FeedData(
const int* indices, const int* indptr, const bool* vali,
const int num_indices, const int num_indptr, const int num_iters);
const int* indices, const int* indptr,
const bool* vali, const float* counts,
float* gamma, const bool init_gamma,
const int num_indices, const int num_indptr,
const int num_iters);
void Pull();
void Push();
int GetBlockCnt();
Expand All @@ -78,8 +81,7 @@ class CuLDA {
std::unique_ptr<CuSimLogger> logger_container_;
thrust::device_vector<float> dev_alpha_, dev_beta_;
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
thrust::device_vector<int> dev_mutex_;
thrust::device_vector<int> dev_locks_;

float *alpha_, *beta_, *grad_alpha_, *new_beta_;
int block_cnt_, block_dim_;
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuw2v/cuda_w2v_base_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
#pragma once
#include "utils/cuda_utils_kernels.cuh"

#define MAX_EXP 20

namespace cusim {


__inline__ __device__
void PositiveFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
if (threadIdx.x == 0) {
float exp_dot = expf(-dot);
g = exp_dot / (1 + exp_dot) * lr;
Expand All @@ -32,7 +34,7 @@ __inline__ __device__
void NegativeFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
if (threadIdx.x == 0) {
float exp_dot = expf(dot);
g = exp_dot / (1 + exp_dot) * lr;
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ __global__ void W2VHsSgKernel(
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
float* _emb_in = emb_in + num_dims * cols[j];
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
int beg3 = hs_indptr[cols[k]];
int end3 = hs_indptr[cols[k] + 1];
float* _emb_in = emb_in + num_dims * cols[k];
int beg3 = hs_indptr[cols[j]];
int end3 = hs_indptr[cols[j] + 1];
for (int l = beg3; l < end3; ++l) {
if (codes[l]) {
PositiveFeedback(_emb_in, emb_out + num_dims * points[l],
Expand All @@ -55,7 +55,7 @@ __global__ void W2VHsSgKernel(
__syncthreads();
}
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
emb_in[num_dims * cols[j] + l] += grad[l];
_emb_in[l] += grad[l];
grad[l] = 0.0f;
}
__syncthreads();
Expand All @@ -70,7 +70,7 @@ __global__ void W2VHsCbowKernel(
const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs,
float* emb_in, float* emb_out,
float* loss_nume, float* loss_deno,
const bool use_mean, const float lr) {
const bool cbow_mean, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
Expand Down Expand Up @@ -98,15 +98,15 @@ __global__ void W2VHsCbowKernel(
grad[k] = 0.0f;
cbow[k] = 0.0f;
}

// compute cbow
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
cbow[l] += emb_in[num_dims * cols[k] + l];
}
}
if (use_mean) {
if (cbow_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
cbow[k] /= (end2 - beg2 - 1);
}
Expand All @@ -126,8 +126,8 @@ __global__ void W2VHsCbowKernel(
__syncthreads();
}

// normalize grad if use_mean = true
if (use_mean) {
// normalize grad if cbow_mean = true
if (cbow_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] /= (end2 - beg2 - 1);
}
Expand Down
Loading