Skip to content

Commit

Permalink
Merge pull request #160 from sashafrey/master
Browse files Browse the repository at this point in the history
Bug fix (#71, #133, #134, #136)
  • Loading branch information
bigartm committed Mar 15, 2015
2 parents 9e4a996 + dd0cd9b commit c7e9e04
Show file tree
Hide file tree
Showing 16 changed files with 227 additions and 135 deletions.
4 changes: 4 additions & 0 deletions docs/ref/messages.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,7 @@ Represents an argument of :c:func:`ArtmInvokeIteration` operation.
message InvokeIterationArgs {
optional int32 iterations_count = 1 [default = 1];
optional bool reset_scores = 2 [default = true];
optional string disk_path = 3;
}

.. attribute:: InvokeIterationArgs.iterations_count
Expand All @@ -1956,6 +1957,9 @@ Represents an argument of :c:func:`ArtmInvokeIteration` operation.

An optional flag that defines whether to reset all scores before this operation.

.. attribute:: InvokeIterationArgs.disk_path

A value that defines the disk location with batches to process on this iteration.

.. _WaitIdleArgs:

Expand Down
3 changes: 2 additions & 1 deletion docs/ref/python_interface.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ MasterComponent
It returns *True* if await succeeded within the timeout, otherwise returns *False*.
The provided timeout is in milliseconds. Use *timeout = -1* to allow infinite time for :py:meth:`AddBatch` operation.

.. py:method:: InvokeIteration(iterations_count = 1)
.. py:method:: InvokeIteration(iterations_count = 1, disk_path = None)

Invokes several iterations over the collection. The recommended value for *iterations_count* is 1.
*disk_path* defines the disk location with batches to process on this iteration.
For more iterations use for loop around :py:meth:`InvokeIteration` method.
This operation is asynchronous. Use :py:meth:`WaitIdle` to await until all iterations succeeded.

Expand Down
43 changes: 30 additions & 13 deletions src/artm/core/data_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ LocalDataLoader::LocalDataLoader(Instance* instance)
is_stopping(false),
thread_() {
std::string disk_path = instance->schema()->config().disk_path();
if (disk_path.empty()) {
generation_.reset(new MemoryGeneration());
} else {
if (!disk_path.empty()) {
generation_.reset(new DiskGeneration(disk_path));
}

Expand Down Expand Up @@ -147,10 +145,6 @@ bool LocalDataLoader::AddBatch(const AddBatchArgs& args) {
return true;
}

int LocalDataLoader::GetTotalItemsCount() const {
return generation_->GetTotalItemsCount();
}

void LocalDataLoader::InvokeIteration(const InvokeIterationArgs& args) {
int iterations_count = args.iterations_count();
if (iterations_count <= 0) {
Expand All @@ -159,14 +153,22 @@ void LocalDataLoader::InvokeIteration(const InvokeIterationArgs& args) {
return;
}

auto latest_generation = generation_.get();
if (generation_->empty()) {
DiskGeneration* generation;
std::unique_ptr<DiskGeneration> args_generation;
if (args.has_disk_path()) {
args_generation.reset(new DiskGeneration(args.disk_path()));
generation = args_generation.get();
} else {
generation = generation_.get();
}

if (generation == nullptr || generation->empty()) {
LOG(WARNING) << "DataLoader::InvokeIteration() - current generation is empty, "
<< "please populate DataLoader data with some data";
return;
}

std::vector<BatchManagerTask> tasks = latest_generation->batch_uuids();
std::vector<BatchManagerTask> tasks = generation->batch_uuids();
for (int iter = 0; iter < iterations_count; ++iter) {
for (auto &task : tasks) {
instance_->batch_manager()->Add(task);
Expand Down Expand Up @@ -282,7 +284,16 @@ void LocalDataLoader::ThreadFunction() {
continue;
}

std::shared_ptr<const Batch> batch = generation_->batch(next_task);
std::shared_ptr<Batch> batch = std::make_shared< ::artm::Batch>();
try {
::artm::core::BatchHelpers::LoadMessage(next_task.file_path, batch.get());
batch->set_id(boost::lexical_cast<std::string>(next_task.uuid)); // keep batch.id and task.uuid in sync
::artm::core::BatchHelpers::PopulateClassId(batch.get());
} catch (std::exception& ex) {
LOG(ERROR) << ex.what() << ", the batch will be skipped.";
batch = nullptr;
}

if (batch == nullptr) {
instance_->batch_manager()->Done(next_task.uuid, ModelName());
continue;
Expand Down Expand Up @@ -401,8 +412,14 @@ void RemoteDataLoader::ThreadFunction() {
std::string batch_file_path = response.batch_file_path(batch_index);

auto batch = std::make_shared< ::artm::Batch>();
::artm::core::BatchHelpers::LoadMessage(batch_file_path, batch.get());
::artm::core::BatchHelpers::PopulateClassId(batch.get());
try {
::artm::core::BatchHelpers::LoadMessage(batch_file_path, batch.get());
::artm::core::BatchHelpers::PopulateClassId(batch.get());
}
catch (std::exception& ex) {
LOG(ERROR) << ex.what() << ", the batch will be skipped";
batch = nullptr;
}

if (batch == nullptr) {
LOG(ERROR) << "Unable to load batch '" << batch_id << "' from " << config.disk_path();
Expand Down
5 changes: 2 additions & 3 deletions src/artm/core/data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace core {

class Instance;
class MasterComponentService_Stub;
class Generation;
class DiskGeneration;

class DataLoader : boost::noncopyable, public Notifiable {
public:
Expand All @@ -49,7 +49,6 @@ class LocalDataLoader : public DataLoader {
explicit LocalDataLoader(Instance* instance);
virtual ~LocalDataLoader();

int GetTotalItemsCount() const;
bool AddBatch(const AddBatchArgs& args);
virtual void Callback(ModelIncrement* model_increment);

Expand All @@ -62,7 +61,7 @@ class LocalDataLoader : public DataLoader {
::artm::ThetaMatrix* theta_matrix);

private:
std::unique_ptr<Generation> generation_;
std::unique_ptr<DiskGeneration> generation_;

typedef std::pair<boost::uuids::uuid, ModelName> CacheKey;
ThreadSafeCollectionHolder<CacheKey, DataLoaderCacheEntry> cache_;
Expand Down
53 changes: 0 additions & 53 deletions src/artm/core/generation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,10 @@ DiskGeneration::DiskGeneration(const std::string& disk_path)
}
}

boost::uuids::uuid DiskGeneration::AddBatch(const std::shared_ptr<Batch>& batch) {
std::string message = "ArtmAddBatch() is not allowed with current configuration. ";
message += "Please, set the configuration parameter MasterComponentConfig.disk_path ";
message += "to an empty string in order to enable ArtmAddBatch() operation. ";
message += "Use ArtmSaveBatch() operation to save batches to disk.";
BOOST_THROW_EXCEPTION(InvalidOperation(message));
}

void DiskGeneration::RemoveBatch(const boost::uuids::uuid& uuid) {
LOG(ERROR) << "Remove batch is not supported in disk generation.";
}

std::vector<BatchManagerTask> DiskGeneration::batch_uuids() const {
return generation_;
}

std::shared_ptr<Batch> DiskGeneration::batch(const BatchManagerTask& task) const {
auto batch = std::make_shared< ::artm::Batch>();
::artm::core::BatchHelpers::LoadMessage(task.file_path, batch.get());
batch->set_id(boost::lexical_cast<std::string>(task.uuid)); // keep batch.id and task.uuid in sync
::artm::core::BatchHelpers::PopulateClassId(batch.get());
return batch;
}

std::shared_ptr<Batch> MemoryGeneration::batch(const BatchManagerTask& task) const {
return generation_.get(task.uuid);
}

std::vector<BatchManagerTask> MemoryGeneration::batch_uuids() const {
std::vector<BatchManagerTask> retval;
auto keys = generation_.keys();
for (auto& key : keys)
retval.push_back(BatchManagerTask(key, std::string()));
return retval;
}

boost::uuids::uuid MemoryGeneration::AddBatch(const std::shared_ptr<Batch>& batch) {
boost::uuids::uuid retval = boost::uuids::random_generator()();
generation_.set(retval, batch);
return retval;
}

void MemoryGeneration::RemoveBatch(const boost::uuids::uuid& uuid) {
generation_.erase(uuid);
}

int MemoryGeneration::GetTotalItemsCount() const {
auto keys = generation_.keys();
int retval = 0;
for (auto& key : keys) {
auto value = generation_.get(key);
if (value != nullptr) retval += value->item_size();
}

return retval;
}

} // namespace core
} // namespace artm

37 changes: 3 additions & 34 deletions src/artm/core/generation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,18 @@
namespace artm {
namespace core {

// Must be thread-safe (used concurrently in DataLoader).
class Generation {
public:
virtual std::vector<BatchManagerTask> batch_uuids() const = 0;
virtual std::shared_ptr<Batch> batch(const BatchManagerTask& task) const = 0;
virtual bool empty() const = 0;
virtual int GetTotalItemsCount() const = 0;
virtual boost::uuids::uuid AddBatch(const std::shared_ptr<Batch>& batch) = 0;
virtual void RemoveBatch(const boost::uuids::uuid& uuid) = 0;
};

class DiskGeneration : public Generation {
class DiskGeneration {
public:
explicit DiskGeneration(const std::string& disk_path);

virtual std::vector<BatchManagerTask> batch_uuids() const;
virtual std::shared_ptr<Batch> batch(const BatchManagerTask& task) const;

virtual boost::uuids::uuid AddBatch(const std::shared_ptr<Batch>& batch);
virtual void RemoveBatch(const boost::uuids::uuid& uuid);
virtual int GetTotalItemsCount() const { return 0; }
virtual bool empty() const { return generation_.empty(); }
std::vector<BatchManagerTask> batch_uuids() const;
bool empty() const { return generation_.empty(); }

private:
std::string disk_path_;
std::vector<BatchManagerTask> generation_; // created one in constructor and then does not change.
};

class MemoryGeneration : public Generation {
public:
virtual std::vector<BatchManagerTask> batch_uuids() const;
virtual std::shared_ptr<Batch> batch(const BatchManagerTask& task) const;

virtual boost::uuids::uuid AddBatch(const std::shared_ptr<Batch>& batch);
virtual void RemoveBatch(const boost::uuids::uuid& uuid);

virtual bool empty() const { return generation_.empty(); }
virtual int GetTotalItemsCount() const;

private:
ThreadSafeCollectionHolder<boost::uuids::uuid, Batch> generation_;
};

} // namespace core
} // namespace artm

Expand Down

0 comments on commit c7e9e04

Please sign in to comment.