Skip to content

Commit

Permalink
Merge pull request #3413 from AdarshSantoria/bksvd
Browse files Browse the repository at this point in the history
Add Block Krylov SVD method in CF
  • Loading branch information
rcurtin committed Feb 17, 2023
2 parents d6657f1 + 72bc449 commit 7f60e29
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 33 deletions.
1 change: 1 addition & 0 deletions doc/tutorials/cf.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ alternating least squares update rules). These include:
- `BiasSVDPolicy`
- `SVDPlusPlusPolicy`
- `RandomizedSVDPolicy`
- `BlockKrylovSVDPolicy`
The `AMF` class has many other possibilities than those listed here; it is a
framework for alternating matrix factorization techniques. See the `AMF` class
Expand Down
17 changes: 16 additions & 1 deletion src/mlpack/methods/cf/cf_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ BINDING_LONG_DESC(
" - 'SVDCompleteIncremental' -- SVD complete incremental learning\n"
" - 'BiasSVD' -- Bias SVD using a SGD optimizer\n"
" - 'SVDPP' -- SVD++ using a SGD optimizer\n"
" - 'RandSVD' -- RandomizedSVD learning\n"
" - 'QSVD' -- QuicSVD learning\n"
" - 'BKSVD' -- Block Krylov SVD learning\n"
"\n\n"
"The following neighbor search algorithms can be specified via" +
" the " + PRINT_PARAM_STRING("neighbor_search") + " parameter:"
Expand Down Expand Up @@ -196,7 +199,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)

RequireParamInSet<string>(params, "algorithm", { "NMF", "BatchSVD",
"SVDIncompleteIncremental", "SVDCompleteIncremental", "RegSVD",
"RandSVD", "BiasSVD", "SVDPP" }, true, "unknown algorithm");
"RandSVD", "BiasSVD", "SVDPP", "QSVD", "BKSVD" }, true, "unknown algorithm");

ReportIgnoredParam(params, {{ "iteration_only_termination", true }},
"min_residue");
Expand Down Expand Up @@ -282,6 +285,18 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
"when max_iterations is reached");
cf->DecompositionType() = CFModel::SVD_PLUS_PLUS;
}
else if (algo == "QSVD")
{
ReportIgnoredParam(params, "min_residue", "QSVD terminates only "
"when max_iterations is reached");
cf->DecompositionType() = CFModel::QUIC_SVD;
}
else if (algo == "BKSVD")
{
ReportIgnoredParam(params, "min_residue", "BKSVD terminates only "
"when max_iterations is reached");
cf->DecompositionType() = CFModel::BLOCK_KRYLOV_SVD;
}

// Perform the factorization and do whatever the user wanted.
const size_t neighborhood = (size_t) params.Get<int>("neighborhood");
Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/methods/cf/cf_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ class CFModel
SVD_COMPLETE,
SVD_INCOMPLETE,
BIAS_SVD,
SVD_PLUS_PLUS
SVD_PLUS_PLUS,
QUIC_SVD,
BLOCK_KRYLOV_SVD
};

enum NormalizationTypes
Expand Down
47 changes: 24 additions & 23 deletions src/mlpack/methods/cf/cf_model_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,6 @@

#include "cf_model.hpp"

#include "interpolation_policies/average_interpolation.hpp"
#include "interpolation_policies/regression_interpolation.hpp"
#include "interpolation_policies/similarity_interpolation.hpp"

#include "neighbor_search_policies/cosine_search.hpp"
#include "neighbor_search_policies/lmetric_search.hpp"
#include "neighbor_search_policies/pearson_search.hpp"

#include "decomposition_policies/batch_svd_method.hpp"
#include "decomposition_policies/bias_svd_method.hpp"
#include "decomposition_policies/nmf_method.hpp"
#include "decomposition_policies/randomized_svd_method.hpp"
#include "decomposition_policies/regularized_svd_method.hpp"
#include "decomposition_policies/svd_complete_method.hpp"
#include "decomposition_policies/svd_incomplete_method.hpp"
#include "decomposition_policies/svdplusplus_method.hpp"

#include "normalization/no_normalization.hpp"
#include "normalization/overall_mean_normalization.hpp"
#include "normalization/user_mean_normalization.hpp"
#include "normalization/item_mean_normalization.hpp"
#include "normalization/z_score_normalization.hpp"

