Skip to content

Commit

Permalink
Merge pull request #17 from aragilar/set-err-options-v3
Browse files Browse the repository at this point in the history
Added wrapper around CVodeSetErrHandlerFn
  • Loading branch information
pplk committed May 9, 2015
2 parents 42f7943 + a93647b commit 9e5f738
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
15 changes: 15 additions & 0 deletions scikits/odes/sundials/cvode.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ cdef class CV_WrapJacTimesVecFunction(CV_JacTimesVecFunction):
cdef int with_userdata
cpdef set_jac_times_vecfn(self, object jac_times_vecfn)

cdef class CV_ErrHandler:
cpdef evaluate(self,
int error_code,
bytes module,
bytes function,
bytes msg,
object user_data = *)

cdef class CV_WrapErrHandler(CV_ErrHandler):
cpdef object _err_handler
cdef int with_userdata
cpdef set_err_handler(self, object err_handler)


cdef class CV_data:
cdef np.ndarray yy_tmp, yp_tmp, jac_tmp, g_tmp, r_tmp, z_tmp
Expand All @@ -82,6 +95,8 @@ cdef class CV_data:
cdef CV_JacTimesVecFunction jac_times_vecfn
cdef bint parallel_implementation
cdef object user_data
cdef CV_ErrHandler err_handler
cdef object err_user_data

cdef class CVODE:
cdef N_Vector atol
Expand Down
84 changes: 82 additions & 2 deletions scikits/odes/sundials/cvode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,58 @@ cdef int _jac_times_vecfn(N_Vector v, N_Vector Jv, realtype t, N_Vector y, N_Vec

return 0

cdef class CV_ErrHandler:
cpdef evaluate(self,
int error_code,
bytes module,
bytes function,
bytes msg,
object user_data = None):
""" format that error handling functions must match """
pass

cdef class CV_WrapErrHandler(CV_ErrHandler):
cpdef set_err_handler(self, object err_handler):
"""
set some (c/p)ython function as the error handler
"""
nrarg = len(inspect.getargspec(err_handler)[0])
self.with_userdata = (nrarg > 5) or (
nrarg == 5 and inspect.isfunction(err_handler)
)
self._err_handler = err_handler

cpdef evaluate(self,
int error_code,
bytes module,
bytes function,
bytes msg,
object user_data = None):
if self.with_userdata == 1:
self._err_handler(error_code, module, function, msg, user_data)
else:
self._err_handler(error_code, module, function, msg)

cdef void _cv_err_handler_fn(
int error_code, const char *module, const char *function, char *msg,
void *eh_data
):
"""
function with the signature of CVErrHandlerFn, that calls python error
handler
"""
aux_data = <CV_data> eh_data
aux_data.err_handler.evaluate(error_code,
module,
function,
msg,
aux_data.err_user_data)

cdef class CV_data:
def __cinit__(self, N):
self.parallel_implementation = False
self.user_data = None
self.err_user_data = None

self.yy_tmp = np.empty(N, float)
self.yp_tmp = np.empty(N, float)
Expand Down Expand Up @@ -486,7 +533,9 @@ cdef class CVODE:
'jacfn': None,
'prec_setupfn': None,
'prec_solvefn': None,
'jac_times_vecfn': None
'jac_times_vecfn': None,
'err_handler': None,
'err_user_data': None,
}

self.verbosity = 1
Expand Down Expand Up @@ -670,7 +719,15 @@ cdef class CVODE:
'max_nonlin_iters':
default = 0,
'nonlin_conv_coef':
default = 0.
default = 0,
'err_handler':
Values: function of class CV_ErrHandler, default = None
Description:
Defines a function which controls output from the CVODE
solver
'err_user_data':
Description:
User data used by 'err_handler', defaults to 'user_data'
"""

for (key, value) in options.items():
Expand Down Expand Up @@ -883,6 +940,29 @@ cdef class CVODE:

# auxiliary variables
self.aux_data = CV_data(N)

# Set err_handler
err_handler = opts.get('err_handler', None)
if err_handler is not None:
if not isinstance(err_handler, CV_ErrHandler):
tmpfun = CV_WrapErrHandler()
tmpfun.set_err_handler(err_handler)
err_handler = tmpfun

self.aux_data.err_handler = err_handler

flag = CVodeSetErrHandlerFn(
cv_mem, _cv_err_handler_fn, <void*> self.aux_data)

if flag == CV_SUCCESS:
pass
elif flag == CV_MEM_FAIL:
raise MemoryError(
'CVodeSetErrHandlerFn: Memory allocation error')
else:
raise RuntimeError('CVodeSetErrHandlerFn: Unknown flag raised')
self.aux_data.err_user_data = opts['err_user_data'] or opts['user_data']

self.aux_data.parallel_implementation = self.parallel_implementation

rfn = opts['rfn']
Expand Down

0 comments on commit 9e5f738

Please sign in to comment.