Skip to content

Commit

Permalink
revised mex files based on new updates of light-matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed Dec 20, 2012
1 parent f50fceb commit 6fca2db
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 17 deletions.
6 changes: 3 additions & 3 deletions common/private/aggreg_base.h
Expand Up @@ -7,7 +7,7 @@
#ifndef PLI_AGGREG_BASE_H
#define PLI_AGGREG_BASE_H

#include <light_mat/common/basic_defs.h>
#include <light_mat/common/memory.h>
#include <limits>

// operations
Expand All @@ -30,7 +30,7 @@ struct aop_max
void init(const lmat::index_t n, T *x)
{
T v0 = -std::numeric_limits<T>::infinity();
fill_val(v0, n, x);
lmat::fill_vec(n, x, v0);
}

LMAT_ENSURE_INLINE
Expand All @@ -48,7 +48,7 @@ struct aop_min
void init(const lmat::index_t n, T *x)
{
T v0 = std::numeric_limits<T>::infinity();
fill_val(v0, n, x);
lmat::fill_vec(n, x, v0);
}

LMAT_ENSURE_INLINE
Expand Down
10 changes: 5 additions & 5 deletions metrics/private/dist_base.h
Expand Up @@ -18,8 +18,8 @@

template<typename T, class Dist, class SMat, class DMat>
inline void pw_dists(Dist dist,
const lmat::IDenseMatrix<SMat, T>& X,
lmat::IDenseMatrix<DMat, T>& D)
const lmat::IRegularMatrix<SMat, T>& X,
lmat::IRegularMatrix<DMat, T>& D)
{
const lmat::index_t d = X.nrows();
const lmat::index_t n = X.ncolumns();
Expand Down Expand Up @@ -47,9 +47,9 @@ inline void pw_dists(Dist dist,

template<typename T, class Dist, class LMat, class RMat, class DMat>
inline void pw_dists(Dist dist,
const lmat::IDenseMatrix<LMat, T>& X,
const lmat::IDenseMatrix<RMat, T>& Y,
lmat::IDenseMatrix<DMat, T>& D)
const lmat::IRegularMatrix<LMat, T>& X,
const lmat::IRegularMatrix<RMat, T>& Y,
lmat::IRegularMatrix<DMat, T>& D)
{
const lmat::index_t d = X.nrows();
const lmat::index_t m = X.ncolumns();
Expand Down
7 changes: 5 additions & 2 deletions optim/private/lbfgs_calcdir_cimp.cpp
Expand Up @@ -8,6 +8,7 @@
**********************************************************/

#include <light_mat/matlab/matlab_port.h>
#include <light_mat/mateval/mat_reduce.h>

using namespace lmat;
using namespace lmat::matlab;
Expand Down Expand Up @@ -69,7 +70,8 @@ struct LBFGS_Update
cvec_t y = Y.column(j);

alpha[j] = rho[j] * dot(s, q);
q -= alpha[j] * y;

accum_to(q, -alpha[j], y);
}


Expand All @@ -80,7 +82,8 @@ struct LBFGS_Update
cvec_t y = Y.column(j);

double beta = rho[j] * dot(y, z);
z += (alpha[j] - beta) * s;

accum_to(z, alpha[j] - beta, s);
}


Expand Down
1 change: 0 additions & 1 deletion svm/private/pegasos_cimp.cpp
Expand Up @@ -6,7 +6,6 @@

#include "svm_sgdx_common.h"


/**
* w = alpha * w + (eta/k) sum_{i in A} y_i x_i
*
Expand Down
5 changes: 3 additions & 2 deletions svm/private/svm_sgdx_common.h
Expand Up @@ -5,6 +5,7 @@
**********************************************************/

#include <light_mat/matlab/matlab_port.h>
#include <light_mat/mateval/mat_reduce.h>

using namespace lmat;
using namespace lmat::matlab;
Expand Down Expand Up @@ -36,7 +37,7 @@ class WeightVec
LMAT_ENSURE_INLINE
double sqnorm() const
{
return lmat::sqL2norm(_w);
return lmat::sqsum(_w);
}

LMAT_ENSURE_INLINE
Expand Down Expand Up @@ -103,7 +104,7 @@ class WeightVecX
LMAT_ENSURE_INLINE
double sqnorm() const
{
return lmat::sqL2norm(_w) + _w0 * _w0;
return lmat::sqsum(_w) + _w0 * _w0;
}

LMAT_ENSURE_INLINE
Expand Down
8 changes: 4 additions & 4 deletions vq/private/ssvq_cimp.cpp
Expand Up @@ -220,8 +220,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
k = _C0.ncolumns();

copy_mem(d * k, _C0.data<double>(), C.ptr_data());
copy_mem(k, _w0.data<double>(), w.ptr_data());
copy_vec(d * k, _C0.data<double>(), C.ptr_data());
copy_vec(k, _w0.data<double>(), w.ptr_data());
}


Expand All @@ -240,8 +240,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
_C = marray::numeric_matrix<double>(d, k);
_w = marray::numeric_matrix<double>(1, k);

copy_mem(d * k, C.ptr_data(), _C.data<double>());
copy_mem(k, w.ptr_data(), _w.data<double>());
copy_vec(d * k, C.ptr_data(), _C.data<double>());
copy_vec(k, w.ptr_data(), _w.data<double>());

_Ctmp.destroy();
_wtmp.destroy();
Expand Down

0 comments on commit 6fca2db

Please sign in to comment.