Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/include/culda/culda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CuLDA {
DeviceInfo dev_info_;
json11::Json opt_;
std::shared_ptr<spdlog::logger> logger_;
std::unique_ptr<CuSimLogger> logger_container_;
thrust::device_vector<float> dev_alpha_, dev_beta_;
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
Expand Down
50 changes: 50 additions & 0 deletions cpp/include/cuw2v/cuda_w2v_base_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2021 Jisang Yoon
// All rights reserved.
//
// This source code is licensed under the Apache 2.0 license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "utils/cuda_utils_kernels.cuh"

namespace cusim {


__inline__ __device__
void PositiveFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
if (threadIdx.x == 0) {
float exp_dot = expf(-dot);
g = exp_dot / (1 + exp_dot) * lr;
loss_nume += logf(1 + exp_dot);
loss_deno++;
}
__syncthreads();
for (int i = threadIdx.x; i < num_dims; i += blockDim.x) {
grad[i] += vec2[i] * g;
vec2[i] += vec1[i] * g;
}
__syncthreads();
}

__inline__ __device__
void NegativeFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
if (threadIdx.x == 0) {
float exp_dot = expf(dot);
g = exp_dot / (1 + exp_dot) * lr;
loss_nume += logf(1 + exp_dot);
loss_deno++;
}
__syncthreads();
for (int i = threadIdx.x; i < num_dims; i += blockDim.x) {
grad[i] -= vec2[i] * g;
vec2[i] -= vec1[i] * g;
}
__syncthreads();
}

} // cusim
149 changes: 149 additions & 0 deletions cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) 2021 Jisang Yoon
// All rights reserved.
//
// This source code is licensed under the Apache 2.0 license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "utils/cuda_utils_kernels.cuh"
#include "cuw2v/cuda_w2v_base_kernels.cuh"


namespace cusim {

__global__ void W2VHsSgKernel(
const int* cols, const int* indptr,
const bool* codes, const int* points, const int* hs_indptr,
const int num_indptr, const int num_dims, const int window_size,
default_random_engine* rngs,
float* emb_in, float* emb_out,
float* loss_nume, float* loss_deno, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
float& _loss_deno = loss_deno[blockIdx.x];

uniform_int_distribution<int> dist_window(0, window_size - 1);
static __shared__ int reduced_windows;
extern __shared__ float shared_memory[];
float* grad = &shared_memory[0];

// zero-initialize shared mem
for (int i = threadIdx.x; i < num_dims; i += blockDim.x)
grad[i] = 0.0f;
__syncthreads();

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
for (int j = beg; j < end; ++j) {
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
float* _emb_in = emb_in + num_dims * cols[j];
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
int beg3 = hs_indptr[cols[k]];
int end3 = hs_indptr[cols[k] + 1];
for (int l = beg3; l < end3; ++l) {
if (codes[l]) {
PositiveFeedback(_emb_in, emb_out + num_dims * points[l],
grad, _loss_nume, _loss_deno, num_dims, lr);
} else {
NegativeFeedback(_emb_in, emb_out + num_dims * points[l],
grad, _loss_nume, _loss_deno, num_dims, lr);
}
__syncthreads();
}
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
emb_in[num_dims * cols[j] + l] += grad[l];
grad[l] = 0.0f;
}
__syncthreads();
}
}
}
}

__global__ void W2VHsCbowKernel(
const int* cols, const int* indptr,
const bool* codes, const int* points, const int* hs_indptr,
const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs,
float* emb_in, float* emb_out,
float* loss_nume, float* loss_deno,
const bool use_mean, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
float& _loss_deno = loss_deno[blockIdx.x];

uniform_int_distribution<int> dist_window(0, window_size - 1);
static __shared__ int reduced_windows;
extern __shared__ float shared_memory[];
float* grad = &shared_memory[0];
float* cbow = &shared_memory[num_dims];

__syncthreads();

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
for (int j = beg; j < end; ++j) {
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
if (end2 - beg2 <= 1) continue;

// zero-initialize shared mem
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] = 0.0f;
cbow[k] = 0.0f;
}

// compute cbow
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
cbow[l] += emb_in[num_dims * cols[k] + l];
}
}
if (use_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
cbow[k] /= (end2 - beg2 - 1);
}
}
__syncthreads();

