Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions cpp/include/culda.hpp

This file was deleted.

121 changes: 121 additions & 0 deletions cpp/include/culda/cuda_lda_kernels.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2021 Jisang Yoon
// All rights reserved.
//
// This source code is licensed under the Apache 2.0 license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "utils/cuda_utils_kernels.cuh"


namespace cusim {

// reference: http://web.science.mq.edu.au/~mjohnson/code/digamma.c
__inline__ __device__
float Digamma(float x) {
float result = 0.0f, xx, xx2, xx4;
for ( ; x < 7.0f; ++x)
result -= 1.0f / x;
x -= 0.5f;
xx = 1.0f / x;
xx2 = xx * xx;
xx4 = xx2 * xx2;
result += logf(x) + 1.0f / 24.0f * xx2
- 7.0f / 960.0f * xx4 + 31.0f / 8064.0f * xx4 * xx2
- 127.0f / 30720.0f * xx4 * xx4;
return result;
}

__global__ void EstepKernel(
const int* cols, const int* indptr, const bool* vali,
const int num_cols, const int num_indptr,
const int num_topics, const int num_iters,
float* gamma, float* new_gamma, float* phi,
const float* alpha, const float* beta,
float* grad_alpha, float* new_beta,
float* train_losses, float* vali_losses, int* mutex) {

// storage for block
float* _gamma = gamma + num_topics * blockIdx.x;
float* _new_gamma = new_gamma + num_topics * blockIdx.x;
float* _phi = phi + num_topics * blockIdx.x;
float* _grad_alpha = grad_alpha + num_topics * blockIdx.x;

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
// initialize gamma
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_gamma[j] = alpha[j] + (end - beg) / num_topics;
__syncthreads();

// iterate E step
for (int j = 0; j < num_iters; ++j) {
// initialize new gamma
for (int k = threadIdx.x; k < num_topics; k += blockDim.x)
_new_gamma[k] = 0.0f;
__syncthreads();

// compute phi from gamma
for (int k = beg; k < end; ++k) {
const int w = cols[k];
const bool _vali = vali[k];

// compute phi
if (not _vali or j + 1 == num_iters) {
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
_phi[l] = beta[w * num_topics + l] * expf(Digamma(_gamma[l]));
__syncthreads();

// normalize phi and add it to new gamma and new beta
float phi_sum = ReduceSum(_phi, num_topics);

for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
_phi[l] /= phi_sum;
if (not _vali) _new_gamma[l] += _phi[l];
}
__syncthreads();
}

if (j + 1 == num_iters) {
// write access of w th vector of new_beta
if (threadIdx.x == 0) {
while (atomicCAS(&mutex[w], 0, 1)) {}
}

__syncthreads();
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
if (j + 1 == num_iters) {
if (not _vali) new_beta[w * num_topics + l] += _phi[l];
_phi[l] *= beta[w * num_topics + l];
}
}
__syncthreads();

// release lock
if (threadIdx.x == 0) mutex[w] = 0;
__syncthreads();

float p = fmaxf(EPS, ReduceSum(_phi, num_topics));
if (threadIdx.x == 0) {
if (_vali)
vali_losses[blockIdx.x] += logf(p);
else
train_losses[blockIdx.x] += logf(p);
}
}
__syncthreads();
}

// update gamma
for (int k = threadIdx.x; k < num_topics; k += blockDim.x)
_gamma[k] = _new_gamma[k] + alpha[k];
__syncthreads();
}
float gamma_sum = ReduceSum(_gamma, num_topics);
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum));

__syncthreads();
}
}

} // cusim
88 changes: 88 additions & 0 deletions cpp/include/culda/culda.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2021 Jisang Yoon
// All rights reserved.
//
// This source code is licensed under the Apache 2.0 license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <thrust/copy.h>
#include <thrust/fill.h>
#include <thrust/random.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/binary_search.h>
#include <thrust/execution_policy.h>

#include <omp.h>
#include <set>
#include <random>
#include <memory>
#include <string>
#include <fstream>
#include <utility>
#include <queue>
#include <deque>
#include <functional>
#include <vector>
#include <cmath>
#include <chrono> // NOLINT

#include "json11.hpp"
#include "utils/log.hpp"
#include "utils/types.hpp"

namespace cusim {


// reference: https://people.math.sc.edu/Burkardt/cpp_src/asa121/asa121.cpp
inline float Trigamma(float x) {
const float a = 0.0001f;
const float b = 5.0f;
const float b2 = 0.1666666667f;
const float b4 = -0.03333333333f;
const float b6 = 0.02380952381f;
const float b8 = -0.03333333333f;
float value = 0, y = 0, z = x;
if (x <= a) return 1.0f / x / x;
while (z < b) {
value += 1.0f / z / z;
z++;
}
y = 1.0f / z / z;
value += value + 0.5 * y + (1.0
+ y * (b2
+ y * (b4
+ y * (b6
+ y * b8)))) / z;
return value;
}


class CuLDA {
public:
CuLDA();
~CuLDA();
bool Init(std::string opt_path);
void LoadModel(float* alpha, float* beta,
float* grad_alpha, float* new_beta, const int num_words);
std::pair<float, float> FeedData(
const int* indices, const int* indptr, const bool* vali,
const int num_indices, const int num_indptr, const int num_iters);
void Pull();
void Push();
int GetBlockCnt();

private:
DeviceInfo dev_info_;
json11::Json opt_;
std::shared_ptr<spdlog::logger> logger_;
thrust::device_vector<float> dev_alpha_, dev_beta_;
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
thrust::device_vector<int> dev_mutex_;

float *alpha_, *beta_, *grad_alpha_, *new_beta_;
int block_cnt_, block_dim_;
int num_topics_, num_words_;
};

} // namespace cusim
41 changes: 0 additions & 41 deletions cpp/include/types.hpp

This file was deleted.

Loading