Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added check_input_matrices option to python bindings #2787

Merged
merged 39 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
266ae6e
added no_sanity_checks option to python bindings
NippunSharma Dec 28, 2020
f1368e2
made no_sanity_checks more general
NippunSharma Dec 29, 2020
6ef1eca
added arma::Mat<size_t> and categorical data
NippunSharma Dec 31, 2020
22f5932
moved SanityCheck() to IO class and changed no_sanity_checks to check…
NippunSharma Dec 31, 2020
124d39d
changed comments
NippunSharma Dec 31, 2020
e5bfb40
wrapped function with Cython
NippunSharma Dec 31, 2020
c3ff8d3
indentation removed
NippunSharma Dec 31, 2020
ee889c7
removed iostream
NippunSharma Jan 2, 2021
8a8f0ea
changed name from SanityChecks() to CheckInputMatrices()
NippunSharma Jan 2, 2021
62c9f50
made single block in print_pyx.cpp
NippunSharma Jan 2, 2021
be82a0d
reduced num of chars per line in io.cpp
NippunSharma Jan 2, 2021
990ef53
reduced num of chars per line in py_option.hpp
NippunSharma Jan 2, 2021
7d6a209
changed function name while wrapping
NippunSharma Jan 2, 2021
aeb0f49
added inf functionality and tests
Jan 12, 2021
bfff19e
fixed indentations, random indices in tests, changed comments
NippunSharma Jan 12, 2021
7056aeb
fixed indent
NippunSharma Jan 12, 2021
5e0a347
added templated utility function
Jan 13, 2021
dc398bf
fixing errors
Jan 13, 2021
6ad120a
added CheckInputMatrix() to reduce code block in CheckInputMatrices()
Jan 13, 2021
04dee7e
fix errors
Jan 13, 2021
0e8e6f1
simplification
Jan 23, 2021
acc559f
Merge branch 'master' of https://github.com/mlpack/mlpack into sanity…
Jan 23, 2021
654e650
updated HISTORY.md
Jan 23, 2021
33af71d
Update src/mlpack/core/util/io.cpp
NippunSharma Jan 23, 2021
723ecf8
added more tests
Jan 23, 2021
0c60d8c
Merge branch 'sanity_checks' of https://github.com/NippunSharma/mlpac…
Jan 23, 2021
141b92d
fixes
Jan 23, 2021
0d31176
added check for unsigned matrix
Jan 24, 2021
954f3c4
fixing parameter types while checking
Jan 24, 2021
f5b32bc
removing size_t tests
Jan 25, 2021
7ec4058
removed size_t conditions from CheckInputMatrices()
Jan 26, 2021
16ff91b
made definition of TUPLE_TYPE inline
Jan 26, 2021
458b302
removed TUPLE_TYPE
Jan 26, 2021
ab3110e
minor change
Jan 26, 2021
0914fe9
added PARAM_IN_WITH_NAME
Jan 27, 2021
5b76b2f
changed PARAM_IN_WITH_NAME to PARAM_COMPLETE and editted other macros
Jan 28, 2021
3664422
added comments
Jan 28, 2021
23c3903
Update src/mlpack/core/util/io.cpp
NippunSharma Jan 29, 2021
9a3815f
changed name to PARAM
Jan 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mlpack/bindings/python/mlpack/io.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ cdef extern from "<mlpack/core/util/io.hpp>" namespace "mlpack" nogil:
@staticmethod
void ClearSettings() nogil except +

@staticmethod
void CheckInputMatrices() nogil except +

cdef extern from "<mlpack/bindings/python/mlpack/io_util.hpp>" \
namespace "mlpack::util" nogil:
void SetParam[T](string, T&) nogil except +
Expand Down
10 changes: 10 additions & 0 deletions src/mlpack/bindings/python/print_pyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ void PrintPYX(const util::BindingDetails& doc,
cout << " IO.SetPassed(<const string> '" << d.name << "')" << endl;
}

// Checking the type of check_input_matrices parameter.
cout << " if not isinstance(check_input_matrices, bool):" << endl;
cout << " raise TypeError(" <<"\"'check_input_matrices\' must have type "
<< "\'bool'!\")" << endl;
cout << endl;

// Before calling mlpackMain(), we check input matrices for NaN values if needed.
cout << " if check_input_matrices:" << endl;
cout << " IO.CheckInputMatrices()" << endl;

