Skip to content

Commit

Permalink
Moved some serialization from .h to .cc
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Oct 6, 2016
1 parent 6b021b4 commit a6d937d
Show file tree
Hide file tree
Showing 20 changed files with 328 additions and 199 deletions.
1 change: 1 addition & 0 deletions cnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ set(cnn_library_SRCS
shadow-params.cc
tensor.cc
training.cc
weight-decay.cc
)

# Headers:
Expand Down
29 changes: 27 additions & 2 deletions cnn/cfsm-builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
#include <fstream>
#include <iostream>

#include "cnn/io-macros.h"

using namespace std;

BOOST_CLASS_EXPORT_IMPLEMENT(cnn::StandardSoftmaxBuilder)
BOOST_CLASS_EXPORT_IMPLEMENT(cnn::ClassFactoredSoftmaxBuilder)
// BOOST_CLASS_EXPORT_IMPLEMENT(cnn::StandardSoftmaxBuilder)
// BOOST_CLASS_EXPORT_IMPLEMENT(cnn::ClassFactoredSoftmaxBuilder)

namespace cnn {

Expand Down Expand Up @@ -53,6 +55,14 @@ Expression StandardSoftmaxBuilder::full_log_distribution(const Expression& rep)
return log(softmax(affine_transform({b, w, rep})));
}

template<class Archive>
void StandardSoftmaxBuilder::serialize(Archive& ar, const unsigned int) {
boost::serialization::base_object<SoftmaxBuilder>(*this);
ar & p_w;
ar & p_b;
}
CNN_SERIALIZE_IMPL(StandardSoftmaxBuilder)

ClassFactoredSoftmaxBuilder::ClassFactoredSoftmaxBuilder() {}

ClassFactoredSoftmaxBuilder::ClassFactoredSoftmaxBuilder(unsigned rep_dim,
Expand Down Expand Up @@ -211,4 +221,19 @@ void ClassFactoredSoftmaxBuilder::read_cluster_file(const std::string& cluster_f
cerr << "Read " << wc << " words in " << cdict.size() << " clusters (" << scs << " singleton clusters)\n";
}

template<class Archive>
void ClassFactoredSoftmaxBuilder::serialize(Archive& ar, const unsigned int) {
boost::serialization::base_object<SoftmaxBuilder>(*this);
ar & cdict;
ar & widx2cidx;
ar & widx2cwidx;
ar & cidx2words;
ar & singleton_cluster;
ar & p_r2c;
ar & p_cbias;
ar & p_rc2ws;
ar & p_rcwbiases;
}
CNN_SERIALIZE_IMPL(ClassFactoredSoftmaxBuilder)

} // namespace cnn
23 changes: 4 additions & 19 deletions cnn/cfsm-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ class StandardSoftmaxBuilder : public SoftmaxBuilder {

friend class boost::serialization::access;
template<class Archive>
void serialize(Archive& ar, const unsigned int) {
boost::serialization::base_object<SoftmaxBuilder>(*this);
ar & p_w;
ar & p_b;
}
void serialize(Archive& ar, const unsigned int);
};

// helps with implementation of hierarchical softmax
Expand Down Expand Up @@ -108,20 +104,9 @@ class ClassFactoredSoftmaxBuilder : public SoftmaxBuilder {

friend class boost::serialization::access;
template<class Archive>
void serialize(Archive& ar, const unsigned int) {
boost::serialization::base_object<SoftmaxBuilder>(*this);
ar & cdict;
ar & widx2cidx;
ar & widx2cwidx;
ar & cidx2words;
ar & singleton_cluster;
ar & p_r2c;
ar & p_cbias;
ar & p_rc2ws;
ar & p_rcwbiases;
}
void serialize(Archive& ar, const unsigned int);
};
} // namespace cnn
BOOST_CLASS_EXPORT_KEY(cnn::StandardSoftmaxBuilder)
BOOST_CLASS_EXPORT_KEY(cnn::ClassFactoredSoftmaxBuilder)
//BOOST_CLASS_EXPORT_KEY(cnn::StandardSoftmaxBuilder)
//BOOST_CLASS_EXPORT_KEY(cnn::ClassFactoredSoftmaxBuilder)
#endif
1 change: 1 addition & 0 deletions cnn/cnn.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cnn/cnn.h"

#include "cnn/exec.h"
#include "cnn/nodes.h"
#include "cnn/param-nodes.h"
Expand Down
5 changes: 1 addition & 4 deletions cnn/cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
#include <iostream>
#include <initializer_list>
#include <utility>

#include <boost/serialization/strong_typedef.hpp>
#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>

#include "cnn/init.h"
#include "cnn/aligned-mem-pool.h"
Expand Down
21 changes: 21 additions & 0 deletions cnn/io-macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef CNN_IO_MACROS__
#define CNN_IO_MACROS__

#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>

#define CNN_SERIALIZE_IMPL(MyClass) \
template void MyClass::serialize<boost::archive::text_oarchive>(boost::archive::text_oarchive &ar, const unsigned int); \
template void MyClass::serialize<boost::archive::text_iarchive>(boost::archive::text_iarchive &ar, const unsigned int); \
template void MyClass::serialize<boost::archive::binary_oarchive>(boost::archive::binary_oarchive &ar, const unsigned int); \
template void MyClass::serialize<boost::archive::binary_iarchive>(boost::archive::binary_iarchive &ar, const unsigned int);

