Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Quic_SVD method in CF #3404

Merged
merged 26 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b99e7cf
add quic_svd for cf
AdarshSantoria Feb 9, 2023
a57b561
fixing comments
AdarshSantoria Feb 9, 2023
0efb5ac
Update src/mlpack/methods/cf/decomposition_policies/quic_svd_method.hpp
AdarshSantoria Feb 9, 2023
c223e64
Update src/mlpack/methods/quic_svd/quic_svd.hpp
AdarshSantoria Feb 9, 2023
4d8f000
fix style
AdarshSantoria Feb 9, 2023
6c51f3d
Added preprocessing part and tests
AdarshSantoria Feb 12, 2023
3315aff
Add comment
AdarshSantoria Feb 12, 2023
b32355d
Update HISTORY.md
AdarshSantoria Feb 12, 2023
2b7bb73
fix typo
AdarshSantoria Feb 12, 2023
6309e47
Revert Changes
AdarshSantoria Feb 12, 2023
d11cb3f
fix quicsvd bug and add template tests
AdarshSantoria Feb 13, 2023
6066623
Update src/mlpack/methods/quic_svd/quic_svd_impl.hpp
AdarshSantoria Feb 14, 2023
70539c8
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
bcacf4e
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
32eaf10
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
3f9367c
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
84e4ec7
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
fd58849
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
30f5763
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
ddd491a
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
b760dd7
Revert change
AdarshSantoria Feb 14, 2023
81033aa
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
ad5e1c4
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
35a82ec
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
089863f
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
b8190d1
Update src/mlpack/tests/cf_test.cpp
AdarshSantoria Feb 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -20,5 +20,6 @@
#include "svd_complete_method.hpp"
#include "svd_incomplete_method.hpp"
#include "svdplusplus_method.hpp"
#include "quic_svd_method.hpp"

#endif
170 changes: 170 additions & 0 deletions src/mlpack/methods/cf/decomposition_policies/quic_svd_method.hpp
@@ -0,0 +1,170 @@
/**
* @file methods/cf/decomposition_policies/quic_svd_method.hpp
* @author Adarsh Santoria
*
* Implementation of the quic svd method for use in
* Collaborative Fitlering.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_QUIC_SVD_METHOD_HPP
#define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_QUIC_SVD_METHOD_HPP

#include <mlpack/prereqs.hpp>
#include <mlpack/methods/quic_svd/quic_svd.hpp>

namespace mlpack {

/**
* Implementation of the QUIC-SVD policy to act as a wrapper when
* accessing Quic SVD from within CFType.
*
* An example of how to use QuicSVDPolicy in CF is shown below:
*
* @code
* extern arma::mat data; // data is a (user, item, rating) table.
* // Users for whom recommendations are generated.
* extern arma::Col<size_t> users;
* arma::Mat<size_t> recommendations; // Resulting recommendations.
*
* CFType<QuicSVDPolicy> cf(data);
*
* // Generate 10 recommendations for all users.
* cf.GetRecommendations(10, recommendations);
* @endcode
*/
class QuicSVDPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the original class is called QUIC_SVD, I think we should call this QUIC_SVDPolicy to be consistent. (I know that is harder to type...)

{
public:
/**
* Use quic SVD method to perform collaborative filtering
*/
QuicSVDPolicy()
{
/* Nothing to do here */
}

/**
* Apply Collaborative Filtering to the provided data set using the
* quic SVD.
*
* @param * (data) Data matrix: dense matrix (coordinate lists)
* or sparse matrix(cleaned).
* @param cleanedData item user table in form of sparse matrix.
* @param * (rank) Rank parameter for matrix factorization.
* @param * (maxIterations) Maximum number of iterations.
* @param * (minResidue) Residue required to terminate.
* @param * (mit) Whether to terminate only when maxIterations is reached.
*/
template<typename MatType>
void Apply(const MatType& /* data */,
const arma::sp_mat& cleanedData,
const size_t /* rank */,
const size_t /* maxIterations */,
const double /* minResidue */,
const bool /* mit */)
{
arma::mat sigma;

// Preprocessed data converted to mat format
arma::mat data(cleanedData);

// Do singular value decomposition using the quic SVD algorithm.
QUIC_SVD quicsvd;
quicsvd.Apply(data, w, h, sigma);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QUIC_SVD gives the approximate SVD data = u * diagmat(sigma) * v.t(); but your code here ignores sigma. I think that's incorrect and sigma should be incorporated into either w or h, so that we can reconstruct as data = w * h.t().


// Take transpose of the matrix h as required by CF class.
h = arma::trans(h);
}

/**
* Return predicted rating given user ID and item ID.
*
* @param user User ID.
* @param item Item ID.
*/
double GetRating(const size_t user, const size_t item) const
{
double rating = arma::as_scalar(w.row(item) * h.col(user));
return rating;
}

/**
* Get predicted ratings for a user.
*
* @param user User ID.
* @param rating Resulting rating vector.
*/
void GetRatingOfUser(const size_t user, arma::vec& rating) const
{
rating = w * h.col(user);
}

/**
* Get the neighborhood and corresponding similarities for a set of users.
*
* @tparam NeighborSearchPolicy The policy to perform neighbor search.
*
* @param users Users whose neighborhood is to be computed.
* @param numUsersForSimilarity The number of neighbors returned for
* each user.
* @param neighborhood Neighbors represented by user IDs.
* @param similarities Similarity between each user and each of its
* neighbors.
*/
template<typename NeighborSearchPolicy>
void GetNeighborhood(const arma::Col<size_t>& users,
const size_t numUsersForSimilarity,
arma::Mat<size_t>& neighborhood,
arma::mat& similarities) const
{
// We want to avoid calculating the full rating matrix, so we will do
// nearest neighbor search only on the H matrix, using the observation that
// if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i),
// W H.col(j)). This can be seen as nearest neighbor search on the H
// matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll
// decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply
// H by L^T. Then we can perform nearest neighbor search.
arma::mat l = arma::chol(w.t() * w);
arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.

// Temporarily store feature vector of queried users.
arma::mat query(stretchedH.n_rows, users.n_elem);
// Select feature vectors of queried users.
for (size_t i = 0; i < users.n_elem; ++i)
query.col(i) = stretchedH.col(users(i));

NeighborSearchPolicy neighborSearch(stretchedH);
neighborSearch.Search(
query, numUsersForSimilarity, neighborhood, similarities);
}

//! Get the Item Matrix.
const arma::mat& W() const { return w; }
//! Get the User Matrix.
const arma::mat& H() const { return h; }

/**
* Serialization.
*/
template<typename Archive>
void serialize(Archive& ar, const uint32_t /* version */)
{
ar(CEREAL_NVP(w));
ar(CEREAL_NVP(h));
}

private:
//! Item matrix.
arma::mat w;
//! User matrix.
arma::mat h;
};

} // namespace mlpack

