Skip to content

Commit

Permalink
KLResult, Meanshift::Result::aborted, Command::isAborted(), some cons…
Browse files Browse the repository at this point in the history
…t correctness
  • Loading branch information
Georg Altmann committed Nov 14, 2014
1 parent 9cfe6b3 commit 7268ccf
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 46 deletions.
3 changes: 2 additions & 1 deletion core/progress_observer.h
Expand Up @@ -27,7 +27,8 @@ class ProgressObserver
*
* All derived classes need to query and honor this flag while executing.
*/
bool isAborted() { return abortflag; }
bool isAborted() const { return abortflag; }

/** Tell the worker thread to abort the computation.
*
* This is virtual so that it can be a slot in derived classes.
Expand Down
1 change: 1 addition & 0 deletions seg_meanshift/CMakeLists.txt
Expand Up @@ -17,6 +17,7 @@ vole_compile_library(
"meanshift_shell"
"meanshift_sp"
"meanshift_som"
"meanshift_klresult"
)

vole_add_module()
3 changes: 1 addition & 2 deletions seg_meanshift/meanshift.cpp
Expand Up @@ -22,8 +22,7 @@

namespace seg_meanshift {

std::pair<int, int> MeanShift::findKL(const multi_img& input,
ProgressObserver *po)
KLResult MeanShift::findKL(const multi_img& input, ProgressObserver *po)
{
// load points
FAMS cfams(config, po);
Expand Down
19 changes: 14 additions & 5 deletions seg_meanshift/meanshift.h
Expand Up @@ -17,10 +17,18 @@ class MeanShift {

public:
struct Result {
Result() : labels(new cv::Mat1s()),
modes(new std::vector<multi_img::Pixel>()) {}
Result()
: labels(new cv::Mat1s()),
modes(new std::vector<multi_img::Pixel>()),
aborted(false)
{}
Result(const std::vector<multi_img::Pixel>& m,
const cv::Mat1s& l) { setModes(m); setLabels(l); }
const cv::Mat1s& l)
: aborted(false)
{ setModes(m); setLabels(l); }

// default copy and assignment OK

void setModes(const std::vector<multi_img::Pixel>& in) {
modes = boost::make_shared<std::vector<multi_img::Pixel> >(in);
}
Expand All @@ -42,12 +50,13 @@ class MeanShift {
}
boost::shared_ptr<cv::Mat1s> labels;
boost::shared_ptr<std::vector<multi_img::Pixel> > modes;
// FIXME: make sure this is set and propagated in every return.
bool aborted;
};

MeanShift(const MeanShiftConfig& config) : config(config) {}

std::pair<int, int> findKL(const multi_img& input,
ProgressObserver *po = 0);
KLResult findKL(const multi_img& input, ProgressObserver *po = 0);
Result execute(const multi_img& input, ProgressObserver *po = 0,
vector<double> *bandwidths = 0,
const multi_img& spinput = multi_img());
Expand Down
32 changes: 32 additions & 0 deletions seg_meanshift/meanshift_klresult.cpp
@@ -0,0 +1,32 @@
#include "meanshift_klresult.h"

#include <iostream>

namespace seg_meanshift {

void diagnoseKLResult(KLResult const& ret)
{
if (ret.isState(KLState::Aborted)) {
std::cerr << "findKL computation was aborted" << std::endl;
} else if (ret.isState(KLState::NoneFound)) {
std::cerr << "findKL computation found no solution" << std::endl;
}
}

void KLResult::insertInto(std::map<std::string, boost::any> &dest)
{
std::map<std::string, boost::any> rmap = makeKeyValueMap();
dest.insert(rmap.begin(), rmap.end());
}

std::map<std::string, boost::any> KLResult::makeKeyValueMap() const
{
std::map<std::string, boost::any> res;
res["findKL.K"] = K;
res["findKL.L"] = L;
res["findKL.aborted"] = isState(KLState::Aborted);
res["findKL.good"] = isGood();
return res;
}

} // namespace seg_meanshift
68 changes: 68 additions & 0 deletions seg_meanshift/meanshift_klresult.h
@@ -0,0 +1,68 @@
#ifndef MEANSHIFT_KLRESULT_H
#define MEANSHIFT_KLRESULT_H

// These are for makeKeyValueMap() only. Maybe find better solution than
// adding these include dependencies.
#include <map>
#include <string>
#include <boost/any.hpp>

namespace seg_meanshift {

struct KLState {
/** KLState flags. */
enum t {
Good = 0x0, //! Good result, 0 for bitwise operations.
Aborted = 0x1, //! The computation was aborted.
NoneFound = 0x2 //! No solution found.
};
};