// Call the method.
cout << " # Call the mlpack program." << endl;
cout << " mlpackMain()" << endl;
Expand Down
6 changes: 4 additions & 2 deletions src/mlpack/bindings/python/py_option.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ class PyOption
data.required = required;
data.input = input;
data.loaded = false;
// Only "verbose" and "copy_all_inputs" will be persistent.
if (identifier == "verbose" || identifier == "copy_all_inputs")
// Only "verbose", "copy_all_inputs" and "check_input_matrices"
// will be persistent.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// will be persistent.
// will be persistent.

Remove extra space.

if (identifier == "verbose" || identifier == "copy_all_inputs" ||
identifier == "check_input_matrices")
data.persistent = true;
else
data.persistent = false;
Expand Down
49 changes: 49 additions & 0 deletions src/mlpack/core/util/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,52 @@ void IO::ClearSettings()
GetSingleton().aliases = persistentAliases;
GetSingleton().functionMap = persistentFunctions;
}

void IO::CheckInputMatrices()
{
typedef typename std::tuple<data::DatasetInfo, arma::mat> TupleType;
std::map<std::string, util::ParamData>::iterator itr;

for (itr = IO::Parameters().begin(); itr != IO::Parameters().end(); ++itr)
{
std::string paramName = itr->first;
std::string paramType = itr->second.cppType;
std::string errMsg = "The input " + paramName + " has NaN values.";
if (paramType == "arma::mat")
{
if (IO::GetParam<arma::Mat<double>>(paramName).has_nan())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably check for Inf too---you could use .is_finite() for that. http://arma.sourceforge.net/docs.html#is_finite 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is a good idea, I have added the same. I decided to .has_inf() just for uniformity with .has_nan(). I have also added separate tests for the nan and inf values in src/mlpack/bindings/python/tests/test_python_binding.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NippunSharma maybe now it makes sense to split out the body of each if statement into a templated function? The bodies of each if statement are basically identical, so it would be great if we could capture that identical functionality in one standalone function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that will reduce the repetitive code. I have added CheckInputMatrix() function that is called inside the CheckInputMatrices() function.

Log::Fatal << errMsg << std::endl;
}
else if (paramType == "arma::Mat<size_t>")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now you can remove these if statements, since it's not possible to find a NaN or Inf in a matrix that has type size_t. 👍

{
if (IO::GetParam<arma::Mat<size_t>>(paramName).has_nan())
Log::Fatal << errMsg << std::endl;
}
else if (paramType == "arma::colvec")
{
if (IO::GetParam<arma::Col<double>>(paramName).has_nan())
Log::Fatal << errMsg << std::endl;
}
else if (paramType == "arma::Col<size_t>")
{
if (IO::GetParam<arma::Col<size_t>>(paramName).has_nan())
Log::Fatal << errMsg << std::endl;
}
else if (paramType == "arma::rowvec")
{
if (IO::GetParam<arma::Row<double>>(paramName).has_nan())
Log::Fatal << errMsg << std::endl;
}
else if (paramType == "arma::Row<size_t>")
{
if (IO::GetParam<arma::Row<size_t>>(paramName).has_nan())
Log::Fatal << errMsg << std::endl;
}
else if (paramType == "std::tuple<data::DatasetInfo, arma::mat>")
{
if (std::get<1>(IO::GetParam<TupleType>(paramName)).has_nan())
Log::Fatal << errMsg << std::endl;
}
}
}

5 changes: 5 additions & 0 deletions src/mlpack/core/util/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ class IO
*/
static void ClearSettings();

/**
* Checks all input matrices for NaN values, if found throws an exception.
*/
static void CheckInputMatrices();

private:
//! Convenience map from alias values to names.
std::map<char, std::string> aliases;
Expand Down
2 changes: 2 additions & 0 deletions src/mlpack/core/util/mlpack_main.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ PARAM_FLAG("copy_all_inputs", "If specified, all input parameters will be deep"
" copied before the method is run. This is useful for debugging problems "
"where the input parameters are being modified by the algorithm, but can "
"slow down the code.", "");
PARAM_FLAG("check_input_matrices", "If specified, the input matrix is checked for"
" NaN values; an exception is thrown if any are found.", "");

// Nothing else needs to be defined---the binding will use mlpackMain() as-is.

Expand Down