-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Quic_SVD method in CF #3404
Changes from 10 commits
b99e7cf
a57b561
0efb5ac
c223e64
4d8f000
6c51f3d
3315aff
b32355d
2b7bb73
6309e47
d11cb3f
6066623
70539c8
bcacf4e
32eaf10
3f9367c
84e4ec7
fd58849
30f5763
ddd491a
b760dd7
81033aa
ad5e1c4
35a82ec
089863f
b8190d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
/** | ||
* @file methods/cf/decomposition_policies/quic_svd_method.hpp | ||
* @author Adarsh Santoria | ||
* | ||
* Implementation of the quic svd method for use in | ||
* Collaborative Fitlering. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_QUIC_SVD_METHOD_HPP | ||
#define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_QUIC_SVD_METHOD_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
#include <mlpack/methods/quic_svd/quic_svd.hpp> | ||
|
||
namespace mlpack { | ||
|
||
/** | ||
* Implementation of the QUIC-SVD policy to act as a wrapper when | ||
* accessing Quic SVD from within CFType. | ||
* | ||
* An example of how to use QuicSVDPolicy in CF is shown below: | ||
* | ||
* @code | ||
* extern arma::mat data; // data is a (user, item, rating) table. | ||
* // Users for whom recommendations are generated. | ||
* extern arma::Col<size_t> users; | ||
* arma::Mat<size_t> recommendations; // Resulting recommendations. | ||
* | ||
* CFType<QuicSVDPolicy> cf(data); | ||
* | ||
* // Generate 10 recommendations for all users. | ||
* cf.GetRecommendations(10, recommendations); | ||
* @endcode | ||
*/ | ||
class QuicSVDPolicy | ||
{ | ||
public: | ||
/** | ||
* Use quic SVD method to perform collaborative filtering | ||
*/ | ||
QuicSVDPolicy() | ||
{ | ||
/* Nothing to do here */ | ||
} | ||
|
||
/** | ||
* Apply Collaborative Filtering to the provided data set using the | ||
* quic SVD. | ||
* | ||
* @param * (data) Data matrix: dense matrix (coordinate lists) | ||
* or sparse matrix(cleaned). | ||
* @param cleanedData item user table in form of sparse matrix. | ||
* @param * (rank) Rank parameter for matrix factorization. | ||
* @param * (maxIterations) Maximum number of iterations. | ||
* @param * (minResidue) Residue required to terminate. | ||
* @param * (mit) Whether to terminate only when maxIterations is reached. | ||
*/ | ||
template<typename MatType> | ||
void Apply(const MatType& /* data */, | ||
const arma::sp_mat& cleanedData, | ||
const size_t /* rank */, | ||
const size_t /* maxIterations */, | ||
const double /* minResidue */, | ||
const bool /* mit */) | ||
{ | ||
arma::mat sigma; | ||
|
||
// Preprocessed data converted to mat format | ||
arma::mat data(cleanedData); | ||
|
||
// Do singular value decomposition using the quic SVD algorithm. | ||
QUIC_SVD quicsvd; | ||
quicsvd.Apply(data, w, h, sigma); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QUIC_SVD gives the approximate SVD |
||
|
||
// Take transpose of the matrix h as required by CF class. | ||
h = arma::trans(h); | ||
} | ||
|
||
/** | ||
* Return predicted rating given user ID and item ID. | ||
* | ||
* @param user User ID. | ||
* @param item Item ID. | ||
*/ | ||
double GetRating(const size_t user, const size_t item) const | ||
{ | ||
double rating = arma::as_scalar(w.row(item) * h.col(user)); | ||
return rating; | ||
} | ||
|
||
/** | ||
* Get predicted ratings for a user. | ||
* | ||
* @param user User ID. | ||
* @param rating Resulting rating vector. | ||
*/ | ||
void GetRatingOfUser(const size_t user, arma::vec& rating) const | ||
{ | ||
rating = w * h.col(user); | ||
} | ||
|
||
/** | ||
* Get the neighborhood and corresponding similarities for a set of users. | ||
* | ||
* @tparam NeighborSearchPolicy The policy to perform neighbor search. | ||
* | ||
* @param users Users whose neighborhood is to be computed. | ||
* @param numUsersForSimilarity The number of neighbors returned for | ||
* each user. | ||
* @param neighborhood Neighbors represented by user IDs. | ||
* @param similarities Similarity between each user and each of its | ||
* neighbors. | ||
*/ | ||
template<typename NeighborSearchPolicy> | ||
void GetNeighborhood(const arma::Col<size_t>& users, | ||
const size_t numUsersForSimilarity, | ||
arma::Mat<size_t>& neighborhood, | ||
arma::mat& similarities) const | ||
{ | ||
// We want to avoid calculating the full rating matrix, so we will do | ||
// nearest neighbor search only on the H matrix, using the observation that | ||
// if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i), | ||
// W H.col(j)). This can be seen as nearest neighbor search on the H | ||
// matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll | ||
// decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply | ||
// H by L^T. Then we can perform nearest neighbor search. | ||
arma::mat l = arma::chol(w.t() * w); | ||
arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T. | ||
|
||
// Temporarily store feature vector of queried users. | ||
arma::mat query(stretchedH.n_rows, users.n_elem); | ||
// Select feature vectors of queried users. | ||
for (size_t i = 0; i < users.n_elem; ++i) | ||
query.col(i) = stretchedH.col(users(i)); | ||
|
||
NeighborSearchPolicy neighborSearch(stretchedH); | ||
neighborSearch.Search( | ||
query, numUsersForSimilarity, neighborhood, similarities); | ||
} | ||
|
||
//! Get the Item Matrix. | ||
const arma::mat& W() const { return w; } | ||
//! Get the User Matrix. | ||
const arma::mat& H() const { return h; } | ||
|
||
/** | ||
* Serialization. | ||
*/ | ||
template<typename Archive> | ||
void serialize(Archive& ar, const uint32_t /* version */) | ||
{ | ||
ar(CEREAL_NVP(w)); | ||
ar(CEREAL_NVP(h)); | ||
} | ||
|
||
private: | ||
//! Item matrix. | ||
arma::mat w; | ||
//! User matrix. | ||
arma::mat h; | ||
}; | ||
|
||
} // namespace mlpack | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,21 +42,20 @@ namespace mlpack { | |
* const double epsilon = 0.01; // Relative error limit of data in subspace. | ||
* const double delta = 0.1 // Lower error bound for Monte Carlo estimate. | ||
* | ||
* // Make a QuicSVD object. | ||
* QuicSVD qSVD(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be |
||
* | ||
* arma::mat u, v, sigma; // Matrices for the factors. data = u * sigma * v.t() | ||
* | ||
* // Get the factorization in the constructor. | ||
* QUIC_SVD(data, u, v, sigma, epsilon, delta); | ||
* // Use the Apply() method to get a factorization. | ||
* qSVD.Apply(data, u, v, sigma, epsilon, delta); | ||
* @endcode | ||
*/ | ||
class QUIC_SVD | ||
{ | ||
public: | ||
/** | ||
* Constructor which implements the QUIC-SVD algorithm. The function calls the | ||
* CosineTree constructor to create a subspace basis, where the original | ||
* matrix's projection has minimum reconstruction error. The constructor then | ||
* uses the ExtractSVD() function to calculate the SVD of the original dataset | ||
* in that subspace. | ||
* Create object for the randomized SVD method. | ||
* | ||
* @param dataset Matrix for which SVD is calculated. | ||
* @param u First unitary matrix. | ||
|
@@ -72,6 +71,35 @@ class QUIC_SVD | |
const double epsilon = 0.03, | ||
const double delta = 0.1); | ||
|
||
/** | ||
* Create object for the QUIC-SVD method. | ||
* | ||
* @param epsilon Error tolerance fraction for calculated subspace. | ||
* @param delta Cumulative probability for Monte Carlo error lower bound. | ||
*/ | ||
QUIC_SVD(const double epsilon = 0.03, | ||
const double delta = 0.1); | ||
|
||
/** | ||
* The function calls the CosineTree constructor to create a subspace basis, | ||
* where the original matrix's projection has minimum reconstruction error. | ||
* The constructor then uses the ExtractSVD() function to calculate the SVD | ||
* of the original dataset in that subspace. | ||
* | ||
* @param dataset Matrix for which SVD is calculated. | ||
* @param u First unitary matrix. | ||
* @param v Second unitary matrix. | ||
* @param sigma Diagonal matrix of singular values. | ||
* @param epsilon Error tolerance fraction for calculated subspace. | ||
* @param delta Cumulative probability for Monte Carlo error lower bound. | ||
*/ | ||
void Apply(const arma::mat& dataset, | ||
arma::mat& u, | ||
arma::mat& v, | ||
arma::mat& sigma, | ||
const double epsilon = 0.03, | ||
const double delta = 0.1); | ||
|
||
/** | ||
* This function uses the vector subspace created using a cosine tree to | ||
* calculate an approximate SVD of the original matrix. | ||
|
@@ -80,11 +108,12 @@ class QUIC_SVD | |
* @param v Second unitary matrix. | ||
* @param sigma Diagonal matrix of singular values. | ||
*/ | ||
void ExtractSVD(arma::mat& u, arma::mat& v, arma::mat& sigma); | ||
void ExtractSVD(const arma::mat& dataset, | ||
arma::mat& u, | ||
arma::mat& v, | ||
arma::mat& sigma); | ||
|
||
private: | ||
//! Matrix for which cosine tree is constructed. | ||
const arma::mat& dataset; | ||
//! Subspace basis of the input dataset. | ||
arma::mat basis; | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the original class is called
QUIC_SVD
, I think we should call thisQUIC_SVDPolicy
to be consistent. (I know that is harder to type...)