Skip to content

Commit

Permalink
address code review
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrei committed Feb 23, 2018
1 parent 3f7b82a commit 40b21c5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/artm/core/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const int kIdleLoopFrequency = 1; // 1 ms

const int kBatchNameLength = 6;

// Defined in 3rdparty/protobuf-3.0.0/src/google/protobuf/io/coded_stream.h
const int64_t kProtobufCodedStreamTotalBytesLimit = 2147483647ULL;

template <typename T>
std::string to_string(T value) {
return boost::lexical_cast<std::string>(value);
Expand Down
14 changes: 10 additions & 4 deletions src/artm/core/dictionary_operations.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2017, Additive Regularization of Topic Models.

#include <algorithm>
#include <climits>
#include <fstream>
#include <functional>
#include <string>
Expand Down Expand Up @@ -94,9 +95,10 @@ void DictionaryOperations::Export(const ExportDictionaryArgs& args, const Dictio
}

std::string str = token_dict_data.SerializeAsString();
if (str.size() >= 2147483647ULL)
if (str.size() >= kProtobufCodedStreamTotalBytesLimit) {
BOOST_THROW_EXCEPTION(InvalidOperation("Dictionary " +
args.dictionary_name() + " is too large to export"));
}

int length = static_cast<int>(str.size());
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
Expand Down Expand Up @@ -135,10 +137,11 @@ void DictionaryOperations::Export(const ExportDictionaryArgs& args, const Dictio

if ((current_cooc_length >= max_cooc_length) || ((token_id + 1) == token_size)) {
std::string str = cooc_dict_data.SerializeAsString();
if (str.size() >= 2147483647ULL)
if (str.size() >= kProtobufCodedStreamTotalBytesLimit) {
BOOST_THROW_EXCEPTION(InvalidOperation(
"Unable to serialize coocurence information in Dictionary " +
args.dictionary_name()));
}

int length = static_cast<int>(str.size());
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
Expand Down Expand Up @@ -575,11 +578,14 @@ void DictionaryOperations::WriteDictionarySummaryToLog(const Dictionary& dict) {
std::map<ClassId, int> entries_per_class;
for (int i = 0; i < dict.size(); i++) {
const DictionaryEntry* entry = dict.entry(i);
if (entry != nullptr) entries_per_class[entry->token().class_id]++;
if (entry != nullptr) {
entries_per_class[entry->token().class_id]++;
}
}
std::stringstream ss; ss << "Dictionary name='" << dict.name() << "' contains entries: ";
for (auto const& x : entries_per_class)
for (auto const& x : entries_per_class) {
ss << x.first << ":" << x.second << "; ";
}
LOG(INFO) << ss.str();
}

Expand Down
6 changes: 4 additions & 2 deletions src/artm/core/master_component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,9 @@ void MasterComponent::ExportModel(const ExportModelArgs& args) {
::artm::TopicModel external_topic_model;
PhiMatrixOperations::RetrieveExternalTopicModel(n_wt, get_topic_model_args, &external_topic_model);
std::string str = external_topic_model.SerializeAsString();
if (str.size() >= 2147483647ULL)
if (str.size() >= kProtobufCodedStreamTotalBytesLimit) {
BOOST_THROW_EXCEPTION(InvalidOperation("TopicModel is too large to export"));
}
fout << str.size();
fout << str;
get_topic_model_args.clear_class_id();
Expand Down Expand Up @@ -477,8 +478,9 @@ void MasterComponent::ExportScoreTracker(const ExportScoreTrackerArgs& args) {
// We expect here that each ScoreData object has suitable size (< 2GB)
for (auto& item : instance_->score_tracker()->GetDataUnsafe()) {
auto str = item->SerializeAsString();
if (str.size() >= 2147483647ULL)
if (str.size() >= kProtobufCodedStreamTotalBytesLimit) {
BOOST_THROW_EXCEPTION(InvalidOperation("ScoreTracker is too large to export"));
}
fout << str.size();
fout << str;
}
Expand Down
3 changes: 2 additions & 1 deletion src/artm/core/protobuf_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "glog/logging.h"

#include "artm/core/common.h"
#include "artm/core/exceptions.h"

namespace pb = ::google::protobuf;
Expand All @@ -28,7 +29,7 @@ void ProtobufSerialization::ParseFromString(const std::string& string, google::p
}

void ProtobufSerialization::ParseFromArray(const char* buffer, int64_t length, google::protobuf::Message* message) {
if (length < 0 || length >= 2147483647) {
if (length < 0 || length >= kProtobufCodedStreamTotalBytesLimit) {
BOOST_THROW_EXCEPTION(CorruptedMessageException("Protobuf message is too long"));
}
ParseFromString((length >= 0) ? std::string(buffer, length) : std::string(buffer), message);
Expand Down

0 comments on commit 40b21c5

Please sign in to comment.