Skip to content

Commit

Permalink
issue #308: avoiding bloat in non-debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Joris Gillis committed Jun 15, 2018
1 parent 70ac2f3 commit 10ad501
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 44 deletions.
8 changes: 4 additions & 4 deletions casadi/core/function.cpp
Expand Up @@ -1023,14 +1023,14 @@ namespace casadi {
return (*this)->export_code(lang, stream, options);
}

std::string Function::serialize() const {
std::string Function::serialize(const Dict& opts) const {
std::stringstream ss;
serialize(ss);
serialize(ss, opts);
return ss.str();
}

void Function::serialize(std::ostream &stream) const {
Serializer s(stream);
void Function::serialize(std::ostream &stream, const Dict& opts) const {
Serializer s(stream, opts);
return serialize(s);
}

Expand Down
4 changes: 2 additions & 2 deletions casadi/core/function.hpp
Expand Up @@ -704,14 +704,14 @@ namespace casadi {

#ifndef SWIG
/** \brief Serialize */
void serialize(std::ostream &stream) const;
void serialize(std::ostream &stream, const Dict& opts=Dict()) const;

/** \brief Serialize */
void serialize(Serializer &s) const;
#endif

/** \brief Serialize */
std::string serialize() const;
std::string serialize(const Dict& opts=Dict()) const;

std::string export_code(const std::string& lang, const Dict& options=Dict()) const;
#ifndef SWIG
Expand Down
52 changes: 46 additions & 6 deletions casadi/core/serializer.cpp
Expand Up @@ -34,12 +34,50 @@
using namespace std;
namespace casadi {

DeSerializer::DeSerializer(std::istream& in_s) : in(in_s) {
static casadi_int serialization_protocol_version = 1;
static casadi_int serialization_check = 123456789012345;

DeSerializer::DeSerializer(std::istream& in_s) : in(in_s), debug_(false) {

// Sanity check
casadi_int check;
unpack(check);
casadi_assert(check==serialization_check,
"DeSerializer sanity check failed. "
"Expected " + str(serialization_check) + ", but got " + str(check) + ".");

// API version check
casadi_int v;
unpack(v);
casadi_assert(v==serialization_protocol_version,
"Serialization protocol is not compatible. "
"Got version " + str(v) + ", while " + str(serialization_protocol_version) + " was expected.");

bool debug;
unpack(debug);
debug_ = debug;

}

Serializer::Serializer(std::ostream& out_s, const Dict& /*opts*/) : out(out_s) {
Serializer::Serializer(std::ostream& out_s, const Dict& opts) : out(out_s), debug_(false) {
// Sanity check
pack(serialization_check);
// API version check
pack(casadi_int(serialization_protocol_version));

bool debug = false;

// Read options
for (auto&& op : opts) {
if (op.first=="debug") {
debug = op.second;
} else {
casadi_error("Unknown option: '" + op.first + "'.");
}
}

pack(debug);
debug_ = debug;
}

casadi_int Serializer::add(const Function& f) {
Expand All @@ -54,13 +92,15 @@ namespace casadi {
}

void Serializer::decorate(char e) {
pack(e);
if (debug_) pack(e);
}

void DeSerializer::assert_decoration(char e) {
char t;
unpack(t);
casadi_assert(t==e, "Serializer error '" + str(e) + "' vs '" + str(t) + "'.");
if (debug_) {
char t;
unpack(t);
casadi_assert(t==e, "Serializer error '" + str(e) + "' vs '" + str(t) + "'.");
}
}

void DeSerializer::unpack(casadi_int& e) {
Expand Down
19 changes: 10 additions & 9 deletions casadi/core/serializer.hpp
Expand Up @@ -81,15 +81,14 @@ namespace casadi {

template <class T>
void unpack(const std::string& descr, T& e) {
std::string d;
unpack(d);
//uout() << "unpack started: " << descr << std::endl;
casadi_assert(d==descr, "Mismatch: '" + descr + "' expected, got '" + d + "'.");
if (debug_) {
std::string d;
unpack(d);
casadi_assert(d==descr, "Mismatch: '" + descr + "' expected, got '" + d + "'.");
}
unpack(e);
//uout() << "unpack: " << descr << ": " << e << std::endl;
}


template <class T, class M>
void shared_unpack(T& e, M& cache) {
char i;
Expand Down Expand Up @@ -121,6 +120,8 @@ namespace casadi {
std::vector<SXElem> sx_nodes;
std::vector<Sparsity> sparsities;
std::vector<Linsol> linsols;

bool debug_;
};


Expand Down Expand Up @@ -164,10 +165,8 @@ namespace casadi {

template <class T>
void pack(const std::string& descr, const T& e) {
//uout() << " pack started: " << descr << std::endl;
pack(descr);
if (debug_) pack(descr);
pack(e);
//uout() << " pack: " << descr << ": " << e << std::endl;
}

void decorate(char e);
Expand Down Expand Up @@ -198,6 +197,8 @@ namespace casadi {

std::ostream& out;

bool debug_;

};


Expand Down
47 changes: 24 additions & 23 deletions test/python/function.py
Expand Up @@ -1680,37 +1680,38 @@ def test_codegen_avoid_stack(self):


def test_serialize(self):
x = SX.sym("x")
y = x+3
z = sin(y)
for opts in [{},{"debug":True}]:
x = SX.sym("x")
y = x+3
z = sin(y)

f = Function('f',[x],[z])
fs = Function.deserialize(f.serialize())
f = Function('f',[x],[z])
fs = Function.deserialize(f.serialize(opts))

self.checkfunction(f,fs,inputs=[2])
self.checkfunction(f,fs,inputs=[2])

x = SX.sym("x")
y = x+3
z = sin(y)
x = SX.sym("x")
y = x+3
z = sin(y)

f = Function('f',[x],[z,np.nan,-np.inf,np.inf])
fs = Function.deserialize(f.serialize())
self.checkfunction(f,fs,inputs=[2])
f = Function('f',[x],[z,np.nan,-np.inf,np.inf])
fs = Function.deserialize(f.serialize(opts))
self.checkfunction(f,fs,inputs=[2])

x = SX.sym("x")
y = SX.sym("y", Sparsity.lower(3))
z = x+y
z1 = sparsify(vertcat(z[0],0,z[1]))
z2 = z.T
x = SX.sym("x")
y = SX.sym("y", Sparsity.lower(3))
z = x+y
z1 = sparsify(vertcat(z[0],0,z[1]))
z2 = z.T

f = Function('f',[x,y],[z1,z2,x**2],["x","y"],["a","b","c"])
fs = Function.deserialize(f.serialize())
f = Function('f',[x,y],[z1,z2,x**2],["x","y"],["a","b","c"])
fs = Function.deserialize(f.serialize(opts))

self.assertEqual(fs.name_in(0), "x")
self.assertEqual(fs.name_out(0), "a")
self.assertEqual(fs.name(), "f")
self.assertEqual(fs.name_in(0), "x")
self.assertEqual(fs.name_out(0), "a")
self.assertEqual(fs.name(), "f")

self.checkfunction(f,fs,inputs=[3.7,np.array([[1,0,0],[2,3,0],[4,5,6]])],hessian=False)
self.checkfunction(f,fs,inputs=[3.7,np.array([[1,0,0],[2,3,0],[4,5,6]])],hessian=False)


fs = pickle.loads(pickle.dumps(f))
Expand Down

0 comments on commit 10ad501

Please sign in to comment.