Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue error when serializing ANN models without MLPACK_ENABLE_ANN_SERIALIZATION #3451

Merged
merged 6 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

* Adapt PReLU layer for new neural network API (#3420).

* Add CF decomposition methods: `QUIC_SVDPolicy` and `BlockKrylovSVDPolicy` (#3413, #3404).
* Add CF decomposition methods: `QUIC_SVDPolicy` and `BlockKrylovSVDPolicy`
(#3413, #3404).

* Update outdated code in tutorials (#3398, #3401).

Expand All @@ -17,6 +18,9 @@

* Avoid deprecation warnings in Armadillo 11.4.4+ (#3405).

* Issue runtime error when serialization of neural networks is attempted but
`MLPACK_ENABLE_ANN_SERIALIZATION` is not defined (#3451).

### mlpack 4.0.1
###### 2022-12-23
* Fix mapping of categorical data for Julia bindings (#3305).
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ g++ -O3 -std=c++14 -o my_program my_program.cpp -larmadillo -fopenmp

Note that if you want to serialize (save or load) neural networks, you should
add `#define MLPACK_ENABLE_ANN_SERIALIZATION` before including `<mlpack.hpp>`.
If you don't define `MLPACK_ENABLE_ANN_SERIALIZATION` and your code serializes a
neural network, a compilation error will occur.

See the [C++ quickstart](doc/quickstart/cpp.md) and the
[examples](https://github.com/mlpack/examples) repository for some examples of
Expand All @@ -198,7 +200,8 @@ reduce compilation time:
* Only use the `MLPACK_ENABLE_ANN_SERIALIZATION` definition if you are
serializing neural networks in your code. When this define is enabled,
compilation time will increase significantly, as the compiler must generate
code for every possible type of layer.
code for every possible type of layer. (The large amount of extra
compilation overhead is why this is not enabled by default.)

* If you are using mlpack in multiple .cpp files, consider using [`extern
templates`](https://isocpp.org/wiki/faq/cpp11-language-templates) so that the
Expand Down
68 changes: 40 additions & 28 deletions src/mlpack/methods/ann/ffn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,34 +374,46 @@ void FFN<
MatType
>::serialize(Archive& ar, const uint32_t /* version */)
{
// Serialize the output layer and initialization rule.
ar(CEREAL_NVP(outputLayer));
ar(CEREAL_NVP(initializeRule));

// Serialize the network itself.
ar(CEREAL_NVP(network));
ar(CEREAL_NVP(parameters));

// Serialize the expected input size.
ar(CEREAL_NVP(inputDimensions));

// If we are loading, we need to initialize the weights.
if (cereal::is_loading<Archive>())
{
// We can clear these members, since it's not possible to serialize in the
// middle of training and resume.
predictors.clear();
responses.clear();

networkOutput.clear();
networkDelta.clear();

layerMemoryIsSet = false;
inputDimensionsAreSet = false;

// The weights in `parameters` will be correctly set for each layer in the
// first call to Forward().
}
#ifndef MLPACK_ENABLE_ANN_SERIALIZATION
// Note: if you define MLPACK_IGNORE_ANN_SERIALIZATION_WARNING, you had
// better ensure that every layer you are serializing has had
// CEREAL_REGISTER_TYPE() called somewhere. See layer/serialization.hpp for
// more information.
#ifndef MLPACK_ANN_IGNORE_SERIALIZATION_WARNING
throw std::runtime_error("Cannot serialize a neural network unless "
"MLPACK_ENABLE_ANN_SERIALIZATION is defined! See the \"Additional "
"build options\" section of the README for more information.");
#endif
#else
// Serialize the output layer and initialization rule.
ar(CEREAL_NVP(outputLayer));
ar(CEREAL_NVP(initializeRule));

// Serialize the network itself.
ar(CEREAL_NVP(network));
ar(CEREAL_NVP(parameters));

// Serialize the expected input size.
ar(CEREAL_NVP(inputDimensions));

// If we are loading, we need to initialize the weights.
if (cereal::is_loading<Archive>())
{
// We can clear these members, since it's not possible to serialize in the
// middle of training and resume.
predictors.clear();
responses.clear();

networkOutput.clear();
networkDelta.clear();

layerMemoryIsSet = false;
inputDimensionsAreSet = false;

// The weights in `parameters` will be correctly set for each layer in the
// first call to Forward().
}
#endif
}

template<typename OutputLayerType,
Expand Down
34 changes: 23 additions & 11 deletions src/mlpack/methods/ann/rnn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,29 @@ void RNN<
>::serialize(
Archive& ar, const uint32_t /* version */)
{
ar(CEREAL_NVP(bpttSteps));
ar(CEREAL_NVP(single));
ar(CEREAL_NVP(network));

if (Archive::is_loading::value)
{
// We can clear these members, since it's not possible to serialize in the
// middle of training and resume.
predictors.clear();
responses.clear();
}
#ifndef MLPACK_ENABLE_ANN_SERIALIZATION
// Note: if you define MLPACK_IGNORE_ANN_SERIALIZATION_WARNING, you had
// better ensure that every layer you are serializing has had
// CEREAL_REGISTER_TYPE() called somewhere. See layer/serialization.hpp for
// more information.
#ifndef MLPACK_IGNORE_ANN_SERIALIZATION_WARNING
throw std::runtime_error("Cannot serialize a neural network unless "
"MLPACK_ENABLE_ANN_SERIALIZATION is defined! See the \"Additional "
"build options\" section of the README for more information.");
#endif
#else
ar(CEREAL_NVP(bpttSteps));
ar(CEREAL_NVP(single));
ar(CEREAL_NVP(network));

if (Archive::is_loading::value)
{
// We can clear these members, since it's not possible to serialize in the
// middle of training and resume.
predictors.clear();
responses.clear();
}
#endif
}

template<
Expand Down
9 changes: 8 additions & 1 deletion src/mlpack/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ else()
target_compile_definitions(mlpack_test PUBLIC -DMLPACK_SUPPRESS_FATAL)
endif()

set_target_properties(mlpack_test PROPERTIES COTIRE_CXX_PREFIX_HEADER_INIT "../core.hpp")
# This has to be added here so that cotire picks it up (even though it is in
# individual tests).
target_compile_definitions(mlpack_test PUBLIC -DMLPACK_ENABLE_ANN_SERIALIZATION)
set_target_properties(mlpack_test PROPERTIES COTIRE_CXX_PREFIX_HEADER_INIT
"../core.hpp")
# TODO: use the source below, but this requires the DET test to be refactored
# and cleaned up.
# "../../mlpack.hpp")
cotire(mlpack_test)

# Copy test data into right place.
Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/tests/ann/feedforward_network_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#define MLPACK_ENABLE_ANN_SERIALIZATION
#ifndef MLPACK_ENABLE_ANN_SERIALIZATION
#define MLPACK_ENABLE_ANN_SERIALIZATION
#endif
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/ann.hpp>
#include <mlpack/methods/kmeans/kmeans.hpp>
Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/tests/ann/layer/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#define MLPACK_ENABLE_ANN_SERIALIZATION
#ifndef MLPACK_ENABLE_ANN_SERIALIZATION
#define MLPACK_ENABLE_ANN_SERIALIZATION
#endif
#include <mlpack/core.hpp>
#include <mlpack/methods/ann.hpp>

Expand Down
3 changes: 3 additions & 0 deletions src/mlpack/tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_ENABLE_ANN_SERIALIZATION
#define MLPACK_ENABLE_ANN_SERIALIZATION
#endif
#include <mlpack.hpp>

// #define CATCH_CONFIG_MAIN // catch.hpp will define main()
Expand Down