int beg3 = hs_indptr[cols[j]];
int end3 = hs_indptr[cols[j] + 1];
for (int k = beg3; k < end3; ++k) {
if (codes[k]) {
PositiveFeedback(cbow, emb_out + num_dims * points[k],
grad, _loss_nume, _loss_deno, num_dims, lr);
} else {
NegativeFeedback(cbow, emb_out + num_dims * points[k],
grad, _loss_nume, _loss_deno, num_dims, lr);
}
__syncthreads();
}

// normalize grad if use_mean = true
if (use_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] /= (end2 - beg2 - 1);
}
}
__syncthreads();

// update emb_in
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
emb_in[num_dims * cols[k] + l] += grad[l];
}
__syncthreads();
}
}
}
}

} // cusim
147 changes: 147 additions & 0 deletions cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) 2021 Jisang Yoon
// All rights reserved.
//
// This source code is licensed under the Apache 2.0 license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "utils/cuda_utils_kernels.cuh"
#include "cuw2v/cuda_w2v_base_kernels.cuh"


namespace cusim {

__global__ void W2VNegSgKernel(
const int* cols, const int* indptr,
const int* random_table, default_random_engine* rngs, const int random_size,
const int num_indptr, const int num_dims, const int neg, const int window_size,
float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
float& _loss_deno = loss_deno[blockIdx.x];

uniform_int_distribution<int> dist_neg(0, random_size - 1);
uniform_int_distribution<int> dist_window(0, window_size - 1);
__shared__ int reduced_windows;
__shared__ int neg_word;
extern __shared__ float shared_memory[];
float* grad = &shared_memory[0];

// zero-initialize shared mem
for (int i = threadIdx.x; i < num_dims; i += blockDim.x)
grad[i] = 0.0f;
__syncthreads();

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
for (int j = beg; j < end; ++j) {
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
float* _emb_in = emb_in + num_dims * cols[j];
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
PositiveFeedback(_emb_in, emb_out + num_dims * cols[k],
grad, _loss_nume, _loss_deno, num_dims, lr);
for (int l = 0; l < neg; ++l) {
if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)];
__syncthreads();
NegativeFeedback(_emb_in, emb_out + num_dims * neg_word,
grad, _loss_nume, _loss_deno, num_dims, lr);
}
__syncthreads();
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
emb_in[num_dims * cols[j] + l] += grad[l];
grad[l] = 0.0f;
}
__syncthreads();
}
}
}
}

__global__ void W2VNegCbowKernel(
const int* cols, const int* indptr,
const int* random_table, default_random_engine* rngs, const int random_size,
const int num_indptr, const int num_dims, const int neg, const int window_size,
float* emb_in, float* emb_out,
float* loss_nume, float* loss_deno, const bool use_mean, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
float& _loss_deno = loss_deno[blockIdx.x];

uniform_int_distribution<int> dist_neg(0, random_size - 1);
uniform_int_distribution<int> dist_window(0, window_size - 1);
static __shared__ int reduced_windows;
static __shared__ int neg_word;
extern __shared__ float shared_memory[];
float* grad = &shared_memory[0];
float* cbow = &shared_memory[num_dims];

__syncthreads();

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
for (int j = beg; j < end; ++j) {
if (threadIdx.x == 0) reduced_windows = dist_window(rng);
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
if (end2 - beg2 <= 1) continue;

// zero-initialize shared mem
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] = 0.0f;
cbow[k] = 0.0f;
}

// compute cbow
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
cbow[l] += emb_in[num_dims * cols[k] + l];
}
}
if (use_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
cbow[k] /= (end2 - beg2 - 1);
}
}
__syncthreads();

PositiveFeedback(cbow, emb_out + num_dims * cols[j], grad,
_loss_nume, _loss_deno, num_dims, lr);
__syncthreads();

// update negative feedback
for (int k = 0; k < neg; ++k){
if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)];
__syncthreads();
NegativeFeedback(cbow, emb_out + num_dims * neg_word,
grad, _loss_nume, _loss_deno, num_dims, lr);
}
__syncthreads();

// normalize grad if use_mean = true
if (use_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] /= (end2 - beg2 - 1);
}
}
__syncthreads();

// update emb_in
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x)
emb_in[num_dims * cols[k] + l] += grad[l];
}
__syncthreads();

}
}
}

} // cusim
Loading