Permalink
Browse files

more optimizations for CG solver

  • Loading branch information...
dselivanov committed Aug 1, 2017
1 parent d935d35 commit 7a8742602c2bfc1d55b7aaf84db1807962f74d69
Showing with 20 additions and 23 deletions.
  1. +3 −3 R/RcppExports.R
  2. +6 −6 src/RcppExports.cpp
  3. +11 −14 src/als_implicit_core_solver.cpp
View
@@ -2,14 +2,14 @@
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
als_implicit <- function(Conf, X, Y, lambda, n_threads, solver, cg_steps = 3L) {
.Call('reco_als_implicit', PACKAGE = 'reco', Conf, X, Y, lambda, n_threads, solver, cg_steps)
.Call('_reco_als_implicit', PACKAGE = 'reco', Conf, X, Y, lambda, n_threads, solver, cg_steps)
}
als_loss_explicit <- function(mat, X, Y, lambda, n_threads) {
.Call('reco_als_loss_explicit', PACKAGE = 'reco', mat, X, Y, lambda, n_threads)
.Call('_reco_als_loss_explicit', PACKAGE = 'reco', mat, X, Y, lambda, n_threads)
}
top_k_indices_byrow <- function(x, k, n_threads) {
.Call('reco_top_k_indices_byrow', PACKAGE = 'reco', x, k, n_threads)
.Call('_reco_top_k_indices_byrow', PACKAGE = 'reco', x, k, n_threads)
}
View
@@ -8,7 +8,7 @@ using namespace Rcpp;
// als_implicit
double als_implicit(const arma::sp_mat& Conf, arma::mat& X, arma::mat& Y, double lambda, int n_threads, int solver, int cg_steps);
RcppExport SEXP reco_als_implicit(SEXP ConfSEXP, SEXP XSEXP, SEXP YSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP) {
RcppExport SEXP _reco_als_implicit(SEXP ConfSEXP, SEXP XSEXP, SEXP YSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP, SEXP solverSEXP, SEXP cg_stepsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -25,7 +25,7 @@ END_RCPP
}
// als_loss_explicit
double als_loss_explicit(const arma::sp_mat& mat, arma::mat& X, arma::mat& Y, double lambda, int n_threads);
RcppExport SEXP reco_als_loss_explicit(SEXP matSEXP, SEXP XSEXP, SEXP YSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP) {
RcppExport SEXP _reco_als_loss_explicit(SEXP matSEXP, SEXP XSEXP, SEXP YSEXP, SEXP lambdaSEXP, SEXP n_threadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -40,7 +40,7 @@ END_RCPP
}
// top_k_indices_byrow
IntegerMatrix top_k_indices_byrow(NumericMatrix x, int k, int n_threads);
RcppExport SEXP reco_top_k_indices_byrow(SEXP xSEXP, SEXP kSEXP, SEXP n_threadsSEXP) {
RcppExport SEXP _reco_top_k_indices_byrow(SEXP xSEXP, SEXP kSEXP, SEXP n_threadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -53,9 +53,9 @@ END_RCPP
}
static const R_CallMethodDef CallEntries[] = {
{"reco_als_implicit", (DL_FUNC) &reco_als_implicit, 7},
{"reco_als_loss_explicit", (DL_FUNC) &reco_als_loss_explicit, 5},
{"reco_top_k_indices_byrow", (DL_FUNC) &reco_top_k_indices_byrow, 3},
{"_reco_als_implicit", (DL_FUNC) &_reco_als_implicit, 7},
{"_reco_als_loss_explicit", (DL_FUNC) &_reco_als_loss_explicit, 5},
{"_reco_top_k_indices_byrow", (DL_FUNC) &_reco_top_k_indices_byrow, 3},
{NULL, NULL, 0}
};
@@ -13,36 +13,33 @@ using namespace Rcpp;
using namespace RcppArmadillo;
using namespace arma;
arma::vec chol_solver(const arma::sp_mat& Conf,
const arma::mat& X,
const arma::mat& XtX,
const arma::mat& Y,
arma::vec chol_solver(const arma::sp_mat &Conf,
const arma::mat &XtX,
const arma::mat &Y,
const arma::mat &X_nnz,
arma::vec confidence) {
const arma::vec &confidence) {
arma::mat inv = XtX + X_nnz.each_row() % (confidence.t() - 1) * X_nnz.t();
arma::mat rhs = X_nnz * confidence;
return solve(inv, rhs, solve_opts::fast );
}
arma::vec cg_solver(const arma::sp_mat& Conf,
const arma::mat& X,
const arma::mat& XtX,
inline arma::vec cg_solver(const arma::sp_mat &Conf,
const arma::mat &XtX,
const arma::mat &X_nnz,
const arma::vec &confidence,
const arma::vec &x_old,
int n_iter) {
const int n_iter) {
arma::colvec x = x_old;
arma::mat X_nnz_t = X_nnz.t();
arma::vec confidence_1 = confidence - 1;
arma::mat Ap;
arma::vec r = X_nnz * confidence - XtX * x - X_nnz * (confidence_1 % (X_nnz_t * x));
arma::vec r = X_nnz * (confidence - (confidence_1 % (X_nnz.t() * x))) - XtX * x;
arma::vec p = r;
double rsold, rsnew, alpha;
rsold = as_scalar(r.t() * r);
for(int k = 0; k < n_iter; k++) {
Ap = XtX * p + X_nnz * (confidence_1 % (X_nnz_t * p));
Ap = XtX * p + X_nnz * (confidence_1 % (X_nnz.t() * p));
alpha = rsold / as_scalar(p.t() * Ap);
x += alpha * p;
r -= alpha * Ap;
@@ -82,9 +79,9 @@ double als_implicit(const arma::sp_mat& Conf,
arma::vec confidence = vec(&Conf.values[p1], p2 - p1);
arma::mat X_nnz = X.cols(idx);
if(solver == CHOLESKY)
Y.col(i) = chol_solver(Conf, X, XtX, Y, X_nnz, confidence);
Y.col(i) = chol_solver(Conf, XtX, Y, X_nnz, confidence);
else if(solver == CONJUGATE_GRADIENT)
Y.col(i) = cg_solver(Conf, X, XtX, X_nnz, confidence, Y.col(i), cg_steps);
Y.col(i) = cg_solver(Conf, XtX, X_nnz, confidence, Y.col(i), cg_steps);
else stop("Unknown solver code %d", solver);
// if we don't want to calc loss - will provide lambda = -1
if(lambda >= 0)

0 comments on commit 7a87426

Please sign in to comment.