Skip to content

Commit

Permalink
Merge branch 'david-cortes-explicit_cpp'
Browse files Browse the repository at this point in the history
# Conflicts:
#	R/RcppExports.R
#	R/model_WRMF.R
#	man/WRMF.Rd
#	src/RcppExports.cpp
#	src/SolverAlsWrmf.cpp
  • Loading branch information
dselivanov committed Nov 29, 2020
2 parents f80e220 + cecfb3d commit 35247b4
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 27 deletions.
4 changes: 3 additions & 1 deletion R/MatrixFactorizationRecommender.R
Expand Up @@ -6,6 +6,8 @@ MatrixFactorizationRecommender = R6::R6Class(
public = list(
#' @field components item embeddings
components = NULL,
#' @field global_mean global mean (for centering values in explicit feedback)
global_mean = 0.,
#' @description recommends items for users
#' @param x user-item interactions matrix (usually sparse - `Matrix::sparseMatrix`).Users are
#' rows and items are columns
Expand Down Expand Up @@ -58,7 +60,7 @@ MatrixFactorizationRecommender = R6::R6Class(
not_recommend = as(not_recommend, "RsparseMatrix")

uids = rownames(user_embeddings)
indices = find_top_product(user_embeddings, item_embeddings, k, not_recommend, items_exclude)
indices = find_top_product(user_embeddings, item_embeddings, k, not_recommend, items_exclude, self$global_mean)

data.table::setattr(indices, "dimnames", list(uids, NULL))
data.table::setattr(indices, "ids", NULL)
Expand Down
16 changes: 14 additions & 2 deletions R/RcppExports.R
Expand Up @@ -89,6 +89,18 @@ als_explicit_float <- function(m_csc_r, X_, Y_, lambda, n_threads, solver, cg_st
.Call(`_rsparse_als_explicit_float`, m_csc_r, X_, Y_, lambda, n_threads, solver, cg_steps, with_biases, is_bias_last_row)
}

initialize_biases_double <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative) {
invisible(.Call(`_rsparse_initialize_biases_double`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative))
}

initialize_biases_float <- function(m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative) {
invisible(.Call(`_rsparse_initialize_biases_float`, m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative))
}

deep_copy <- function(x) {
.Call(`_rsparse_deep_copy`, x)
}

rankmf_solver_double <- function(x_r, W, H, W2_grad, H2_grad, user_features_r, item_features_r, rank, n_updates, learning_rate = 0.01, gamma = 1, lambda_user = 0.0, lambda_item_positive = 0.0, lambda_item_negative = 0.0, n_threads = 1L, update_items = TRUE, loss = 0L, kernel = 0L, max_negative_samples = 50L, margin = 0.1, optimizer = 0L, report_progress = 10L) {
invisible(.Call(`_rsparse_rankmf_solver_double`, x_r, W, H, W2_grad, H2_grad, user_features_r, item_features_r, rank, n_updates, learning_rate, gamma, lambda_user, lambda_item_positive, lambda_item_negative, n_threads, update_items, loss, kernel, max_negative_samples, margin, optimizer, report_progress))
}
Expand All @@ -97,8 +109,8 @@ rankmf_solver_float <- function(x_r, W, H, W2_grad, H2_grad, user_features_r, it
invisible(.Call(`_rsparse_rankmf_solver_float`, x_r, W, H, W2_grad, H2_grad, user_features_r, item_features_r, rank, n_updates, learning_rate, gamma, lambda_user, lambda_item_positive, lambda_item_negative, n_threads, update_items, loss, kernel, max_negative_samples, margin, optimizer, report_progress))
}

top_product <- function(x, y, k, n_threads, not_recommend_r, exclude) {
.Call(`_rsparse_top_product`, x, y, k, n_threads, not_recommend_r, exclude)
top_product <- function(x, y, k, n_threads, not_recommend_r, exclude, glob_mean = 0.) {
.Call(`_rsparse_top_product`, x, y, k, n_threads, not_recommend_r, exclude, glob_mean)
}

arma_kmeans <- function(x, k, seed_mode, n_iter, verbose, result) {
Expand Down
74 changes: 62 additions & 12 deletions R/model_WRMF.R
@@ -1,7 +1,7 @@
#' @title Weighted Regularized Matrix Factorization for collaborative filtering
#' @description Creates a matrix factorization model which is solved through Alternating Least Squares (Weighted ALS for implicit feedback).
#' For implicit feedback see "Collaborative Filtering for Implicit Feedback Datasets" (Hu, Koren, Volinsky).
#' For explicit feedback it corresponds to the classic model for rating matrix decomposition with MSE error (without biases at the moment).
#' For explicit feedback it corresponds to the classic model for rating matrix decomposition with MSE error.
#' These two algorithms are proven to work well in recommender systems.
#' @references
#' \itemize{
Expand Down Expand Up @@ -59,7 +59,7 @@ WRMF = R6::R6Class(
#' @param cg_steps \code{integer > 0} - max number of internal steps in conjugate gradient
#' (if "conjugate_gradient" solver used). \code{cg_steps = 3} by default.
#' Controls precision of linear equation solution at the each ALS step. Usually no need to tune this parameter
#' @param precision one of \code{c("double", "float")}. Should embeeding matrices be
#' @param precision one of \code{c("double", "float")}. Should embedding matrices be
#' numeric or float (from \code{float} package). The latter is usually 2x faster and
#' consumes less RAM. BUT \code{float} matrices are not "base" objects. Use carefully.
#' @param ... not used at the moment
Expand Down Expand Up @@ -90,9 +90,6 @@ WRMF = R6::R6Class(
solver_codes = c("cholesky", "conjugate_gradient", "nnls")
private$solver_code = match(solver, solver_codes) - 1L

if (feedback == "explicit" && precision == "float")
stop("Explicit solver doesn't support single precision at the moment (but in principle can support).")

private$precision = match.arg(precision)
private$feedback = feedback
private$lambda = as.numeric(lambda)
Expand Down Expand Up @@ -146,7 +143,7 @@ WRMF = R6::R6Class(
#' @param n_iter max number of ALS iterations
#' @param convergence_tol convergence tolerance checked between iterations
#' @param ... not used at the moment
fit_transform = function(x, n_iter = 10L, convergence_tol = 0.005, ...) {
fit_transform = function(x, n_iter = 10L, convergence_tol = ifelse(private$feedback == "implicit", 0.005, 0.001), ...) {
if (private$feedback == "implicit" ) {
logger$trace("WRMF$fit_transform(): calling `RhpcBLASctl::blas_set_num_threads(1)` (to avoid thread contention)")
blas_threads_keep = RhpcBLASctl::blas_get_num_procs()
Expand All @@ -159,13 +156,31 @@ WRMF = R6::R6Class(

c_ui = as(x, "CsparseMatrix")
c_ui = private$preprocess(c_ui)
c_iu = t(c_ui)
# store item_ids in order to use them in predict method
private$item_ids = colnames(c_ui)

if ((private$feedback != "explicit") || private$non_negative) {
stopifnot(all(c_ui@x >= 0))
}
c_iu = t(c_ui)

# if (private$feedback == "explicit" && !private$non_negative) {
# self$global_mean = mean(c_ui@x)
# c_ui@x = c_ui@x - self$global_mean
# }
# if (private$with_bias) {
# c_ui@x = deep_copy(c_ui@x)
# c_ui_orig = deep_copy(c_ui@x)
# }
# else {
# c_ui_orig = numeric(0L)
# }

# if (private$with_bias) {
# c_iu_orig = deep_copy(c_iu@x)
# } else {
# c_iu_orig = numeric(0L)
# }

# init
n_user = nrow(c_ui)
Expand Down Expand Up @@ -213,22 +228,45 @@ WRMF = R6::R6Class(
}

# NNLS
if (private$solver_code == 2L) {
if (private$non_negative) {
self$components = abs(self$components)
private$U = abs(private$U)
}

stopifnot(ncol(private$U) == ncol(c_iu))
stopifnot(ncol(self$components) == ncol(c_ui))

logger$info("starting factorization")
# if (private$with_bias) {
# logger$debug("initializing biases")
# if (private$precision == "double") {
# user_bias = numeric(n_user)
# item_bias = numeric(n_item)
# initialize_biases_double(c_ui, c_iu,
# user_bias,
# item_bias,
# private$lambda,
# private$non_negative)
# } else {
# user_bias = float(n_user)
# item_bias = float(n_item)
# initialize_biases_float(c_ui, c_iu,
# user_bias,
# item_bias,
# private$lambda,
# private$non_negative)
# }
# self$components[1L, ] = item_bias
# private$U[private$rank, ] = user_bias
# }

logger$info("starting factorization with %d threads", getOption("rsparse_omp_threads", 1L))

loss_prev_iter = Inf

# iterate
for (i in seq_len(n_iter)) {
# solve for items
loss = private$solver(c_ui, private$U, self$components, TRUE)

# solve for users
loss = private$solver(c_iu, self$components, private$U, FALSE)

Expand Down Expand Up @@ -287,6 +325,19 @@ WRMF = R6::R6Class(
}

loss = private$solver(t(x), self$components, res, FALSE, private$XtX)
# if (private$feedback == "implicit") {
# loss = private$solver(t(x), self$components, res, FALSE, private$XtX)
# } else{
# x_use = t(x)
# if (!private$non_negative)
# x_use@x = x_use@x - self$global_mean
# if (private$with_bias) {
# x_orig = deep_copy(x_use@x)
# } else {
# x_orig = numeric(0L)
# }
# loss = private$solver(x_use, self$components, res, FALSE)
# }
res = t(res)

if (private$precision == "double")
Expand All @@ -297,6 +348,7 @@ WRMF = R6::R6Class(
res
}
),
#### private -----
private = list(
solver_code = NULL,
cg_steps = NULL,
Expand All @@ -313,8 +365,6 @@ WRMF = R6::R6Class(
# this is essentially "confidence" transformation from WRMF article
preprocess = NULL,
feedback = NULL,
cv_data = NULL,
scorers_ellipsis = NULL,
precision = NULL,
XtX = NULL,
solver = NULL,
Expand Down
4 changes: 2 additions & 2 deletions R/utils.R
Expand Up @@ -28,7 +28,7 @@ train_test_split = function(x, test_proportion = 0.5) {
}


find_top_product = function(x, y, k, not_recommend = NULL, exclude = integer(0), n_threads = getOption("rsparse_omp_threads", 1L)) {
find_top_product = function(x, y, k, not_recommend = NULL, exclude = integer(0), n_threads = getOption("rsparse_omp_threads", 1L), glob_mean = 0.) {
n_threads_blas = RhpcBLASctl::blas_get_num_procs()
# set num threads to 1 in order to avoid thread contention between BLAS and openmp threads in `top_product()`
RhpcBLASctl::blas_set_num_threads(1L)
Expand All @@ -49,5 +49,5 @@ find_top_product = function(x, y, k, not_recommend = NULL, exclude = integer(0),
stopifnot(ncol(y) == ncol(not_recommend))
not_recommend = as(not_recommend, "RsparseMatrix")
}
top_product(x, y, k, n_threads, not_recommend, exclude)
top_product(x, y, k, n_threads, not_recommend, exclude, glob_mean)
}
2 changes: 2 additions & 0 deletions man/MatrixFactorizationRecommender.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions man/WRMF.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 49 additions & 4 deletions src/RcppExports.cpp
Expand Up @@ -321,6 +321,47 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// initialize_biases_double
void initialize_biases_double(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, arma::Col<double>& user_bias, arma::Col<double>& item_bias, double lambda, bool non_negative);
RcppExport SEXP _rsparse_initialize_biases_double(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP non_negativeSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Rcpp::S4& >::type m_csc_r(m_csc_rSEXP);
Rcpp::traits::input_parameter< const Rcpp::S4& >::type m_csr_r(m_csr_rSEXP);
Rcpp::traits::input_parameter< arma::Col<double>& >::type user_bias(user_biasSEXP);
Rcpp::traits::input_parameter< arma::Col<double>& >::type item_bias(item_biasSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
initialize_biases_double(m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative);
return R_NilValue;
END_RCPP
}
// initialize_biases_float
void initialize_biases_float(const Rcpp::S4& m_csc_r, const Rcpp::S4& m_csr_r, Rcpp::S4& user_bias, Rcpp::S4& item_bias, double lambda, bool non_negative);
RcppExport SEXP _rsparse_initialize_biases_float(SEXP m_csc_rSEXP, SEXP m_csr_rSEXP, SEXP user_biasSEXP, SEXP item_biasSEXP, SEXP lambdaSEXP, SEXP non_negativeSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Rcpp::S4& >::type m_csc_r(m_csc_rSEXP);
Rcpp::traits::input_parameter< const Rcpp::S4& >::type m_csr_r(m_csr_rSEXP);
Rcpp::traits::input_parameter< Rcpp::S4& >::type user_bias(user_biasSEXP);
Rcpp::traits::input_parameter< Rcpp::S4& >::type item_bias(item_biasSEXP);
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
Rcpp::traits::input_parameter< bool >::type non_negative(non_negativeSEXP);
initialize_biases_float(m_csc_r, m_csr_r, user_bias, item_bias, lambda, non_negative);
return R_NilValue;
END_RCPP
}
// deep_copy
SEXP deep_copy(SEXP x);
RcppExport SEXP _rsparse_deep_copy(SEXP xSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< SEXP >::type x(xSEXP);
rcpp_result_gen = Rcpp::wrap(deep_copy(x));
return rcpp_result_gen;
END_RCPP
}
// rankmf_solver_double
void rankmf_solver_double(const Rcpp::S4& x_r, arma::Mat<double>& W, arma::Mat<double>& H, arma::Col<double>& W2_grad, arma::Col<double>& H2_grad, const Rcpp::S4& user_features_r, const Rcpp::S4& item_features_r, const arma::uword rank, const arma::uword n_updates, double learning_rate, double gamma, double lambda_user, double lambda_item_positive, double lambda_item_negative, const arma::uword n_threads, bool update_items, const arma::uword loss, const arma::uword kernel, arma::uword max_negative_samples, double margin, const arma::uword optimizer, const arma::uword report_progress);
RcppExport SEXP _rsparse_rankmf_solver_double(SEXP x_rSEXP, SEXP WSEXP, SEXP HSEXP, SEXP W2_gradSEXP, SEXP H2_gradSEXP, SEXP user_features_rSEXP, SEXP item_features_rSEXP, SEXP rankSEXP, SEXP n_updatesSEXP, SEXP learning_rateSEXP, SEXP gammaSEXP, SEXP lambda_userSEXP, SEXP lambda_item_positiveSEXP, SEXP lambda_item_negativeSEXP, SEXP n_threadsSEXP, SEXP update_itemsSEXP, SEXP lossSEXP, SEXP kernelSEXP, SEXP max_negative_samplesSEXP, SEXP marginSEXP, SEXP optimizerSEXP, SEXP report_progressSEXP) {
Expand Down Expand Up @@ -384,8 +425,8 @@ BEGIN_RCPP
END_RCPP
}
// top_product
Rcpp::IntegerMatrix top_product(const arma::mat& x, const arma::mat& y, unsigned k, unsigned n_threads, const Rcpp::S4& not_recommend_r, const Rcpp::IntegerVector& exclude);
RcppExport SEXP _rsparse_top_product(SEXP xSEXP, SEXP ySEXP, SEXP kSEXP, SEXP n_threadsSEXP, SEXP not_recommend_rSEXP, SEXP excludeSEXP) {
Rcpp::IntegerMatrix top_product(const arma::mat& x, const arma::mat& y, unsigned k, unsigned n_threads, const Rcpp::S4& not_recommend_r, const Rcpp::IntegerVector& exclude, const double glob_mean);
RcppExport SEXP _rsparse_top_product(SEXP xSEXP, SEXP ySEXP, SEXP kSEXP, SEXP n_threadsSEXP, SEXP not_recommend_rSEXP, SEXP excludeSEXP, SEXP glob_meanSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -395,7 +436,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< unsigned >::type n_threads(n_threadsSEXP);
Rcpp::traits::input_parameter< const Rcpp::S4& >::type not_recommend_r(not_recommend_rSEXP);
Rcpp::traits::input_parameter< const Rcpp::IntegerVector& >::type exclude(excludeSEXP);
rcpp_result_gen = Rcpp::wrap(top_product(x, y, k, n_threads, not_recommend_r, exclude));
Rcpp::traits::input_parameter< const double >::type glob_mean(glob_meanSEXP);
rcpp_result_gen = Rcpp::wrap(top_product(x, y, k, n_threads, not_recommend_r, exclude, glob_mean));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -477,9 +519,12 @@ static const R_CallMethodDef CallEntries[] = {
{"_rsparse_als_implicit_float", (DL_FUNC) &_rsparse_als_implicit_float, 9},
{"_rsparse_als_explicit_double", (DL_FUNC) &_rsparse_als_explicit_double, 9},
{"_rsparse_als_explicit_float", (DL_FUNC) &_rsparse_als_explicit_float, 9},
{"_rsparse_initialize_biases_double", (DL_FUNC) &_rsparse_initialize_biases_double, 6},
{"_rsparse_initialize_biases_float", (DL_FUNC) &_rsparse_initialize_biases_float, 6},
{"_rsparse_deep_copy", (DL_FUNC) &_rsparse_deep_copy, 1},
{"_rsparse_rankmf_solver_double", (DL_FUNC) &_rsparse_rankmf_solver_double, 22},
{"_rsparse_rankmf_solver_float", (DL_FUNC) &_rsparse_rankmf_solver_float, 22},
{"_rsparse_top_product", (DL_FUNC) &_rsparse_top_product, 6},
{"_rsparse_top_product", (DL_FUNC) &_rsparse_top_product, 7},
{"_rsparse_arma_kmeans", (DL_FUNC) &_rsparse_arma_kmeans, 6},
{"_rsparse_c_nnls_double", (DL_FUNC) &_rsparse_c_nnls_double, 4},
{"_rsparse_omp_thread_count", (DL_FUNC) &_rsparse_omp_thread_count, 0},
Expand Down

0 comments on commit 35247b4

Please sign in to comment.