class KLResult{
public:
/** Create KLResult with flags s0 and s1 set. */
KLResult(int K, int L,
KLState::t s0 = KLState::Good,
KLState::t s1 = KLState::Good
)
: K(K), L(L), state(KLState::t(s0 | s1))
{}

const int K;
const int L;

/** Returns true if flag s is set. */
bool isState(KLState::t s) const {
if (KLState::Good == s)
return isGood();
else
return (state & s) != 0;
}

/** Returns true if findKL did not abort and found a valid result. */
bool isGood() const {
return state == KLState::Good;
}

/** Insert result and flags into key-value map.
*
* The inserted values are:
*
* * findKL.K int
* * findKL.L int
* * findKL.aborted bool
* * findKL.good bool
*/
void insertInto(std::map<std::string, boost::any> &dest);

private:
std::map<std::string, boost::any> makeKeyValueMap() const;
KLState::t state;
};

/** Print informational messages to cerr on KLResult state. */
void diagnoseKLResult(KLResult const& ret);

} // namespace seg_meanshift

#endif // MEANSHIFT_KLRESULT_H
49 changes: 29 additions & 20 deletions seg_meanshift/meanshift_shell.cpp
Expand Up @@ -62,16 +62,18 @@ int MeanShiftShell::execute() {

MeanShift ms(config);
if (config.findKL) {
// find K, L
std::pair<int, int> ret = ms.findKL(
// find K, L
KLResult ret = ms.findKL(
#ifdef WITH_SEG_FELZENSZWALB
(config.sp_withGrad ? *input_grad : *input));
#else
*input);
#endif
config.K = ret.first; config.L = ret.second;
diagnoseKLResult(ret);
std::cout << "Found K = " << config.K
<< "\tL = " << config.L << std::endl;
<< "\tL = " << config.L << std::endl;
config.K = ret.K; config.L = ret.L;

return 0;
}

Expand All @@ -92,8 +94,10 @@ int MeanShiftShell::execute() {
res = ms.execute(*input, NULL, NULL, *input);
}

if (res.modes->empty())
return 1; // something went wrong, there should always be one mode!
if (res.aborted ||
// something went wrong, there should always be one mode!
res.modes->empty())
return 1;

res.printModes();

Expand Down Expand Up @@ -142,19 +146,20 @@ MeanShiftShell::execute(std::map<std::string, boost::any> &input,
MeanShift ms(config);
std::map<std::string, boost::any> output;
if (config.findKL) {
// find K, L
std::pair<int, int> ret = ms.findKL(
// find K, L
KLResult res = ms.findKL(
#ifdef WITH_SEG_FELZENSZWALB
(config.sp_withGrad ? *inputgrad : *inputimg));
(config.sp_withGrad ? *inputgrad : *inputimg));
#else
*inputimg);
*inputimg);
#endif
config.K = ret.first; config.L = ret.second;
std::cout << "Found K = " << config.K
<< "\tL = " << config.L << std::endl;

output["findKL.K"] = ret.first;
output["findKL.L"] = ret.second;
if (res.isGood()) {
config.K = res.K; config.L = res.L;
std::cout << "Found K = " << config.K
<< "\tL = " << config.L << std::endl;
}
res.insertInto(output);
return output;
} else {
MeanShift::Result res = ms.execute(
#ifdef WITH_SEG_FELZENSZWALB
Expand All @@ -163,11 +168,15 @@ MeanShiftShell::execute(std::map<std::string, boost::any> &input,
*inputimg,
#endif
progress, NULL, *inputimg);
output["labels"] = res.labels;
output["modes"] = res.modes;
if (!res.aborted) {
output["labels"] = res.labels;
output["modes"] = res.modes;
return output;
} else {
output["aborted"] = true;
return output;
}
}

return output;
}


Expand Down
8 changes: 4 additions & 4 deletions seg_meanshift/meanshift_som.cpp
Expand Up @@ -126,13 +126,13 @@ MeanShiftSOM::Result MeanShiftSOM::execute(multi_img::ptr input)

assert(!config.findKL);

