Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Block Krylov SVD method in CF #3413

Merged
merged 13 commits into from
Feb 17, 2023
1 change: 1 addition & 0 deletions doc/tutorials/cf.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ alternating least squares update rules). These include:
- `BiasSVDPolicy`
- `SVDPlusPlusPolicy`
- `RandomizedSVDPolicy`
- `BlockKrylovSVDPolicy`

The `AMF` class has many other possibilities than those listed here; it is a
framework for alternating matrix factorization techniques. See the `AMF` class
Expand Down
17 changes: 16 additions & 1 deletion src/mlpack/methods/cf/cf_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ BINDING_LONG_DESC(
" - 'SVDCompleteIncremental' -- SVD complete incremental learning\n"
" - 'BiasSVD' -- Bias SVD using a SGD optimizer\n"
" - 'SVDPP' -- SVD++ using a SGD optimizer\n"
" - 'RandSVD' -- RandomizedSVD learning\n"
" - 'QSVD' -- QuicSVD learning\n"
" - 'BKSVD' -- Block Krylov SVD learning\n"
"\n\n"
"The following neighbor search algorithms can be specified via" +
" the " + PRINT_PARAM_STRING("neighbor_search") + " parameter:"
Expand Down Expand Up @@ -196,7 +199,7 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)

RequireParamInSet<string>(params, "algorithm", { "NMF", "BatchSVD",
"SVDIncompleteIncremental", "SVDCompleteIncremental", "RegSVD",
"RandSVD", "BiasSVD", "SVDPP" }, true, "unknown algorithm");
"RandSVD", "BiasSVD", "SVDPP", "QSVD", "BKSVD" }, true, "unknown algorithm");

ReportIgnoredParam(params, {{ "iteration_only_termination", true }},
"min_residue");
Expand Down Expand Up @@ -282,6 +285,18 @@ void BINDING_FUNCTION(util::Params& params, util::Timers& timers)
"when max_iterations is reached");
cf->DecompositionType() = CFModel::SVD_PLUS_PLUS;
}
else if (algo == "QSVD")
{
ReportIgnoredParam(params, "min_residue", "QSVD terminates only "
"when max_iterations is reached");
cf->DecompositionType() = CFModel::QUIC_SVD;
}
else if (algo == "BKSVD")
{
ReportIgnoredParam(params, "min_residue", "BKSVD terminates only "
"when max_iterations is reached");
cf->DecompositionType() = CFModel::BLOCK_KRYLOV_SVD;
}

// Perform the factorization and do whatever the user wanted.
const size_t neighborhood = (size_t) params.Get<int>("neighborhood");
Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/methods/cf/cf_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ class CFModel
SVD_COMPLETE,
SVD_INCOMPLETE,
BIAS_SVD,
SVD_PLUS_PLUS
SVD_PLUS_PLUS,
QUIC_SVD,
BLOCK_KRYLOV_SVD
};

enum NormalizationTypes
Expand Down
47 changes: 24 additions & 23 deletions src/mlpack/methods/cf/cf_model_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,6 @@

#include "cf_model.hpp"

#include "interpolation_policies/average_interpolation.hpp"
#include "interpolation_policies/regression_interpolation.hpp"
#include "interpolation_policies/similarity_interpolation.hpp"

#include "neighbor_search_policies/cosine_search.hpp"
#include "neighbor_search_policies/lmetric_search.hpp"
#include "neighbor_search_policies/pearson_search.hpp"

#include "decomposition_policies/batch_svd_method.hpp"
#include "decomposition_policies/bias_svd_method.hpp"
#include "decomposition_policies/nmf_method.hpp"
#include "decomposition_policies/randomized_svd_method.hpp"
#include "decomposition_policies/regularized_svd_method.hpp"
#include "decomposition_policies/svd_complete_method.hpp"
#include "decomposition_policies/svd_incomplete_method.hpp"
#include "decomposition_policies/svdplusplus_method.hpp"

#include "normalization/no_normalization.hpp"
#include "normalization/overall_mean_normalization.hpp"
#include "normalization/user_mean_normalization.hpp"
#include "normalization/item_mean_normalization.hpp"
#include "normalization/z_score_normalization.hpp"

