Skip to content

Commit

Permalink
Fix LMNN termination condition (shogun-toolbox#4188)
Browse files Browse the repository at this point in the history
* Fix LMNN termination condition
* Add unit tests on lmnn termination
  • Loading branch information
vinx13 authored and ktiefe committed Jul 26, 2019
1 parent 928f492 commit 344a0cd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/shogun/metric/LMNNImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ bool CLMNNImpl::check_termination(float64_t stepsize, const SGVector<float64_t>

if (iter >= 10)
{
obj_threshold *= obj[iter - 1];
for (int32_t i = 0; i < 3; ++i)
{
if (CMath::abs(obj[iter-i]-obj[iter-i-1]) >= obj_threshold)
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/metric/LMNN_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
*/
#include <gtest/gtest.h>

#include <shogun/metric/LMNN.h>
#include <shogun/base/some.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/metric/LMNN.h>

using namespace shogun;

Expand Down Expand Up @@ -62,6 +63,49 @@ TEST(LMNN,train_identity_init)
SG_UNREF(lmnn)
}

TEST(LMNN, train_termination)
{
SGMatrix<float64_t> feat_mat(2, 4);
feat_mat(0, 0) = 0;
feat_mat(1, 0) = 0;
feat_mat(0, 1) = 0;
feat_mat(1, 1) = -1;
feat_mat(0, 2) = 1;
feat_mat(1, 2) = 1;
feat_mat(0, 3) = -1;
feat_mat(1, 3) = 1;

CDenseFeatures<float64_t>* features =
new CDenseFeatures<float64_t>(feat_mat);

SGVector<float64_t> lab_vec(4);
lab_vec[0] = 0;
lab_vec[1] = 0;
lab_vec[2] = 1;
lab_vec[3] = 1;

CMulticlassLabels* labels = new CMulticlassLabels(lab_vec);

int32_t k = 1; // number of target neighbors per example
auto lmnn = some<CLMNN>(features, labels, k);

SGMatrix<float64_t> init_transform =
SGMatrix<float64_t>::create_identity_matrix(2, 1);
lmnn->set_maxiter(1500);
lmnn->train(init_transform);

// check linear transform solution
SGMatrix<float64_t> L = lmnn->get_linear_transform();
EXPECT_NEAR(L(0, 0), 0.000041647483219, 1e-10);
EXPECT_NEAR(L(0, 1), 0, 1e-10);
EXPECT_NEAR(L(1, 0), 0, 1e-10);
EXPECT_NEAR(L(1, 1), 0.988162395685451, 1e-10);

// check number of iterations
auto stat = lmnn->get_statistics();
EXPECT_EQ(stat->obj.vlen, 1234);
}

TEST(LMNN,train_pca_init)
{
// create features, each column is a feature vector
Expand Down

0 comments on commit 344a0cd

Please sign in to comment.