Skip to content

Commit

Permalink
Add LMNN statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
iglesias committed Sep 12, 2013
1 parent 618878d commit 0497410
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/interfaces/modular/Metric.i
Expand Up @@ -9,6 +9,7 @@
*/

/* Remove C Prefix */
%rename(LMNNStatistics) CLMNNStatistics;
%rename(LMNN) CLMNN;

/* Include Class Headers to make them visible from within the target language */
Expand Down
67 changes: 67 additions & 0 deletions src/shogun/metric/LMNN.cpp
Expand Up @@ -26,6 +26,9 @@ using namespace Eigen;
CLMNN::CLMNN()
{
init();

m_statistics = new CLMNNStatistics();
SG_REF(m_statistics);
}

CLMNN::CLMNN(CDenseFeatures<float64_t>* features, CMulticlassLabels* labels, int32_t k)
Expand All @@ -38,12 +41,16 @@ CLMNN::CLMNN(CDenseFeatures<float64_t>* features, CMulticlassLabels* labels, int

SG_REF(m_features)
SG_REF(m_labels)

m_statistics = new CLMNNStatistics();
SG_REF(m_statistics);
}

CLMNN::~CLMNN()
{
SG_UNREF(m_features)
SG_UNREF(m_labels)
SG_UNREF(m_statistics);
}

const char* CLMNN::get_name() const
Expand Down Expand Up @@ -85,6 +92,8 @@ void CLMNN::train(SGMatrix<float64_t> init_transform)
uint32_t iter = 0;
// Criterion for termination
bool stop = false;
// Make space for the training statistics
m_statistics->resize(m_maxiter);

/// Main loop
while (!stop)
Expand Down Expand Up @@ -121,10 +130,16 @@ void CLMNN::train(SGMatrix<float64_t> init_transform)
// Update previous set of impostors
prev_impostors = cur_impostors;

// Store statistics for this iteration
m_statistics->set(iter-1, obj[iter-1], stepsize, cur_impostors.size());

SG_DEBUG("iteration=%d, objective=%.4f, #impostors=%4d, stepsize=%.4E\n",
iter, obj[iter-1], cur_impostors.size(), stepsize)
}

// Truncate statistics in case convergence was reached in less than maxiter
m_statistics->resize(iter);

/// Store the transformation found in the class attribute
int32_t nfeats = x->get_num_features();
float64_t* cloned_data = SGMatrix<float64_t>::clone_matrix(L.data(), nfeats, nfeats);
Expand Down Expand Up @@ -250,6 +265,12 @@ void CLMNN::set_diagonal(const bool diagonal)
m_diagonal = diagonal;
}

CLMNNStatistics* CLMNN::get_statistics() const
{
SG_REF(m_statistics);
return m_statistics;
}

void CLMNN::init()
{
SG_ADD(&m_linear_transform, "linear_transform",
Expand All @@ -273,6 +294,8 @@ void CLMNN::init()
SG_ADD(&m_obj_threshold, "obj_threshold", "Objective threshold",
MS_NOT_AVAILABLE)
SG_ADD(&m_diagonal, "m_diagonal", "Diagonal transformation", MS_NOT_AVAILABLE);
SG_ADD((CSGObject**) &m_statistics, "statistics", "Training statistics",
MS_NOT_AVAILABLE);

m_features = NULL;
m_labels = NULL;
Expand All @@ -284,6 +307,50 @@ void CLMNN::init()
m_correction = 15;
m_obj_threshold = 1e-9;
m_diagonal = false;
m_statistics = NULL;
}

CLMNNStatistics::CLMNNStatistics()
{
init();
}

CLMNNStatistics::~CLMNNStatistics()
{
}

const char* CLMNNStatistics::get_name() const
{
return "LMNNStatistics";
}

void CLMNNStatistics::resize(int32_t size)
{
REQUIRE(size > 0, "The new size in CLMNNStatistics::resize must be larger than zero."
" Given value is %d.\n", size);

obj.resize_vector(size);
stepsize.resize_vector(size);
num_impostors.resize_vector(size);
}

void CLMNNStatistics::set(index_t iter, float64_t obj_iter, float64_t stepsize_iter,
uint32_t num_impostors_iter)
{
REQUIRE(iter >= 0 && iter < obj.vlen, "The iteration index in CLMNNStatistics::set "
"must be larger or equal to zero and less than the size (%d). Given valu is %d.\n", obj.vlen, iter);

obj[iter] = obj_iter;
stepsize[iter] = stepsize_iter;
num_impostors[iter] = num_impostors_iter;
}

void CLMNNStatistics::init()
{
SG_ADD(&obj, "obj", "Objective at each iteration", MS_NOT_AVAILABLE);
SG_ADD(&stepsize, "stepsize", "Step size at each iteration", MS_NOT_AVAILABLE);
SG_ADD(&num_impostors, "num_impostors", "Number of impostors at each iteration",
MS_NOT_AVAILABLE);
}

#endif /* HAVE_LAPACK */
Expand Down
63 changes: 63 additions & 0 deletions src/shogun/metric/LMNN.h
Expand Up @@ -25,6 +25,9 @@
namespace shogun
{

// Forward declaration
class CLMNNStatistics;

/**
* @brief Class LMNN that implements the distance metric learning technique
* Large Margin Nearest Neighbour (LMNN) described in
Expand Down Expand Up @@ -171,6 +174,12 @@ class CLMNN : public CSGObject
*/
void set_diagonal(const bool diagonal);

/** get LMNN training statistics
*
* @return LMNN training statistics
*/
CLMNNStatistics* get_statistics() const;

private:
/** register parameters */
void init();
Expand Down Expand Up @@ -229,8 +238,62 @@ class CLMNN : public CSGObject
*/
bool m_diagonal;

/** training statistics, @see CLMNNStatistics */
CLMNNStatistics* m_statistics;

}; /* class CLMNN */

/**
* @brief Class LMNNStatistics used to give access to intermediate results
* obtained training LMNN.
*/
class CLMNNStatistics : public CSGObject
{
public:
/** default constructor */
CLMNNStatistics();

/** destructor */
virtual ~CLMNNStatistics();

/** @return name of SGSerializable */
virtual const char* get_name() const;

/**
* resize CLMNNStatistics::obj, CLMNNStatistics::stepsize and
* CLMNNStatistics::num_impostors to fit the specified number of elements
*
* @param size number of elements
*/
void resize(int32_t size);

/**
* set objective, step size and number of impostors computed at the
* specified iteration
*
* @param iter index to store the parameters, must be greater or equal to zero,
* and less than the size
* @param obj_iter objective to set
* @param stepsize_iter stepsize to set
* @param num_impostors_iter number of impostors to set
*/
void set(index_t iter, float64_t obj_iter, float64_t stepsize_iter, uint32_t num_impostors_iter);

private:
/** register parameters */
void init();

public:
/** objective function at each iteration */
SGVector<float64_t> obj;

/** step size at each iteration */
SGVector<float64_t> stepsize;

/** number of impostors at each iteration */
SGVector<uint32_t> num_impostors;
};

} /* namespace shogun */

#endif /* HAVE_LAPACK */
Expand Down

0 comments on commit 0497410

Please sign in to comment.