Skip to content

Commit

Permalink
[PrematureStopping] Add on_next() and on_complete() method to CMachine.
Browse files Browse the repository at this point in the history
The default behaviour of these method is to set a variable called
m_cancel_computation which can be used to terminate prematurely
algorithms, by pressing CTRL+C.

The algorithm has to be registered into the signal handler
by calling connect_to_signal_handler().

Fixed also a bug inside Signal.cpp.
  • Loading branch information
geektoni committed Jun 21, 2017
1 parent c84a73e commit 6fd340d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/shogun/lib/Signal.cpp
Expand Up @@ -60,7 +60,8 @@ void CSignal::handler(int signal)
"\n[ShogunSignalHandler] Immediately return to prompt / "
"Prematurely finish "
"computations / Do nothing (I/P/D)? ")
char answer = fgetc(stdin);
char answer = getchar();
getchar();
switch (answer)
{
case 'I':
Expand All @@ -73,6 +74,10 @@ void CSignal::handler(int signal)
"[ShogunSignalHandler] Terminating"
" prematurely current algorithm...\n");
m_sigurg_observable.connect();
m_sigurg_observable =
rxcpp::observable<>::create<int>([](rxcpp::subscriber<int> s) {
s.on_next(1);
}).publish();
break;
default:
SG_SPRINT("[ShogunSignalHandler] Continuing...\n")
Expand Down
13 changes: 13 additions & 0 deletions src/shogun/machine/Machine.cpp
Expand Up @@ -9,6 +9,9 @@
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
*/

#include <rxcpp/rx.hpp>
#include <shogun/base/init.h>
#include <shogun/lib/Signal.h>
#include <shogun/machine/Machine.h>

using namespace shogun;
Expand All @@ -18,6 +21,7 @@ CMachine::CMachine() : CSGObject(), m_max_train_time(0), m_labels(NULL),
{
m_data_locked=false;
m_store_model_features=false;
m_cancel_computation = false;

SG_ADD(&m_max_train_time, "max_train_time",
"Maximum training time.", MS_NOT_AVAILABLE);
Expand Down Expand Up @@ -269,3 +273,12 @@ CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices)
"for %s\n", get_name());
return NULL;
}

void CMachine::connect_to_signal_handler()
{
// Subscribe this algorithm to the signal handler
auto subscriber = rxcpp::make_subscriber<int>(
[this](int i) { this->on_next(); }, [this]() { this->on_complete(); });
get_global_signal()->get_SIGINT_observable().subscribe(subscriber);
get_global_signal()->get_SIGURG_observable().subscribe(subscriber);
}
23 changes: 23 additions & 0 deletions src/shogun/machine/Machine.h
Expand Up @@ -306,6 +306,9 @@ class CMachine : public CSGObject
return PT_BINARY;
}

/** connect the machine instance to the signal handler */
void connect_to_signal_handler();

virtual const char* get_name() const { return "Machine"; }

protected:
Expand Down Expand Up @@ -357,6 +360,23 @@ class CMachine : public CSGObject
/** returns whether machine require labels for training */
virtual bool train_require_labels() const { return true; }

/** @return whether the algorithm needs to be stopped */
bool cancel_computation() const
{
return m_cancel_computation;
}

/** */
virtual void on_next()
{
m_cancel_computation = true;
};

/** */
virtual void on_complete()
{
}

protected:
/** maximum training time */
float64_t m_max_train_time;
Expand All @@ -372,6 +392,9 @@ class CMachine : public CSGObject

/** whether data is locked */
bool m_data_locked;

/** Cancel computation */
bool m_cancel_computation;
};
}
#endif // _MACHINE_H__

0 comments on commit 6fd340d

Please sign in to comment.