From 6fd340d48695531fd1be3844657e6597e60784d4 Mon Sep 17 00:00:00 2001 From: Giovanni De Toni Date: Tue, 20 Jun 2017 10:30:51 +0200 Subject: [PATCH] [PrematureStopping] Add on_next() and on_complete() method to CMachine. The default behaviour of these method is to set a variable called m_cancel_computation which can be used to terminate prematurely algorithms, by pressing CTRL+C. The algorithm has to be registered into the signal handler by calling connect_to_signal_handler(). Fixed also a bug inside Signal.cpp. --- src/shogun/lib/Signal.cpp | 7 ++++++- src/shogun/machine/Machine.cpp | 13 +++++++++++++ src/shogun/machine/Machine.h | 23 +++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/shogun/lib/Signal.cpp b/src/shogun/lib/Signal.cpp index 3defa0fdd5c..46dfebdf230 100644 --- a/src/shogun/lib/Signal.cpp +++ b/src/shogun/lib/Signal.cpp @@ -60,7 +60,8 @@ void CSignal::handler(int signal) "\n[ShogunSignalHandler] Immediately return to prompt / " "Prematurely finish " "computations / Do nothing (I/P/D)? ") - char answer = fgetc(stdin); + char answer = getchar(); + getchar(); switch (answer) { case 'I': @@ -73,6 +74,10 @@ void CSignal::handler(int signal) "[ShogunSignalHandler] Terminating" " prematurely current algorithm...\n"); m_sigurg_observable.connect(); + m_sigurg_observable = + rxcpp::observable<>::create([](rxcpp::subscriber s) { + s.on_next(1); + }).publish(); break; default: SG_SPRINT("[ShogunSignalHandler] Continuing...\n") diff --git a/src/shogun/machine/Machine.cpp b/src/shogun/machine/Machine.cpp index 44362db11c9..46d610ad9d5 100644 --- a/src/shogun/machine/Machine.cpp +++ b/src/shogun/machine/Machine.cpp @@ -9,6 +9,9 @@ * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society */ +#include +#include +#include #include using namespace shogun; @@ -18,6 +21,7 @@ CMachine::CMachine() : CSGObject(), m_max_train_time(0), m_labels(NULL), { m_data_locked=false; m_store_model_features=false; + m_cancel_computation = false; SG_ADD(&m_max_train_time, "max_train_time", "Maximum training time.", MS_NOT_AVAILABLE); @@ -269,3 +273,12 @@ CLatentLabels* CMachine::apply_locked_latent(SGVector indices) "for %s\n", get_name()); return NULL; } + +void CMachine::connect_to_signal_handler() +{ + // Subscribe this algorithm to the signal handler + auto subscriber = rxcpp::make_subscriber( + [this](int i) { this->on_next(); }, [this]() { this->on_complete(); }); + get_global_signal()->get_SIGINT_observable().subscribe(subscriber); + get_global_signal()->get_SIGURG_observable().subscribe(subscriber); +} diff --git a/src/shogun/machine/Machine.h b/src/shogun/machine/Machine.h index aa8a1d3b940..dc175c0ae10 100644 --- a/src/shogun/machine/Machine.h +++ b/src/shogun/machine/Machine.h @@ -306,6 +306,9 @@ class CMachine : public CSGObject return PT_BINARY; } + /** connect the machine instance to the signal handler */ + void connect_to_signal_handler(); + virtual const char* get_name() const { return "Machine"; } protected: @@ -357,6 +360,23 @@ class CMachine : public CSGObject /** returns whether machine require labels for training */ virtual bool train_require_labels() const { return true; } + /** @return whether the algorithm needs to be stopped */ + bool cancel_computation() const + { + return m_cancel_computation; + } + + /** */ + virtual void on_next() + { + m_cancel_computation = true; + }; + + /** */ + virtual void on_complete() + { + } + protected: /** maximum training time */ float64_t m_max_train_time; @@ -372,6 +392,9 @@ class CMachine : public CSGObject /** whether data is locked */ bool m_data_locked; + + /** Cancel computation */ + bool m_cancel_computation; }; } #endif // _MACHINE_H__