-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
simple_residue_termination.hpp
113 lines (95 loc) · 3.16 KB
/
simple_residue_termination.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/**
* @file simple_residue_termination.hpp
* @author Sumedh Ghaisas
*
* Termination policy used in AMF (Alternating Matrix Factorization).
*/
#ifndef _MLPACK_METHODS_AMF_SIMPLERESIDUETERMINATION_HPP_INCLUDED
#define _MLPACK_METHODS_AMF_SIMPLERESIDUETERMINATION_HPP_INCLUDED
#include <mlpack/core.hpp>
namespace mlpack {
namespace amf {
/**
* This class implements a simple residue-based termination policy. The
* termination decision depends on two factors: the value of the residue (the
* difference between the norm of WH this iteration and the previous iteration),
* and the number of iterations. If the current value of residue drops below
* the threshold or the number of iterations goes above the iteration limit,
* IsConverged() will return true. This class is meant for use with the AMF
* (alternating matrix factorization) class.
*
* @see AMF
*/
class SimpleResidueTermination
{
public:
/**
* Construct the SimpleResidueTermination object with the given minimum
* residue (or the default) and the given maximum number of iterations (or the
* default). 0 indicates no iteration limit.
*
* @param minResidue Minimum residue for termination.
* @param maxIterations Maximum number of iterations.
*/
SimpleResidueTermination(const double minResidue = 1e-5,
const size_t maxIterations = 10000)
: minResidue(minResidue), maxIterations(maxIterations) { }
/**
* Initializes the termination policy before stating the factorization.
*
* @param V Input matrix being factorized.
*/
template<typename MatType>
void Initialize(const MatType& V)
{
// Initialize the things we keep track of.
residue = DBL_MAX;
iteration = 1;
nm = V.n_rows * V.n_cols;
// Remove history.
normOld = 0;
}
/**
* Check if termination criterion is met.
*
* @param W Basis matrix of output.
* @param H Encoding matrix of output.
*/
bool IsConverged(arma::mat& W, arma::mat& H)
{
// Calculate the norm and compute the residue
const double norm = arma::norm(W * H, "fro");
residue = fabs(normOld - norm) / normOld;
// Store the norm.
normOld = norm;
// Increment iteration count
iteration++;
// Check if termination criterion is met.
return (residue < minResidue || iteration > maxIterations);
}
//! Get current value of residue
const double& Index() const { return residue; }
//! Get current iteration count
const size_t& Iteration() const { return iteration; }
//! Access max iteration count
const size_t& MaxIterations() const { return maxIterations; }
size_t& MaxIterations() { return maxIterations; }
//! Access minimum residue value
const double& MinResidue() const { return minResidue; }
double& MinResidue() { return minResidue; }
public:
//! residue threshold
double minResidue;
//! iteration threshold
size_t maxIterations;
//! current value of residue
double residue;
//! current iteration count
size_t iteration;
//! norm of previous iteration
double normOld;
size_t nm;
}; // class SimpleResidueTermination
}; // namespace amf
}; // namespace mlpack
#endif // _MLPACK_METHODS_AMF_SIMPLERESIDUETERMINATION_HPP_INCLUDED