#endif
49 changes: 39 additions & 10 deletions src/mlpack/methods/quic_svd/quic_svd.hpp
Expand Up @@ -42,21 +42,20 @@ namespace mlpack {
* const double epsilon = 0.01; // Relative error limit of data in subspace.
* const double delta = 0.1 // Lower error bound for Monte Carlo estimate.
*
* // Make a QuicSVD object.
* QuicSVD qSVD();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be QUIC_SVD?

*
* arma::mat u, v, sigma; // Matrices for the factors. data = u * sigma * v.t()
*
* // Get the factorization in the constructor.
* QUIC_SVD(data, u, v, sigma, epsilon, delta);
* // Use the Apply() method to get a factorization.
* qSVD.Apply(data, u, v, sigma, epsilon, delta);
* @endcode
*/
class QUIC_SVD
{
public:
/**
* Constructor which implements the QUIC-SVD algorithm. The function calls the
* CosineTree constructor to create a subspace basis, where the original
* matrix's projection has minimum reconstruction error. The constructor then
* uses the ExtractSVD() function to calculate the SVD of the original dataset
* in that subspace.
* Create object for the randomized SVD method.
*
* @param dataset Matrix for which SVD is calculated.
* @param u First unitary matrix.
Expand All @@ -72,6 +71,35 @@ class QUIC_SVD
const double epsilon = 0.03,
const double delta = 0.1);

/**
* Create object for the QUIC-SVD method.
*
* @param epsilon Error tolerance fraction for calculated subspace.
* @param delta Cumulative probability for Monte Carlo error lower bound.
*/
QUIC_SVD(const double epsilon = 0.03,
const double delta = 0.1);

/**
* The function calls the CosineTree constructor to create a subspace basis,
* where the original matrix's projection has minimum reconstruction error.
* The constructor then uses the ExtractSVD() function to calculate the SVD
* of the original dataset in that subspace.
*
* @param dataset Matrix for which SVD is calculated.
* @param u First unitary matrix.
* @param v Second unitary matrix.
* @param sigma Diagonal matrix of singular values.
* @param epsilon Error tolerance fraction for calculated subspace.
* @param delta Cumulative probability for Monte Carlo error lower bound.
*/
void Apply(const arma::mat& dataset,
arma::mat& u,
arma::mat& v,
arma::mat& sigma,
const double epsilon = 0.03,
const double delta = 0.1);

/**
* This function uses the vector subspace created using a cosine tree to
* calculate an approximate SVD of the original matrix.
Expand All @@ -80,11 +108,12 @@ class QUIC_SVD
* @param v Second unitary matrix.
* @param sigma Diagonal matrix of singular values.
*/
void ExtractSVD(arma::mat& u, arma::mat& v, arma::mat& sigma);
void ExtractSVD(const arma::mat& dataset,
arma::mat& u,
arma::mat& v,
arma::mat& sigma);

private:
//! Matrix for which cosine tree is constructed.
const arma::mat& dataset;
//! Subspace basis of the input dataset.
arma::mat basis;
};
Expand Down
26 changes: 22 additions & 4 deletions src/mlpack/methods/quic_svd/quic_svd_impl.hpp
Expand Up @@ -23,8 +23,25 @@ inline QUIC_SVD::QUIC_SVD(
arma::mat& v,
arma::mat& sigma,
const double epsilon,
const double delta) :
dataset(dataset)
const double delta)
{
Apply(dataset, u, v, sigma, epsilon, delta);
}

inline QUIC_SVD::QUIC_SVD(
const double epsilon,
const double delta)
{
/* Nothing to do here */
}

inline void QUIC_SVD::Apply(
const arma::mat& dataset,
arma::mat& u,
arma::mat& v,
arma::mat& sigma,
const double epsilon,
const double delta)
{
// Since columns are sample in the implementation, the matrix is transposed if
// necessary for maximum speedup.
Expand All @@ -42,10 +59,11 @@ inline QUIC_SVD::QUIC_SVD(

// Use the ExtractSVD algorithm mentioned in the paper to extract the SVD of
// the original dataset in the obtained subspace.
ExtractSVD(u, v, sigma);
ExtractSVD(dataset,u, v, sigma);
AdarshSantoria marked this conversation as resolved.
Show resolved Hide resolved
}

inline void QUIC_SVD::ExtractSVD(arma::mat& u,
inline void QUIC_SVD::ExtractSVD(const arma::mat& dataset,
arma::mat& u,
arma::mat& v,
arma::mat& sigma)
{
Expand Down