#define CNN_SAVELOAD_IMPL(MyClass) \
template void MyClass::save<boost::archive::text_oarchive>(boost::archive::text_oarchive &ar, const unsigned int) const; \
template void MyClass::load<boost::archive::text_iarchive>(boost::archive::text_iarchive &ar, const unsigned int); \
template void MyClass::save<boost::archive::binary_oarchive>(boost::archive::binary_oarchive &ar, const unsigned int) const; \
template void MyClass::load<boost::archive::binary_iarchive>(boost::archive::binary_iarchive &ar, const unsigned int);

#endif
15 changes: 15 additions & 0 deletions cnn/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
#include <vector>
#include <iostream>

#include <boost/serialization/utility.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>

#include "cnn/nodes.h"
#include "cnn/io-macros.h"

using namespace std;
using namespace cnn::expr;
Expand Down Expand Up @@ -199,4 +205,13 @@ void LSTMBuilder::load_parameters_pretraining(const string& fname) {
}
}

template<class Archive>
void LSTMBuilder::serialize(Archive& ar, const unsigned int) {
ar & boost::serialization::base_object<RNNBuilder>(*this);
ar & params;
ar & layers;
ar & dropout_rate;
}
CNN_SERIALIZE_IMPL(LSTMBuilder);

} // namespace cnn
7 changes: 1 addition & 6 deletions cnn/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,7 @@ struct LSTMBuilder : public RNNBuilder {
private:
friend class boost::serialization::access;
template<class Archive>
void serialize(Archive& ar, const unsigned int) {
ar & boost::serialization::base_object<RNNBuilder>(*this);
ar & params;
ar & layers;
ar & dropout_rate;
}
void serialize(Archive& ar, const unsigned int);

};

Expand Down
69 changes: 67 additions & 2 deletions cnn/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "cnn/tensor.h"
#include "cnn/aligned-mem-pool.h"
#include "cnn/cnn.h"
#include "cnn/io-macros.h"

#include <unordered_set>
#include <iostream>
Expand All @@ -12,6 +13,11 @@
#ifndef __CUDACC__
#include <boost/archive/text_iarchive.hpp>
#include <boost/archive/text_oarchive.hpp>
#include <boost/serialization/export.hpp>
#include <boost/serialization/vector.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/access.hpp>
#include <boost/serialization/split_member.hpp>
#endif

// Macros for defining functions over parameters
Expand Down Expand Up @@ -40,8 +46,8 @@
using namespace std;

#ifndef __CUDACC__
BOOST_CLASS_EXPORT_IMPLEMENT(cnn::ParameterStorage)
BOOST_CLASS_EXPORT_IMPLEMENT(cnn::LookupParameterStorage)
// BOOST_CLASS_EXPORT_IMPLEMENT(cnn::ParameterStorage)
// BOOST_CLASS_EXPORT_IMPLEMENT(cnn::LookupParameterStorage)
#endif

namespace cnn {
Expand Down Expand Up @@ -81,6 +87,17 @@ void ParameterStorage::clear() {
TensorTools::Zero(g);
}

#ifndef __CUDACC__
template<class Archive>
void ParameterStorage::serialize(Archive& ar, const unsigned int) {
boost::serialization::base_object<ParameterStorageBase>(*this);
ar & dim;
ar & values;
ar & g;
}
CNN_SERIALIZE_IMPL(ParameterStorage)
#endif

LookupParameterStorage::LookupParameterStorage(unsigned n, const Dim& d) : dim(d) {
all_dim = dim; all_dim.d[all_dim.nd++] = n;
all_grads.d = all_values.d = all_dim;
Expand Down Expand Up @@ -125,6 +142,25 @@ void LookupParameterStorage::clear() {
non_zero_grads.clear();
}

#ifndef __CUDACC__
template<class Archive>
void LookupParameterStorage::save(Archive& ar, const unsigned int) const {
ar << boost::serialization::base_object<ParameterStorageBase>(*this);
ar << all_dim;
ar << all_values;
ar << all_grads;
}
template<class Archive>
void LookupParameterStorage::load(Archive& ar, const unsigned int) {
ar >> boost::serialization::base_object<ParameterStorageBase>(*this);
ar >> all_dim;
ar >> all_values;
ar >> all_grads;
initialize_lookups();
}
CNN_SAVELOAD_IMPL(LookupParameterStorage)
#endif

Parameter::Parameter() {
mp = nullptr;
index = 0;
Expand All @@ -140,6 +176,15 @@ void Parameter::zero() {
return mp->parameters_list()[index]->zero();
}

#ifndef __CUDACC__
template<class Archive>
void Parameter::serialize(Archive& ar, const unsigned int) {
ar & mp;
ar & index;
}
CNN_SERIALIZE_IMPL(Parameter)
#endif

LookupParameter::LookupParameter() {
mp = nullptr;
index = 0;
Expand All @@ -159,6 +204,15 @@ void LookupParameter::initialize(unsigned index, const std::vector<float>& val)
get()->initialize(index, val);
}

#ifndef __CUDACC__
template<class Archive>
void LookupParameter::serialize(Archive& ar, const unsigned int) {
ar & mp;
ar & index;
}
CNN_SERIALIZE_IMPL(LookupParameter)
#endif

Model::Model() : gradient_norm_scratch(nullptr) {
weight_decay.set_lambda(weight_decay_lambda);
}
Expand Down Expand Up @@ -219,6 +273,17 @@ size_t Model::parameter_count() const {
return r;
}

#ifndef __CUDACC__
template<class Archive>
void Model::serialize(Archive& ar, const unsigned int) {
ar & all_params;
ar & params;
ar & lookup_params;
ar & weight_decay;
}
CNN_SERIALIZE_IMPL(Model)
#endif

void save_cnn_model(std::string filename, Model* model) {
std::ofstream out(filename);
boost::archive::text_oarchive oa(out);
Expand Down

0 comments on commit a6d937d

Please sign in to comment.