MeanShift::Result ret_in = ms.execute(msinput, 0,
MeanShift::Result msres = ms.execute(msinput, 0,
(config.sp_weight > 0 ? &weights : 0));
if (ret_in.labels->empty())
if (msres.labels->empty())
return Result();

Result ret_out;
ret_out.modes = ret_in.modes;
ret_out.modes = msres.modes;
ret_out.som = som;
ret_out.lookup = mapping;

Expand All @@ -147,7 +147,7 @@ MeanShiftSOM::Result MeanShiftSOM::execute(multi_img::ptr input)
cv::Point pos = som->getCoord2D(answer.first->index);

// get segement number of 2D coordinate position
short index = (*ret_in.labels)(pos) - 1;
short index = (*msres.labels)(pos) - 1;

(*ret_out.labels)(y, x) = index;
}
Expand Down
38 changes: 31 additions & 7 deletions seg_meanshift/meanshift_sp.cpp
Expand Up @@ -82,6 +82,8 @@ std::map<std::string, boost::any> MeanShiftSP::execute(std::map<std::string, boo
#ifdef WITH_SEG_FELZENSZWALB
// XXX: for now, gradient/rescale is expected to be done by caller

setProgressObserver(progress);

boost::shared_ptr<multi_img> inputimg =
boost::any_cast<boost::shared_ptr<multi_img> >(input["multi_img"]);
boost::shared_ptr<multi_img> inputgrad;
Expand All @@ -92,17 +94,32 @@ std::map<std::string, boost::any> MeanShiftSP::execute(std::map<std::string, boo

// make sure pixel caches are built
inputimg->rebuildPixels(true);

// FIXME New behaviour: output is empty if we are aborted. Make this clear
// in the execute description and check client code.
std::map<std::string, boost::any> output;
output["aborted"] = true;
if (isAborted())
return output;

if (config.sp_withGrad)
inputgrad->rebuildPixels(true);

if (isAborted())
return output;

MeanShift::Result res = execute(inputimg, inputgrad);
std::map<std::string, boost::any> output;
output["labels"] = res.labels;
output["modes"] = res.modes;

if (isAborted())
return output;

output["aborted"] = false;
output["labels"] = res.labels;
output["modes"] = res.modes;
return output;
#else
#else // WITH_SEG_FELZENSZWALB
throw std::runtime_error("Module seg_felzenszwalb needed, but missing!");
#endif
#endif // WITH_SEG_FELZENSZWALB
}

MeanShift::Result MeanShiftSP::execute(multi_img::ptr input, multi_img::ptr input_grad)
Expand Down Expand Up @@ -184,8 +201,15 @@ MeanShift::Result MeanShiftSP::execute(multi_img::ptr input, multi_img::ptr inpu

if (config.findKL) {
// find K, L
std::pair<int, int> ret = ms.findKL(msinput);
config.K = ret.first; config.L = ret.second;
KLResult ret = ms.findKL(msinput);
diagnoseKLResult(ret);
if (ret.isState(KLState::Aborted)) {
MeanShift::Result myres;
myres.aborted = true;
return myres;
}

config.K = ret.K; config.L = ret.L;
std::cout << "Found K = " << config.K
<< "\tL = " << config.L << std::endl;
return MeanShift::Result();
Expand Down
12 changes: 6 additions & 6 deletions seg_meanshift/mfams.cpp
Expand Up @@ -378,7 +378,7 @@ bool FAMS::finishFAMS() {
}

// main function to find K and L
std::pair<int,int> FAMS::FindKL() {
KLResult FAMS::FindKL() {
int Kmin = config.Kmin, Kmax = config.K, Kjump = config.Kjump;
int Lmax = config.L, k = config.K;
float width = config.bandwidth, epsilon = config.epsilon;
Expand All @@ -388,7 +388,7 @@ std::pair<int,int> FAMS::FindKL() {

if (datapoints.empty()) {
bgLog("Load points first\n");
return make_pair(0, 0);
return KLResult(0, 0, KLState::Aborted);
}

int hWidth = 0;
Expand Down Expand Up @@ -442,7 +442,7 @@ std::pair<int,int> FAMS::FindKL() {
bool cont = progressUpdate(50.f * (Kmax-Kcrt)/(Kmax-Kmin));
if (!cont) {
bgLog("FindKL aborted\n");
return std::make_pair(0, 0);
return KLResult(0, 0, KLState::Aborted);
}

// update Lcrt to reduce running time!
Expand All @@ -462,7 +462,7 @@ std::pair<int,int> FAMS::FindKL() {
bool cont = progressUpdate(50.f + 50.f * i/nBest);
if (!cont) {
bgLog("FindKL aborted\n");
return std::make_pair(0, 0);
return KLResult(0, 0, KLState::Aborted);
}

if (LBest[i] <= 0)
Expand All @@ -482,10 +482,10 @@ std::pair<int,int> FAMS::FindKL() {
bgLog("done\n");

if (iBest != -1) {
return std::make_pair(KBest[iBest], LBest[iBest]);
return KLResult(KBest[iBest], LBest[iBest]);
} else {
bgLog("No valid pairs found.\n");
return std::make_pair(0, 0);
return KLResult(0, 0, KLState::NoneFound);
}
}

Expand Down

0 comments on commit 7268ccf

Please sign in to comment.