Skip to content

Commit

Permalink
T-ARTM code refactoring and speed-up (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
MelLain committed Mar 8, 2018
1 parent 0262269 commit d6dd1f7
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 179 deletions.
18 changes: 5 additions & 13 deletions src/artm/core/processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,15 @@ void Processor::ThreadFunction() {
if (ptdw_agents.empty() && !part->has_ptdw_cache_manager()) {
std::shared_ptr<BatchTransactionInfo> batch_info;
{
CuckooWatch cuckoo2("GetBatchTransactionsInfo", &cuckoo, kTimeLoggingThreshold);
batch_info = ProcessorTransactionHelpers::GetBatchTransactionsInfo(batch);
}

std::shared_ptr<CsrMatrix<float>> sparse_ndx;
{
CuckooWatch cuckoo2("InitializeSparseNdx", &cuckoo, kTimeLoggingThreshold);
sparse_ndx = ProcessorTransactionHelpers::InitializeSparseNdx(batch, args,
batch_info->class_id_to_tt,
batch_info->transaction_ids_to_index);
CuckooWatch cuckoo2("PrepareBatchInfo", &cuckoo, kTimeLoggingThreshold);
batch_info = ProcessorTransactionHelpers::PrepareBatchInfo(
batch, args, p_wt);
}

CuckooWatch cuckoo2("InferThetaAndUpdateNwtSparseNew", &cuckoo, kTimeLoggingThreshold);
ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
args, batch, part->batch_weight(), *sparse_ndx,
batch_info->transaction_to_index, batch_info->token_to_index,
batch_info->transactions, p_wt, theta_agents,
args, batch, part->batch_weight(),
batch_info, p_wt, theta_agents,
theta_matrix.get(), nwt_writer.get(),
blas, new_cache_entry_ptr.get());
} else {
Expand Down
207 changes: 86 additions & 121 deletions src/artm/core/processor_transaction_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,75 +5,30 @@
namespace artm {
namespace core {

std::shared_ptr<BatchTransactionInfo> ProcessorTransactionHelpers::GetBatchTransactionsInfo(const Batch& batch) {
ClassIdToTt class_id_to_tt;
std::unordered_map<std::vector<int>, int, IntVectorHasher> transaction_ids_to_index;
TransactionToIndex transaction_to_index;
std::vector<std::vector<Token>> transactions;
std::unordered_map<Token, int, TokenHasher> token_to_index;

for (int item_index = 0; item_index < batch.item_size(); ++item_index) {
const Item& item = batch.item(item_index);

for (int token_index = 0; token_index < item.transaction_start_index_size(); ++token_index) {
const int start_index = item.transaction_start_index(token_index);
const int end_index = (token_index + 1) < item.transaction_start_index_size() ?
item.transaction_start_index(token_index + 1) :
item.transaction_token_id_size();
std::vector<int> vec;
for (int i = start_index; i < end_index; ++i) {
vec.push_back(item.transaction_token_id(i));
}
auto iter = transaction_ids_to_index.find(vec);
if (iter == transaction_ids_to_index.end()) {
transaction_ids_to_index.insert(std::make_pair(vec, transaction_ids_to_index.size()));

std::string str;
for (int token_id = start_index; token_id < end_index; ++token_id) {
auto& tmp = batch.class_id(item.transaction_token_id(token_id));
str += (token_id == start_index) ? tmp : TransactionSeparator + tmp;
}

TransactionType tt(str);
for (const ClassId& class_id : tt.AsVector()) {
class_id_to_tt[class_id].emplace(tt);
}

std::vector<Token> transaction;
for (int idx = start_index; idx < end_index; ++idx) {
const int token_id = item.transaction_token_id(idx);
auto token = Token(batch.class_id(token_id), batch.token(token_id), tt);
transaction.push_back(token);
token_to_index.insert(std::make_pair(token, token_to_index.size()));
}
transactions.push_back(transaction);
transaction_to_index.insert(std::make_pair(transaction, transaction_to_index.size()));
}
namespace {
struct IntVectorHasher {
size_t operator()(const std::vector<int>& elems) const {
size_t hash = 0;
for (const int e : elems) {
boost::hash_combine<std::string>(hash, std::to_string(e));
}
return hash;
}
};

if (transaction_to_index.size() != transactions.size()) {
LOG(ERROR) << "Fatal error: transaction_to_index.size() [ " << transaction_to_index.size()
<< " ] != transactions.size() [ " << transactions.size();
}

if (transaction_ids_to_index.size() != transactions.size()) {
LOG(ERROR) << "Fatal error: transaction_ids_to_index.size() [ " << transaction_ids_to_index.size()
<< " ] != transactions.size() [ " << transactions.size();
}
typedef std::unordered_map<std::vector<int>, std::shared_ptr<TransactionInfo>, IntVectorHasher> TokenIdsToInfo;
} // namespace

return std::make_shared<BatchTransactionInfo>(
BatchTransactionInfo(class_id_to_tt, transaction_ids_to_index,
transaction_to_index, transactions, token_to_index));
}

std::shared_ptr<CsrMatrix<float>> ProcessorTransactionHelpers::InitializeSparseNdx(const Batch& batch,
const ProcessBatchesArgs& args, const ClassIdToTt& class_id_to_tt,
const std::unordered_map<std::vector<int>, int, IntVectorHasher>& transaction_to_index) {
std::shared_ptr<BatchTransactionInfo> ProcessorTransactionHelpers::PrepareBatchInfo(
const Batch& batch, const ProcessBatchesArgs& args, const ::artm::core::PhiMatrix& p_wt) {
std::vector<float> n_dw_val;
std::vector<int> n_dw_row_ptr;
std::vector<int> n_dw_col_ind;

std::unordered_map<Token, int, TokenHasher> token_to_index;
TokenIdsToInfo token_ids_to_info;
TransactionIdToInfo transaction_id_to_info;

bool use_weights = false;
std::unordered_map<TransactionType, float, TransactionHasher> tt_to_weight;
if (args.transaction_type_size() != 0) {
Expand All @@ -90,21 +45,27 @@ std::shared_ptr<CsrMatrix<float>> ProcessorTransactionHelpers::InitializeSparseN
n_dw_row_ptr.push_back(static_cast<int>(n_dw_val.size()));
const Item& item = batch.item(item_index);

auto func = [](int start_idx, int end_idx, const Batch& b, const Item& d) -> std::shared_ptr<TransactionType> {
std::string str;
for (int token_id = start_idx; token_id < end_idx; ++token_id) {
auto& tmp = b.class_id(d.transaction_token_id(token_id));
str += (token_id == start_idx) ? tmp : TransactionSeparator + tmp;
}
return std::make_shared<TransactionType>(str);
};

for (int token_index = 0; token_index < item.transaction_start_index_size(); ++token_index) {
const int start_index = item.transaction_start_index(token_index);
const int end_index = (token_index + 1) < item.transaction_start_index_size() ?
item.transaction_start_index(token_index + 1) :
item.transaction_token_id_size();

std::shared_ptr<TransactionType> tt = nullptr;
float transaction_weight = 1.0f;
if (use_weights) {
std::string str;
for (int token_id = start_index; token_id < end_index; ++token_id) {
auto& tmp = batch.class_id(item.transaction_token_id(token_id));
str += (token_id == start_index) ? tmp : TransactionSeparator + tmp;
}
auto iter = tt_to_weight.find(TransactionType(str));
transaction_weight = (iter == tt_to_weight.end()) ? 0.0f : iter->second;
tt = func(start_index, end_index, batch, item);
auto it = tt_to_weight.find(*tt);
transaction_weight = (it == tt_to_weight.end()) ? 0.0f : it->second;
}

const float token_weight = item.token_weight(token_index);
Expand All @@ -114,37 +75,54 @@ std::shared_ptr<CsrMatrix<float>> ProcessorTransactionHelpers::InitializeSparseN
for (int i = start_index; i < end_index; ++i) {
vec.push_back(item.transaction_token_id(i));
}
auto iter = transaction_to_index.find(vec);
if (iter != transaction_to_index.end()) {
n_dw_col_ind.push_back(iter->second);
auto iter = token_ids_to_info.find(vec);

if (iter != token_ids_to_info.end()) {
n_dw_col_ind.push_back(iter->second->transaction_index);
} else {
std::stringstream ss;
ss << "Fatal error: transaction_to_index doesn't contain transaction from indices:";
for (const int e : vec) {
ss << " " << e;
std::vector<int> local_indices;
std::vector<int> global_indices;

if (tt == nullptr) {
tt = func(start_index, end_index, batch, item);
}

for (int idx = start_index; idx < end_index; ++idx) {
const int token_id = item.transaction_token_id(idx);
auto token = Token(batch.class_id(token_id), batch.token(token_id), *tt);
token_to_index.insert(std::make_pair(token, token_to_index.size()));

local_indices.push_back(token_to_index.size() - 1);
global_indices.push_back(p_wt.token_index(token));
}
ss << " read from item with index " << item_index << " from batch " << batch.id()
<< ", empty matrix will be returned for this batch.";
LOG(ERROR) << ss.str();

return std::make_shared<CsrMatrix<float>>(0, 0, 0);
auto ptr = std::make_shared<TransactionInfo>(token_ids_to_info.size(), local_indices, global_indices);

token_ids_to_info.insert(std::make_pair(vec, ptr));
transaction_id_to_info.insert(std::make_pair(transaction_id_to_info.size(), ptr));

n_dw_col_ind.push_back(token_ids_to_info.size() - 1);
}
}
}
n_dw_row_ptr.push_back(static_cast<int>(n_dw_val.size()));

return std::make_shared<CsrMatrix<float>>(
transaction_to_index.size(), &n_dw_val, &n_dw_row_ptr, &n_dw_col_ind);
if (token_ids_to_info.size() != transaction_id_to_info.size()) {
LOG(ERROR) << "Fatal error: token_ids_to_info.size() [ " << token_ids_to_info.size()
<< " ] != transaction_id_to_info.size() [ " << transaction_id_to_info.size();
}

return std::make_shared<BatchTransactionInfo>(
std::make_shared<CsrMatrix<float>>(token_ids_to_info.size(),
&n_dw_val, &n_dw_row_ptr, &n_dw_col_ind),
transaction_id_to_info, token_to_index.size());
}

void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
const ProcessBatchesArgs& args,
const Batch& batch,
float batch_weight,
const CsrMatrix<float>& sparse_ndx,
const TransactionToIndex& transaction_to_index,
const std::unordered_map<Token, int, TokenHasher>& token_to_local_index,
const std::vector<std::vector<Token>>& transactions,
std::shared_ptr<BatchTransactionInfo> batch_info,
const ::artm::core::PhiMatrix& p_wt,
const RegularizeThetaAgentCollection& theta_agents,
LocalThetaMatrix<float>* theta_matrix,
Expand All @@ -158,8 +136,9 @@ void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
LocalThetaMatrix<float> n_td(theta_matrix->num_topics(), theta_matrix->num_items());
const int num_topics = p_wt.topic_size();
const int docs_count = theta_matrix->num_items();
const auto& sparse_ndx = *(batch_info->n_dx);

LocalPhiMatrix<float> local_phi(token_to_local_index.size(), num_topics);
LocalPhiMatrix<float> local_phi(batch_info->token_size, num_topics);
LocalThetaMatrix<float> r_td(num_topics, 1);
std::vector<float> helper_vector(num_topics, 0.0f);

Expand All @@ -172,19 +151,18 @@ void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
local_phi.InitializeZeros();
bool item_has_tokens = false;
for (int i = begin_index; i < end_index; ++i) {
int w = sparse_ndx.col_ind()[i];
auto& transaction = transactions[w];
for (const auto& token : transaction) {
if (p_wt.token_index(token) == ::artm::core::PhiMatrix::kUndefIndex) {
continue;
}
auto iter = token_to_local_index.find(token);
if (iter == token_to_local_index.end()) {
auto it = batch_info->transaction_id_to_info.find(sparse_ndx.col_ind()[i]);

for (int k = 0; k < it->second->local_pwt_token_index.size(); ++k) {
int global_index = it->second->global_pwt_token_index[k];
if (global_index == ::artm::core::PhiMatrix::kUndefIndex) {
continue;
}

item_has_tokens = true;
float* local_phi_ptr = &local_phi(iter->second, 0);
p_wt.get(p_wt.token_index(token), &helper_vector);

float* local_phi_ptr = &local_phi(it->second->local_pwt_token_index[k], 0);
p_wt.get(global_index, &helper_vector);
for (int k = 0; k < num_topics; ++k) {
local_phi_ptr[k] = helper_vector[k];
}
Expand All @@ -202,17 +180,11 @@ void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
}

for (int i = begin_index; i < end_index; ++i) {
int w = sparse_ndx.col_ind()[i];
std::fill(p_xt_local.begin(), p_xt_local.end(), 1.0f);
auto& transaction = transactions[w];
for (const auto& token : transaction) {
auto iter = token_to_local_index.find(token);

if (iter == token_to_local_index.end()) {
continue;
}
auto it = batch_info->transaction_id_to_info.find(sparse_ndx.col_ind()[i]);

const float* phi_ptr = &local_phi(iter->second, 0);
for (int local_index : it->second->local_pwt_token_index) {
const float* phi_ptr = &local_phi(local_index, 0);
for (int k = 0; k < num_topics; ++k) {
p_xt_local[k] *= phi_ptr[k];
}
Expand Down Expand Up @@ -252,21 +224,16 @@ void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(

std::vector<float> values(num_topics, 0.0f);
std::vector<float> p_xt_local(num_topics, 1.0f);
for (const auto& transaction : transactions) {
auto tr_iter = transaction_to_index.find(transaction);
if (tr_iter == transaction_to_index.end()) {
continue;
}
int transaction_index = tr_iter->second;

for (const auto& tuple : batch_info->transaction_id_to_info) {
int transaction_index = tuple.first;

std::fill(p_xt_local.begin(), p_xt_local.end(), 1.0f);
for (const auto& token : transaction) {
int phi_token_index = p_wt.token_index(token);
if (phi_token_index == ::artm::core::PhiMatrix::kUndefIndex) {
for (int global_index : tuple.second->global_pwt_token_index) {
if (global_index == ::artm::core::PhiMatrix::kUndefIndex) {
continue;
}

p_wt.get(phi_token_index, &helper_vector);
p_wt.get(global_index, &helper_vector);
for (int i = 0; i < num_topics; ++i) {
p_xt_local[i] *= helper_vector[i];
}
Expand All @@ -284,17 +251,15 @@ void ProcessorTransactionHelpers::TransactionInferThetaAndUpdateNwtSparse(
&(*theta_matrix)(0, d), 1, &helper_vector[0], 1); // NOLINT
}

for (const auto& token : transaction) {
int phi_token_index = p_wt.token_index(token);
if (phi_token_index == ::artm::core::PhiMatrix::kUndefIndex) {
for (int global_index : tuple.second->global_pwt_token_index) {
if (global_index == ::artm::core::PhiMatrix::kUndefIndex) {
continue;
}

for (int topic_index = 0; topic_index < num_topics; ++topic_index) {
values[topic_index] = p_xt_local[topic_index] * helper_vector[topic_index] * batch_weight;
}

nwt_writer->Store(-1, phi_token_index, values);
nwt_writer->Store(-1, global_index, values);
}
}
}
Expand Down

0 comments on commit d6dd1f7

Please sign in to comment.