namespace mlpack {

inline CFModel::CFModel() :
Expand Down Expand Up @@ -361,6 +338,12 @@ inline CFWrapperBase* InitializeModel(

case CFModel::SVD_PLUS_PLUS:
return InitializeModelHelper<SVDPlusPlusPolicy>(normalizationType);

case CFModel::QUIC_SVD:
return InitializeModelHelper<QUIC_SVDPolicy>(normalizationType);

case CFModel::BLOCK_KRYLOV_SVD:
return InitializeModelHelper<BlockKrylovSVDPolicy>(normalizationType);
}

// This shouldn't ever happen.
Expand Down Expand Up @@ -473,6 +456,16 @@ inline void CFModel::Train(
cf = TrainHelper(SVDPlusPlusPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;

case QUIC_SVD:
cf = TrainHelper(QUIC_SVDPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;

case BLOCK_KRYLOV_SVD:
cf = TrainHelper(BlockKrylovSVDPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;
}
}

Expand Down Expand Up @@ -555,6 +548,14 @@ void CFModel::serialize(Archive& ar, const uint32_t /* version */)
case SVD_PLUS_PLUS:
SerializeHelper<SVDPlusPlusPolicy>(ar, cf, normalizationType);
break;

case QUIC_SVD:
SerializeHelper<QUIC_SVDPolicy>(ar, cf, normalizationType);
break;

case BLOCK_KRYLOV_SVD:
SerializeHelper<BlockKrylovSVDPolicy>(ar, cf, normalizationType);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/**
* @file methods/cf/decomposition_policies/block_krylov_svd_method.hpp
* @author Adarsh Santoria
*
* Implementation of the block krylov svd method for use in
* Collaborative Fitlering.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BLOCK_KRYLOV_SVD_METHOD_HPP
#define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BLOCK_KRYLOV_SVD_METHOD_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/block_krylov_svd/block_krylov_svd.hpp>

namespace mlpack {

/**
* Implementation of the Block Krylov SVD policy to act as a wrapper when
* using Block Krylov SVD for the decomposition type of CF.
*
* An example of how to use BlockKrylovSVDPolicy in CF is shown below:
*
* @code
* extern arma::mat data; // data is a (user, item, rating) table.
* // Users for whom recommendations are generated.
* extern arma::Col<size_t> users;
* arma::Mat<size_t> recommendations; // Resulting recommendations.
*
* CFType<BlockKrylovSVDPolicy> cf(data);
*
* // Generate 10 recommendations for all users.
* cf.GetRecommendations(10, recommendations);
* @endcode
*/
class BlockKrylovSVDPolicy
{
public:
/**
* Create block krylov SVD object to use for collaborative filtering.
*/
BlockKrylovSVDPolicy()
{
/* Nothing to do here */
}

/**
* Apply Collaborative Filtering to the provided data set using the
* block krylov SVD.
*
* @param * (data) Data matrix: dense matrix (coordinate lists)
* or sparse matrix(cleaned).
* @param cleanedData item user table in form of sparse matrix.
* @param rank Rank parameter for matrix factorization.
* @param * (maxIterations) Maximum number of iterations.
* @param * (minResidue) Residue required to terminate.
* @param * (mit) Whether to terminate only when maxIterations is reached.
*/
template<typename MatType>
void Apply(const MatType& /* data */,
const arma::sp_mat& cleanedData,
const size_t rank,
const size_t /* maxIterations */,
const double /* minResidue */,
const bool /* mit */)
{
arma::vec sigma;

// Preprocessed data converted to mat format
arma::mat data(cleanedData);

// Do singular value decomposition using the block krylov SVD algorithm.
RandomizedBlockKrylovSVD blockkrylovsvd;
blockkrylovsvd.Apply(data, w, sigma, h, rank);

// Sigma matrix is multiplied to w.
w = w * arma::diagmat(sigma);

// Take transpose of the matrix h as required by CF class.
h = arma::trans(h);
}

/**
* Return predicted rating given user ID and item ID.
*
* @param user User ID.
* @param item Item ID.
*/
double GetRating(const size_t user, const size_t item) const
{
double rating = arma::as_scalar(w.row(item) * h.col(user));
return rating;
}

/**
* Get predicted ratings for a user.
*
* @param user User ID.
* @param rating Resulting rating vector.
*/
void GetRatingOfUser(const size_t user, arma::vec& rating) const
{
rating = w * h.col(user);
}

/**
* Get the neighborhood and corresponding similarities for a set of users.
*
* @tparam NeighborSearchPolicy The policy to perform neighbor search.
*
* @param users Users whose neighborhood is to be computed.
* @param numUsersForSimilarity The number of neighbors returned for
* each user.
* @param neighborhood Neighbors represented by user IDs.
* @param similarities Similarity between each user and each of its
* neighbors.
*/
template<typename NeighborSearchPolicy>
void GetNeighborhood(const arma::Col<size_t>& users,
const size_t numUsersForSimilarity,
arma::Mat<size_t>& neighborhood,
arma::mat& similarities) const
{
// We want to avoid calculating the full rating matrix, so we will do
// nearest neighbor search only on the H matrix, using the observation that
// if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i),
// W H.col(j)). This can be seen as nearest neighbor search on the H
// matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll
// decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply
// H by L^T. Then we can perform nearest neighbor search.
arma::mat l = arma::chol(w.t() * w);
arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.

// Temporarily store feature vector of queried users.
arma::mat query(stretchedH.n_rows, users.n_elem);
// Select feature vectors of queried users.
for (size_t i = 0; i < users.n_elem; ++i)
query.col(i) = stretchedH.col(users(i));

NeighborSearchPolicy neighborSearch(stretchedH);
neighborSearch.Search(
query, numUsersForSimilarity, neighborhood, similarities);
}

//! Get the Item Matrix.
const arma::mat& W() const { return w; }
//! Get the User Matrix.
const arma::mat& H() const { return h; }

/**
* Serialization.
*/
template<typename Archive>
void serialize(Archive& ar, const uint32_t /* version */)
{
ar(CEREAL_NVP(w));
ar(CEREAL_NVP(h));
}

private:
//! Item matrix.
arma::mat w;
//! User matrix.
arma::mat h;
};

} // namespace mlpack

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
#include "svd_incomplete_method.hpp"
#include "svdplusplus_method.hpp"
#include "quic_svd_method.hpp"
#include "block_krylov_svd_method.hpp"

#endif
18 changes: 10 additions & 8 deletions src/mlpack/tests/cf_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ void Serialization()
TEMPLATE_TEST_CASE("CFGetRecommendationsAllUsersTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
GetRecommendationsAllUsers<TestType>();
}
Expand All @@ -513,7 +513,7 @@ TEMPLATE_TEST_CASE("CFGetRecommendationsAllUsersTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFGetRecommendationsQueriedUsersTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
GetRecommendationsQueriedUser<TestType>();
}
Expand All @@ -524,7 +524,8 @@ TEMPLATE_TEST_CASE("CFGetRecommendationsQueriedUsersTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("RecommendationAccuracyTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy)
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy,
BlockKrylovSVDPolicy)
{
RecommendationAccuracy<TestType>();
}
Expand All @@ -546,7 +547,7 @@ TEMPLATE_TEST_CASE("RecommendationAccuracyTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFPredictTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
CFPredict<TestType>();
}
Expand All @@ -557,7 +558,7 @@ TEMPLATE_TEST_CASE("CFPredictTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFBatchPredictTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
BatchPredict<TestType>();
}
Expand All @@ -568,7 +569,7 @@ TEMPLATE_TEST_CASE("CFBatchPredictTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("TrainTest_1", "[CFTest]",
RandomizedSVDPolicy, BatchSVDPolicy, NMFPolicy, SVDCompletePolicy,
SVDIncompletePolicy, QUIC_SVDPolicy)
SVDIncompletePolicy, QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
TestType decomposition;
Train(decomposition);
Expand All @@ -591,7 +592,8 @@ TEMPLATE_TEST_CASE("TrainTest_2", "[CFTest]",
*/
TEMPLATE_TEST_CASE("EmptyConstructorTrainTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy)
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy,
BlockKrylovSVDPolicy)
{
EmptyConstructorTrain<TestType>();
}
Expand All @@ -601,7 +603,7 @@ TEMPLATE_TEST_CASE("EmptyConstructorTrainTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("SerializationTest", "[CFTest]",
RandomizedSVDPolicy, BatchSVDPolicy, NMFPolicy, SVDCompletePolicy,
SVDIncompletePolicy, QUIC_SVDPolicy)
SVDIncompletePolicy, QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
Serialization<TestType>();
}
Expand Down

0 comments on commit 7f60e29

Please sign in to comment.