-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Adjusted R2 #2624
Added Adjusted R2 #2624
Changes from 8 commits
87dff6d
726a15e
1668879
786e5e9
cd9fbcd
3ef6020
85556de
9a36925
e4df97a
8346677
30a1dbc
19eb5c9
60d4c2a
d810ca6
e1993dd
dd0d16f
c95a1f5
a48e158
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -16,7 +16,7 @@ namespace mlpack { | |||||||
namespace cv { | ||||||||
|
||||||||
template<typename MLAlgorithm, typename DataType, typename ResponsesType> | ||||||||
double R2Score::Evaluate(MLAlgorithm& model, | ||||||||
double R2Score<false>::Evaluate(MLAlgorithm& model, | ||||||||
const DataType& data, | ||||||||
const ResponsesType& responses) | ||||||||
{ | ||||||||
|
@@ -45,10 +45,46 @@ double R2Score::Evaluate(MLAlgorithm& model, | |||||||
// Handling undefined R2 Score when both denominator and numerator is 0.0. | ||||||||
if (residualSumSquared == 0.0) | ||||||||
return totalSumSquared ? 1.0 : DBL_MIN; | ||||||||
|
||||||||
|
||||||||
// Returning R-squared | ||||||||
return 1 - residualSumSquared / totalSumSquared; | ||||||||
} | ||||||||
|
||||||||
template<typename MLAlgorithm, typename DataType, typename ResponsesType> | ||||||||
double R2Score<true>::Evaluate(MLAlgorithm& model, | ||||||||
const DataType& data, | ||||||||
const ResponsesType& responses) | ||||||||
{ | ||||||||
if (data.n_cols != responses.n_cols) | ||||||||
{ | ||||||||
std::ostringstream oss; | ||||||||
oss << "R2Score::Evaluate(): number of points (" << data.n_cols << ") " | ||||||||
<< "does not match number of responses (" << responses.n_cols << ")!" | ||||||||
<< std::endl; | ||||||||
throw std::invalid_argument(oss.str()); | ||||||||
} | ||||||||
|
||||||||
ResponsesType predictedResponses; | ||||||||
// Taking Predicted Output from the model. | ||||||||
model.Predict(data, predictedResponses); | ||||||||
// Mean value of response. | ||||||||
double meanResponses = arma::mean(responses); | ||||||||
|
||||||||
// Calculate the numerator i.e. residual sum of squares. | ||||||||
double residualSumSquared = arma::accu(arma::square(responses - | ||||||||
predictedResponses)); | ||||||||
|
||||||||
// Calculate the denominator i.e.total sum of squares. | ||||||||
double totalSumSquared = arma::accu(arma::square(responses - meanResponses)); | ||||||||
|
||||||||
// Handling undefined R2 Score when both denominator and numerator is 0.0. | ||||||||
if (residualSumSquared == 0.0) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dear @zoq Thanks for pointing it out. First of all I think, I have placed the Second, since I do not know what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so if (totalSumSquared = 0) I am checking I hope it makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this is correctly handled on lines 46 and 47 now. 👍 |
||||||||
return totalSumSquared ? 1.0 : DBL_MIN; | ||||||||
// Returning adjusted R-squared. | ||||||||
double rsq = 1 - (residualSumSquared / totalSumSquared); | ||||||||
return (1 - ((1 - rsq) * ((data.n_cols - 1) / (data.n_cols - data.n_rows - 1)))); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This line was longer than 80 characters, so I wrapped it. 👍 |
||||||||
} | ||||||||
|
||||||||
} // namespace cv | ||||||||
} // namespace mlpack | ||||||||
|
||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -187,10 +187,32 @@ TEST_CASE("R2ScoreTest", "[CVTest]") | |||||
|
||||||
double expectedR2 = 0.99999779; | ||||||
|
||||||
REQUIRE(R2Score::Evaluate(lr, data, responses) | ||||||
REQUIRE(R2Score<false>::Evaluate(lr, data, responses) | ||||||
== Approx(expectedR2).epsilon(1e-7)); | ||||||
} | ||||||
|
||||||
/** | ||||||
* Test the Adjusted R squared metric | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Add stop to be consistent with the rest of the codebase. |
||||||
*/ | ||||||
TEST_CASE("AdjR2ScoreTest", "[CVTest]") | ||||||
{ | ||||||
// Making two variables that define the linear function is | ||||||
// f(x1, x2) = x1 + x2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Add stop at the end to be consistent with the rest of the codebase. |
||||||
arma::mat X; | ||||||
X << 1 << 2 << 3 << 4 << 5 << 6 << arma::endr | ||||||
<< 2 << 3 << 4 << 5 << 6 << 7 << arma::endr; | ||||||
arma::rowvec Y; | ||||||
Y << 3 << 5 << 7 << 9 << 11 << 13; | ||||||
|
||||||
LinearRegression lr(X, Y); | ||||||
|
||||||
//Theoretically Adjusted R squared should be equal 1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Insert an extra space right after |
||||||
double expAdjR2 = 1; | ||||||
REQUIRE(std::abs(R2Score<true>::Evaluate(lr, X, Y) - expAdjR2) | ||||||
<= 1e-7); | ||||||
} | ||||||
|
||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
No need for two blank lines---one will be fine. 👍 |
||||||
/** | ||||||
* Test the mean squared error with matrix responses. | ||||||
*/ | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @shawnbrar, thanks for taking the time to add this template parameter! Actually, I think you don't need template specialization here. All you should need to do is declare the class as:
and then in the implementation of
Evaluate()
, you can change the bottom to this:and that should be all that's necessary. The nice thing about templates is that that code above will actually be compiled into two different functions at compile time, so the
if (AdjustedR2)
won't actually be run when the program is executed---only the correct branch will be run!Would you mind refactoring it to try this? It should result in a significantly shorter diff. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @rcurtin , sure, even I was thinking of a way which would have been shorter but like I said I am not an experienced programmer in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @rcurtin , I have removed the template specialization and made it the way you had asked for.