Skip to content

Commit 3e9919e

Browse files
committed
issue #2433: low-overhead function calls in python
1 parent 92dcf7e commit 3e9919e

File tree

9 files changed

+321
-5
lines changed

9 files changed

+321
-5
lines changed

casadi/core/callback.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ namespace casadi {
5252
}
5353
}
5454

55+
int Callback::eval_buffer(const double **arg, const std::vector<casadi_int>& sizes_arg,
56+
double **res, const std::vector<casadi_int>& sizes_res) const {
57+
casadi_error("eval_buffer not overloaded.");
58+
}
59+
bool Callback::has_eval_buffer() const {
60+
return false;
61+
}
5562
std::vector<DM> Callback::eval(const std::vector<DM>& arg) const {
5663
return (*this)->FunctionInternal::eval_dm(arg);
5764
}

casadi/core/callback.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,15 @@ namespace casadi {
8585
/** \brief Evaluate numerically, temporary matrices and work vectors */
8686
virtual std::vector<DM> eval(const std::vector<DM>& arg) const;
8787

88-
/** \brief Get the number of inputs
88+
/** \brief A copy-free low level interface
89+
*
90+
* In Python, you will be passed two tuples of memoryview objects
91+
*/
92+
virtual int eval_buffer(const double **arg, const std::vector<casadi_int>& sizes_arg,
93+
double **res, const std::vector<casadi_int>& sizes_res) const;
94+
virtual bool has_eval_buffer() const;
95+
96+
/** \brief Get the number of inputs
8997
* This function is called during construction.
9098
*/
9199
virtual casadi_int get_n_in();

casadi/core/callback_internal.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace casadi {
3939

4040
CallbackInternal::
4141
CallbackInternal(const std::string& name, Callback *self)
42-
: FunctionInternal(name), self_(self) {
42+
: FunctionInternal(name), self_(self), has_eval_buffer_(false) {
4343
}
4444

4545
CallbackInternal::~CallbackInternal() {
@@ -70,6 +70,10 @@ namespace casadi {
7070
TRY_CALL(get_name_out, self_, i);
7171
}
7272

73+
bool CallbackInternal::has_eval_buffer() const {
74+
TRY_CALL(has_eval_buffer, self_);
75+
}
76+
7377
void CallbackInternal::init(const Dict& opts) {
7478
// Initialize the base classes
7579
FunctionInternal::init(opts);
@@ -86,12 +90,35 @@ namespace casadi {
8690

8791
// Finalize the base classes
8892
FunctionInternal::finalize();
93+
94+
has_eval_buffer_ = has_eval_buffer();
95+
96+
if (has_eval_buffer_) {
97+
sizes_arg_.resize(n_in_);
98+
for (casadi_int i=0;i<n_in_;++i) {
99+
sizes_arg_[i] = nnz_in(i);
100+
}
101+
sizes_res_.resize(n_out_);
102+
for (casadi_int i=0;i<n_out_;++i) {
103+
sizes_res_[i] = nnz_out(i);
104+
}
105+
}
89106
}
90107

91108
std::vector<DM> CallbackInternal::eval_dm(const std::vector<DM>& arg) const {
92109
TRY_CALL(eval, self_, arg);
93110
}
94111

112+
/** \brief Evaluate numerically */
113+
int CallbackInternal::eval(const double** arg, double** res,
114+
casadi_int* iw, double* w, void* mem) const {
115+
if (has_eval_dm()) {
116+
return FunctionInternal::eval(arg, res, iw, w, mem);
117+
} else {
118+
TRY_CALL(eval_buffer, self_, arg, sizes_arg_, res, sizes_res_);
119+
}
120+
}
121+
95122
bool CallbackInternal::uses_output() const {
96123
TRY_CALL(uses_output, self_);
97124
}

casadi/core/callback_internal.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,16 @@ namespace casadi {
6868
void finalize() override;
6969

7070
///@{
71-
/** \brief Evaluate with DM matrices (NOTE: eval not defined) */
71+
/** \brief Evaluate with DM matrices */
7272
std::vector<DM> eval_dm(const std::vector<DM>& arg) const override;
73-
bool has_eval_dm() const override { return true;}
73+
bool has_eval_dm() const override { return !has_eval_buffer_;}
7474
///@}
7575

76+
/** \brief Evaluate numerically */
77+
virtual int eval(const double** arg, double** res,
78+
casadi_int* iw, double* w, void* mem) const override;
79+
bool has_eval_buffer() const;
80+
7681
/** \brief Do the derivative functions need nondifferentiated outputs? */
7782
bool uses_output() const override;
7883

@@ -105,6 +110,10 @@ namespace casadi {
105110

106111
/** \brief Pointer to the public class */
107112
Callback* self_;
113+
114+
// For buffered evaluation
115+
std::vector<casadi_int> sizes_arg_, sizes_res_;
116+
bool has_eval_buffer_;
108117
};
109118

110119
} // namespace casadi

casadi/core/function.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,5 +1694,67 @@ namespace casadi {
16941694
return (*this)->info();
16951695
}
16961696

1697+
FunctionBuffer::FunctionBuffer(const Function& f) : f_(f) {
1698+
w_.resize(f_.sz_w());
1699+
iw_.resize(f_.sz_iw());
1700+
arg_.resize(f_.sz_arg());
1701+
res_.resize(f_.sz_res());
1702+
mem_ = f_->checkout();
1703+
mem_internal_ = f.memory(mem_);
1704+
f_node_ = f.operator->();
1705+
}
1706+
1707+
FunctionBuffer::~FunctionBuffer() {
1708+
if (f_->release_) {
1709+
f_->release_(mem_);
1710+
} else {
1711+
f_.release(mem_);
1712+
}
1713+
}
1714+
1715+
FunctionBuffer::FunctionBuffer(const FunctionBuffer& f) : f_(f.f_) {
1716+
operator=(f);
1717+
}
1718+
1719+
FunctionBuffer& FunctionBuffer::operator=(const FunctionBuffer& f) {
1720+
f_ = f.f_;
1721+
w_ = f.w_; iw_ = f.iw_; arg_ = f.arg_; res_ = f.res_; f_node_ = f.f_node_;
1722+
// Checkout fresh memory
1723+
if (f_->checkout_) {
1724+
mem_ = f_->checkout_();
1725+
} else {
1726+
mem_ = f_.checkout();
1727+
mem_internal_ = f_.memory(mem_);
1728+
}
1729+
1730+
return *this;
1731+
}
1732+
1733+
void FunctionBuffer::set_arg(casadi_int i, const double* a, casadi_int size) {
1734+
casadi_assert(size>=f_.nnz_in(i)*sizeof(double),
1735+
"Buffer is not large enough. Needed " + str(f_.nnz_in(i)*sizeof(double)) +
1736+
" bytes, got " + str(size) + ".");
1737+
arg_.at(i) = a;
1738+
}
1739+
void FunctionBuffer::set_res(casadi_int i, double* a, casadi_int size) {
1740+
casadi_assert(size>=f_.nnz_out(i)*sizeof(double),
1741+
"Buffer is not large enough. Needed " + str(f_.nnz_out(i)*sizeof(double)) +
1742+
" bytes, got " + str(size) + ".");
1743+
res_.at(i) = a;
1744+
}
1745+
void FunctionBuffer::_eval() {
1746+
if (f_node_->eval_) {
1747+
ret_ = f_node_->eval_(get_ptr(arg_), get_ptr(res_), get_ptr(iw_), get_ptr(w_), mem_);
1748+
} else {
1749+
ret_ = f_node_->eval(get_ptr(arg_), get_ptr(res_), get_ptr(iw_), get_ptr(w_), mem_internal_);
1750+
}
1751+
}
1752+
int FunctionBuffer::ret() {
1753+
return ret_;
1754+
}
1755+
1756+
void CASADI_EXPORT _function_buffer_eval(void* raw) {
1757+
static_cast<FunctionBuffer*>(raw)->_eval();
1758+
}
16971759

16981760
} // namespace casadi

casadi/core/function.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mx.hpp"
3030
#include "printable.hpp"
3131
#include <exception>
32+
#include <stack>
3233

3334
namespace casadi {
3435

@@ -1024,6 +1025,52 @@ namespace casadi {
10241025

10251026
};
10261027

1028+
1029+
/** \brief Class to achieve minimal overhead function evaluations
1030+
*/
1031+
class CASADI_EXPORT FunctionBuffer {
1032+
Function f_;
1033+
std::vector<double> w_;
1034+
std::vector<casadi_int> iw_;
1035+
std::vector<const double*> arg_;
1036+
std::vector<double*> res_;
1037+
FunctionInternal* f_node_;
1038+
casadi_int mem_;
1039+
void *mem_internal_;
1040+
int ret_;
1041+
public:
1042+
/** \brief Main constructor */
1043+
FunctionBuffer(const Function& f);
1044+
#ifndef SWIG
1045+
~FunctionBuffer();
1046+
FunctionBuffer(const FunctionBuffer& f);
1047+
FunctionBuffer& operator=(const FunctionBuffer& f);
1048+
#endif // SWIG
1049+
1050+
/** \brief Set input buffer for input i
1051+
1052+
mem.set_arg(0, memoryview(a))
1053+
1054+
Note that CasADi uses 'fortran' order: column-by-column
1055+
*/
1056+
void set_arg(casadi_int i, const double* a, casadi_int size);
1057+
1058+
/** \brief Set output buffer for ouput i
1059+
1060+
mem.set_res(0, memoryview(a))
1061+
1062+
Note that CasADi uses 'fortran' order: column-by-column
1063+
*/
1064+
void set_res(casadi_int i, double* a, casadi_int size);
1065+
/// Get last return value
1066+
int ret();
1067+
void _eval();
1068+
void* _self() { return this; }
1069+
};
1070+
1071+
void CASADI_EXPORT _function_buffer_eval(void* raw);
1072+
1073+
10271074
} // namespace casadi
10281075

10291076
#include "casadi_interrupt.hpp"

swig/casadi.i

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*/
2424

2525

26+
2627
%module(package="casadi",directors=1) casadi
2728

2829
#ifdef CASADI_WITH_COPYSIGN_UNDEF
@@ -2294,6 +2295,59 @@ namespace std {
22942295
#define L_STR "str"
22952296
#endif
22962297

2298+
#ifdef SWIGPYTHON
2299+
%typemap(in, doc="memoryview(ro)", noblock=1, fragment="casadi_all") (const double * a, casadi_int size) (Py_buffer* buffer) {
2300+
if (!PyMemoryView_Check($input)) SWIG_exception_fail(SWIG_TypeError, "Must supply a MemoryView.");
2301+
buffer = PyMemoryView_GET_BUFFER($input);
2302+
$1 = static_cast<double*>(buffer->buf); // const double cast comes later
2303+
$2 = buffer->len;
2304+
}
2305+
2306+
%typemap(in, doc="memoryview(rw)", noblock=1, fragment="casadi_all") (double * a, casadi_int size) (Py_buffer* buffer) {
2307+
if (!PyMemoryView_Check($input)) SWIG_exception_fail(SWIG_TypeError, "Must supply a writable MemoryView.");
2308+
buffer = PyMemoryView_GET_BUFFER($input);
2309+
if (buffer->readonly) SWIG_exception_fail(SWIG_TypeError, "Must supply a writable MemoryView.");
2310+
$1 = static_cast<double*>(buffer->buf);
2311+
$2 = buffer->len;
2312+
}
2313+
2314+
// Directorin typemap; as output
2315+
%typemap(directorin, noblock=1, fragment="casadi_all") (const double** arg, const std::vector<casadi_int>& sizes_arg) (PyObject* my_tuple) {
2316+
PyObject * arg_tuple = PyTuple_New($2.size());
2317+
for (casadi_int i=0;i<$2.size();++i) {
2318+
2319+
#ifdef WITH_PYTHON3
2320+
PyObject* buf = $1[i] ? PyMemoryView_FromMemory(reinterpret_cast<char*>(const_cast<double*>($1[i])), $2[i]*sizeof(double), PyBUF_READ) : SWIG_Py_Void();
2321+
#else
2322+
PyObject* buf = $1[i] ? PyBuffer_FromMemory(const_cast<double*>($1[i]), $2[i]*sizeof(double)) : SWIG_Py_Void();
2323+
#endif
2324+
PyTuple_SET_ITEM(arg_tuple, i, buf);
2325+
}
2326+
$input = arg_tuple;
2327+
}
2328+
2329+
%typemap(directorin, noblock=1, fragment="casadi_all") (double** res, const std::vector<casadi_int>& sizes_res) {
2330+
PyObject* res_tuple = PyTuple_New($2.size());
2331+
for (casadi_int i=0;i<$2.size();++i) {
2332+
#ifdef WITH_PYTHON3
2333+
PyObject* buf = $1[i] ? PyMemoryView_FromMemory(reinterpret_cast<char*>(const_cast<double*>($1[i])), $2[i]*sizeof(double), PyBUF_WRITE) : SWIG_Py_Void();
2334+
#else
2335+
PyObject* buf = $1[i] ? PyBuffer_FromReadWriteMemory($1[i], $2[i]*sizeof(double)) : SWIG_Py_Void();
2336+
#endif
2337+
PyTuple_SET_ITEM(res_tuple, i, buf);
2338+
}
2339+
$input = res_tuple;
2340+
}
2341+
2342+
%typemap(in, doc="void*", noblock=1, fragment="casadi_all") void* raw {
2343+
$1 = PyCapsule_GetPointer($input, NULL);
2344+
}
2345+
2346+
%typemap(out, doc="void*", noblock=1, fragment="casadi_all") void* {
2347+
$result = PyCapsule_New($1, NULL,NULL);
2348+
}
2349+
#endif
2350+
22972351
%casadi_typemaps(L_STR, PREC_STRING, std::string)
22982352
%casadi_template(LL L_STR LR, PREC_VECTOR, std::vector<std::string>)
22992353
%casadi_typemaps("Sparsity", PREC_SPARSITY, casadi::Sparsity)
@@ -3950,6 +4004,11 @@ def PyFunction(name, obj, inputs, outputs, opts={}):
39504004
%}
39514005
#endif
39524006

4007+
#ifndef SWIGPYTHON
4008+
%ignore FunctionBuffer;
4009+
%ignore _function_buffer_eval;
4010+
#endif
4011+
39534012
%include <casadi/core/function.hpp>
39544013
#ifdef SWIGPYTHON
39554014
namespace casadi{
@@ -3971,8 +4030,21 @@ namespace casadi{
39714030
else:
39724031
# Named inputs -> return dictionary
39734032
return self.call(kwargs)
4033+
4034+
def buffer(self):
4035+
"""
4036+
Create a FunctionBuffer object for evaluating with minimal overhead
4037+
4038+
"""
4039+
import functools
4040+
fb = FunctionBuffer(self)
4041+
caller = functools.partial(_casadi._function_buffer_eval, fb._self())
4042+
return (fb, caller)
39744043
%}
4044+
4045+
39754046
}
4047+
39764048
}
39774049
#endif // SWIGPYTHON
39784050

swig/python/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ endif()
5858
set(PYTHONFLAG "")
5959
set(CMAKE_SWIG_FLAGS ${CMAKE_SWIG_FLAGS} "-DPy_USING_UNICODE")
6060
set(CMAKE_SWIG_FLAGS ${CMAKE_SWIG_FLAGS} "-noproxydel")
61+
set(CMAKE_SWIG_FLAGS ${CMAKE_SWIG_FLAGS} "-fastunpack")
62+
set(CMAKE_SWIG_FLAGS ${CMAKE_SWIG_FLAGS} "-modernargs")
6163
if("${PYTHON_VERSION_MAJOR}" STREQUAL "3")
6264
set(PYTHONFLAG "3")
6365
set(CMAKE_SWIG_FLAGS ${CMAKE_SWIG_FLAGS} "-py3")

0 commit comments

Comments
 (0)