namespace mlpack {

inline CFModel::CFModel() :
Expand Down Expand Up @@ -361,6 +338,12 @@ inline CFWrapperBase* InitializeModel(

case CFModel::SVD_PLUS_PLUS:
return InitializeModelHelper<SVDPlusPlusPolicy>(normalizationType);

case CFModel::QUIC_SVD:
return InitializeModelHelper<QUIC_SVDPolicy>(normalizationType);

case CFModel::BLOCK_KRYLOV_SVD:
return InitializeModelHelper<BlockKrylovSVDPolicy>(normalizationType);
}

// This shouldn't ever happen.
Expand Down Expand Up @@ -473,6 +456,16 @@ inline void CFModel::Train(
cf = TrainHelper(SVDPlusPlusPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;

case QUIC_SVD:
cf = TrainHelper(QUIC_SVDPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;

case BLOCK_KRYLOV_SVD:
cf = TrainHelper(BlockKrylovSVDPolicy(), normalizationType, data,
numUsersForSimilarity, rank, maxIterations, minResidue, mit);
break;
}
}

Expand Down Expand Up @@ -555,6 +548,14 @@ void CFModel::serialize(Archive& ar, const uint32_t /* version */)
case SVD_PLUS_PLUS:
SerializeHelper<SVDPlusPlusPolicy>(ar, cf, normalizationType);
break;

case QUIC_SVD:
SerializeHelper<QUIC_SVDPolicy>(ar, cf, normalizationType);
break;

case BLOCK_KRYLOV_SVD:
SerializeHelper<BlockKrylovSVDPolicy>(ar, cf, normalizationType);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/**
* @file methods/cf/decomposition_policies/block_krylov_svd_method.hpp
* @author Adarsh Santoria
*
* Implementation of the block krylov svd method for use in
* Collaborative Fitlering.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BLOCK_KRYLOV_SVD_METHOD_HPP
#define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BLOCK_KRYLOV_SVD_METHOD_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/block_krylov_svd/block_krylov_svd.hpp>

namespace mlpack {

/**
* Implementation of the Block Krylov SVD policy to act as a wrapper when
* using Block Krylov SVD for the decomposition type of CF.
*
* An example of how to use BlockKrylovSVDPolicy in CF is shown below:
*
* @code
* extern arma::mat data; // data is a (user, item, rating) table.
* // Users for whom recommendations are generated.
* extern arma::Col<size_t> users;
* arma::Mat<size_t> recommendations; // Resulting recommendations.
*
* CFType<BlockKrylovSVDPolicy> cf(data);
*
* // Generate 10 recommendations for all users.
* cf.GetRecommendations(10, recommendations);
* @endcode
*/
class BlockKrylovSVDPolicy
{
public:
/**
* Create block krylov SVD object to use for collaborative filtering.
*/
BlockKrylovSVDPolicy()
{
/* Nothing to do here */
}

/**
* Apply Collaborative Filtering to the provided data set using the
* block krylov SVD.
*
* @param * (data) Data matrix: dense matrix (coordinate lists)
* or sparse matrix(cleaned).
* @param cleanedData item user table in form of sparse matrix.
* @param rank Rank parameter for matrix factorization.
* @param * (maxIterations) Maximum number of iterations.
* @param * (minResidue) Residue required to terminate.
* @param * (mit) Whether to terminate only when maxIterations is reached.
*/
template<typename MatType>
void Apply(const MatType& /* data */,
const arma::sp_mat& cleanedData,
const size_t rank,
const size_t /* maxIterations */,
const double /* minResidue */,
const bool /* mit */)
{
arma::vec sigma;

// Preprocessed data converted to mat format
arma::mat data(cleanedData);

// Do singular value decomposition using the block krylov SVD algorithm.
RandomizedBlockKrylovSVD blockkrylovsvd;
blockkrylovsvd.Apply(data, w, sigma, h, rank);

// Sigma matrix is multiplied to w.
w = w * arma::diagmat(sigma);

// Take transpose of the matrix h as required by CF class.
h = arma::trans(h);
}

/**
* Return predicted rating given user ID and item ID.
*
* @param user User ID.
* @param item Item ID.
*/
double GetRating(const size_t user, const size_t item) const
{
double rating = arma::as_scalar(w.row(item) * h.col(user));
return rating;
}

/**
* Get predicted ratings for a user.
*
* @param user User ID.
* @param rating Resulting rating vector.
*/
void GetRatingOfUser(const size_t user, arma::vec& rating) const
{
rating = w * h.col(user);
}

/**
* Get the neighborhood and corresponding similarities for a set of users.
*
* @tparam NeighborSearchPolicy The policy to perform neighbor search.
*
* @param users Users whose neighborhood is to be computed.
* @param numUsersForSimilarity The number of neighbors returned for
* each user.
* @param neighborhood Neighbors represented by user IDs.
* @param similarities Similarity between each user and each of its
* neighbors.
*/
template<typename NeighborSearchPolicy>
void GetNeighborhood(const arma::Col<size_t>& users,
const size_t numUsersForSimilarity,
arma::Mat<size_t>& neighborhood,
arma::mat& similarities) const
{
// We want to avoid calculating the full rating matrix, so we will do
// nearest neighbor search only on the H matrix, using the observation that
// if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i),
// W H.col(j)). This can be seen as nearest neighbor search on the H
// matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll
// decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply
// H by L^T. Then we can perform nearest neighbor search.
arma::mat l = arma::chol(w.t() * w);
arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.

// Temporarily store feature vector of queried users.
arma::mat query(stretchedH.n_rows, users.n_elem);
// Select feature vectors of queried users.
for (size_t i = 0; i < users.n_elem; ++i)
query.col(i) = stretchedH.col(users(i));

NeighborSearchPolicy neighborSearch(stretchedH);
neighborSearch.Search(
query, numUsersForSimilarity, neighborhood, similarities);
}

//! Get the Item Matrix.
const arma::mat& W() const { return w; }
//! Get the User Matrix.
const arma::mat& H() const { return h; }

/**
* Serialization.
*/
template<typename Archive>
void serialize(Archive& ar, const uint32_t /* version */)
{
ar(CEREAL_NVP(w));
ar(CEREAL_NVP(h));
}

private:
//! Item matrix.
arma::mat w;
//! User matrix.
arma::mat h;
};

} // namespace mlpack

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
#include "svd_incomplete_method.hpp"
#include "svdplusplus_method.hpp"
#include "quic_svd_method.hpp"
#include "block_krylov_svd_method.hpp"

#endif
18 changes: 10 additions & 8 deletions src/mlpack/tests/cf_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ void Serialization()
TEMPLATE_TEST_CASE("CFGetRecommendationsAllUsersTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
GetRecommendationsAllUsers<TestType>();
}
Expand All @@ -513,7 +513,7 @@ TEMPLATE_TEST_CASE("CFGetRecommendationsAllUsersTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFGetRecommendationsQueriedUsersTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
GetRecommendationsQueriedUser<TestType>();
}
Expand All @@ -524,7 +524,8 @@ TEMPLATE_TEST_CASE("CFGetRecommendationsQueriedUsersTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("RecommendationAccuracyTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy)
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy,
BlockKrylovSVDPolicy)
{
RecommendationAccuracy<TestType>();
}
Expand All @@ -546,7 +547,7 @@ TEMPLATE_TEST_CASE("RecommendationAccuracyTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFPredictTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
CFPredict<TestType>();
}
Expand All @@ -557,7 +558,7 @@ TEMPLATE_TEST_CASE("CFPredictTest", "[CFTest]",
TEMPLATE_TEST_CASE("CFBatchPredictTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, SVDPlusPlusPolicy,
QUIC_SVDPolicy)
QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
BatchPredict<TestType>();
}
Expand All @@ -568,7 +569,7 @@ TEMPLATE_TEST_CASE("CFBatchPredictTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("TrainTest_1", "[CFTest]",
RandomizedSVDPolicy, BatchSVDPolicy, NMFPolicy, SVDCompletePolicy,
SVDIncompletePolicy, QUIC_SVDPolicy)
SVDIncompletePolicy, QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
TestType decomposition;
Train(decomposition);
Expand All @@ -591,7 +592,8 @@ TEMPLATE_TEST_CASE("TrainTest_2", "[CFTest]",
*/
TEMPLATE_TEST_CASE("EmptyConstructorTrainTest", "[CFTest]",
RandomizedSVDPolicy, RegSVDPolicy, BatchSVDPolicy, NMFPolicy,
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy)
SVDCompletePolicy, SVDIncompletePolicy, BiasSVDPolicy, QUIC_SVDPolicy,
BlockKrylovSVDPolicy)
{
EmptyConstructorTrain<TestType>();
}
Expand All @@ -601,7 +603,7 @@ TEMPLATE_TEST_CASE("EmptyConstructorTrainTest", "[CFTest]",
*/
TEMPLATE_TEST_CASE("SerializationTest", "[CFTest]",
RandomizedSVDPolicy, BatchSVDPolicy, NMFPolicy, SVDCompletePolicy,
SVDIncompletePolicy, QUIC_SVDPolicy)
SVDIncompletePolicy, QUIC_SVDPolicy, BlockKrylovSVDPolicy)
{
Serialization<TestType>();
}
Expand Down