Skip to content

Commit

Permalink
Implement delegation of ReactionRate::setParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
speth committed Sep 11, 2022
1 parent da59a90 commit bf1e2b4
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 13 deletions.
31 changes: 31 additions & 0 deletions include/cantera/base/Delegator.h
Expand Up @@ -140,6 +140,21 @@ class Delegator
*m_funcs_v_d[name] = makeDelegate(func, when, *m_funcs_v_d[name]);
}

//! set delegates for member functions with the signature
//! `void(AnyMap&, UnitStack&)`
void setDelegate(const std::string& name,
const std::function<void(const AnyMap&, const UnitStack&)>& func,
const std::string& when)
{
if (!m_funcs_v_cAMr_cUSr.count(name)) {
throw NotImplementedError("Delegator::setDelegate",
"for function named '{}' with signature "
"'void(const AnyMap&, const UnitStack&)'.",
name);
}
*m_funcs_v_cAMr_cUSr[name] = makeDelegate(func, when, *m_funcs_v_cAMr_cUSr[name]);
}

//! Set delegates for member functions with the signature `void(double*)`
void setDelegate(const std::string& name,
const std::function<void(std::array<size_t, 1>, double*)>& func,
Expand Down Expand Up @@ -263,6 +278,18 @@ class Delegator
m_funcs_v_d[name] = &target;
}

//! Install a function with the signature `void(const AnyMap&, const UnitStack&)`
//! as being delegatable
void install(const std::string& name,
std::function<void(const AnyMap&, const UnitStack&)>& target,
const std::function<void(const AnyMap&, const UnitStack&)>& func)
{
target = func;
m_funcs_v_cAMr_cUSr[name] = &target;
}



//! Install a function with the signature `void(double*)` as being delegatable
void install(const std::string& name,
std::function<void(std::array<size_t, 1>, double*)>& target,
Expand Down Expand Up @@ -425,6 +452,8 @@ class Delegator
//! - `d` for `double`
//! - `s` for `std::string`
//! - `sz` for `size_t`
//! - `AM` for `AnyMap`
//! - `US` for `UnitStack`
//! - prefix `c` for `const` arguments
//! - suffix `r` for reference arguments
//! - suffix `p` for pointer arguments
Expand All @@ -434,6 +463,8 @@ class Delegator
std::map<std::string, std::function<void()>*> m_funcs_v;
std::map<std::string, std::function<void(bool)>*> m_funcs_v_b;
std::map<std::string, std::function<void(double)>*> m_funcs_v_d;
std::map<std::string,
std::function<void(const AnyMap&, const UnitStack&)>*> m_funcs_v_cAMr_cUSr;
std::map<std::string,
std::function<void(std::array<size_t, 1>, double*)>*> m_funcs_v_dp;
std::map<std::string,
Expand Down
8 changes: 8 additions & 0 deletions include/cantera/kinetics/ReactionRateDelegator.h
Expand Up @@ -23,6 +23,9 @@ class ReactionRateDelegator : public Delegator, public ReactionRate
return 0.0; // necessary to set lambda's function signature
}
);
install("setParameters", m_setParameters,
[this](const AnyMap& node, const UnitStack& units) {
ReactionRate::setParameters(node, units); });
}

virtual unique_ptr<MultiRateBase> newMultiRate() const override {
Expand All @@ -43,8 +46,13 @@ class ReactionRateDelegator : public Delegator, public ReactionRate
return m_evalFromStruct(&T);
}

void setParameters(const AnyMap& node, const UnitStack& units) override {
m_setParameters(node, units);
}

private:
std::function<double(void*)> m_evalFromStruct;
std::function<void(const AnyMap&, const UnitStack&)> m_setParameters;
};

}
Expand Down
4 changes: 4 additions & 0 deletions interfaces/cython/cantera/delegator.pxd
Expand Up @@ -6,6 +6,7 @@

from .ctcxx cimport *
from .func1 cimport *
from .units cimport CxxUnitStack

cdef extern from "<array>" namespace "std" nogil:
cdef cppclass size_array1 "std::array<size_t, 1>":
Expand All @@ -28,6 +29,7 @@ cdef extern from "cantera/base/Delegator.h" namespace "Cantera":
void setDelegate(string&, function[void()], string&) except +translate_exception
void setDelegate(string&, function[void(cbool)], string&) except +translate_exception
void setDelegate(string&, function[void(double)], string&) except +translate_exception
void setDelegate(string&, function[void(const CxxAnyMap&, const CxxUnitStack&)], string&) except +translate_exception
void setDelegate(string&, function[void(size_array1, double*)], string&) except +translate_exception
void setDelegate(string&, function[void(size_array1, double, double*)], string&) except +translate_exception
void setDelegate(string&, function[void(size_array2, double, double*, double*)], string&) except +translate_exception
Expand All @@ -43,6 +45,8 @@ cdef extern from "cantera/cython/funcWrapper.h":
cdef function[void(double)] pyOverride(PyObject*, void(PyFuncInfo&, double))
cdef function[void(cbool)] pyOverride(PyObject*, void(PyFuncInfo&, cbool))
cdef function[void()] pyOverride(PyObject*, void(PyFuncInfo&))
cdef function[void(const CxxAnyMap&, const CxxUnitStack&)] pyOverride(
PyObject*, void(PyFuncInfo&, const CxxAnyMap&, const CxxUnitStack&))
cdef function[void(size_array1, double*)] pyOverride(
PyObject*, void(PyFuncInfo&, size_array1, double*))
cdef function[void(size_array1, double, double*)] pyOverride(
Expand Down
19 changes: 18 additions & 1 deletion interfaces/cython/cantera/delegator.pyx
Expand Up @@ -5,7 +5,8 @@ import inspect
import sys

from ._utils import CanteraError
from ._utils cimport stringify, pystr
from ._utils cimport stringify, pystr, anymap_to_dict
from .units cimport Units
from .reaction import ExtensibleRate
from cython.operator import dereference as deref

Expand Down Expand Up @@ -103,6 +104,19 @@ cdef void callback_v_b(PyFuncInfo& funcInfo, cbool arg):
funcInfo.setExceptionType(<PyObject*>exc_type)
funcInfo.setExceptionValue(<PyObject*>exc_value)

# Wrapper for functions of type void(const AnyMap&, const UnitStack&)
cdef void callback_v_cAMr_cUSr(PyFuncInfo& funcInfo, const CxxAnyMap& arg1,
const CxxUnitStack& arg2):

pyArg1 = anymap_to_dict(<CxxAnyMap&>arg1) # cast away constness
pyArg2 = Units.copy(arg2.product())
try:
(<object>funcInfo.func())(pyArg1, pyArg2)
except BaseException as e:
exc_type, exc_value = sys.exc_info()[:2]
funcInfo.setExceptionType(<PyObject*>exc_type)
funcInfo.setExceptionValue(<PyObject*>exc_value)

# Wrapper for functions of type void(double*)
cdef void callback_v_dp(PyFuncInfo& funcInfo, size_array1 sizes, double* arg):
cdef double[:] view = <double[:sizes[0]]>arg if sizes[0] else None
Expand Down Expand Up @@ -276,6 +290,9 @@ cdef int assign_delegates(obj, CxxDelegator* delegator) except -1:
elif callback == 'void(double)':
delegator.setDelegate(cxx_name,
pyOverride(<PyObject*>method, callback_v_d), cxx_when)
elif callback == 'void(AnyMap&,UnitStack&)':
delegator.setDelegate(cxx_name,
pyOverride(<PyObject*>method, callback_v_cAMr_cUSr), cxx_when)
elif callback == 'void(double*)':
delegator.setDelegate(cxx_name,
pyOverride(<PyObject*>method, callback_v_dp), cxx_when)
Expand Down
5 changes: 4 additions & 1 deletion interfaces/cython/cantera/reaction.pyx
Expand Up @@ -710,11 +710,13 @@ cdef class CustomRate(ReactionRate):

self.cxx_object().setRateFunction(self._rate_func._func)


cdef class ExtensibleRate(ReactionRate):
_reaction_rate_type = "extensible"

delegatable_methods = {
"eval": ("evalFromStruct", "double(void*)")
"eval": ("evalFromStruct", "double(void*)"),
"set_parameters": ("setParameters", "void(AnyMap&, UnitStack&)")
}
def __cinit__(self, *args, init=True, **kwargs):
if init:
Expand All @@ -734,6 +736,7 @@ cdef class ExtensibleRate(ReactionRate):
self.rate = rate
assign_delegates(self, dynamic_cast[CxxDelegatorPtr](self.rate))


cdef class InterfaceRateBase(ArrheniusRateBase):
"""
Base class collecting commonly used features of Arrhenius-type rate objects
Expand Down
4 changes: 4 additions & 0 deletions interfaces/cython/cantera/units.pxd
Expand Up @@ -23,6 +23,10 @@ cdef extern from "cantera/base/Units.h" namespace "Cantera":
stdmap[string, string] defaults()
void setDefaults(stdmap[string, string]&) except +translate_exception

cdef cppclass CxxUnitStack "Cantera::UnitStack":
CxxUnitStack()
CxxUnits product()


cdef class Units:
cdef CxxUnits units
Expand Down
2 changes: 2 additions & 0 deletions src/extensions/PythonExtensionManager.cpp
Expand Up @@ -123,6 +123,8 @@ void PythonExtensionManager::registerPythonRateBuilder(
"Problem in ct_newPythonExtensibleRate:\n{}",
getPythonExceptionInfo());
}
//! Call setParameters after the delegated functions have been connected
delegator->setParameters(params, units);

// Make the delegator responsible for eventually deleting the Python object
Py_IncRef(extRate);
Expand Down
1 change: 0 additions & 1 deletion src/extensions/pythonExtensions.pyx
Expand Up @@ -39,5 +39,4 @@ cdef public object ct_newPythonExtensibleRate(CxxReactionRateDelegator* delegato
mod = importlib.import_module(module_name.decode())
cdef ExtensibleRate rate = getattr(mod, class_name.decode())(init=False)
rate.set_cxx_object(delegator)
assign_delegates(rate, delegator)
return rate
1 change: 1 addition & 0 deletions test/data/extensible-reactions.yaml
Expand Up @@ -12,3 +12,4 @@ phases:
reactions:
- equation: H + O2 = HO2
type: square-rate
A: 3.14
24 changes: 15 additions & 9 deletions test/python/test_reaction.py
Expand Up @@ -1501,8 +1501,15 @@ def func(T):

@ct.extension(name="user-rate-1")
class UserRate1(ct.ExtensibleRate):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.A = np.nan

def after_set_parameters(self, params, units):
self.A = params["A"]

def replace_eval(self, T):
return 38.7 * T**2.7 * exp(-3150.15428/T)
return self.A * T**2.7 * exp(-3150.15428/T)


class TestExtensible(ReactionTests, utilities.CanteraTest):
Expand All @@ -1515,27 +1522,26 @@ class TestExtensible(ReactionTests, utilities.CanteraTest):
_rate_type = "user-rate-1"
_rate = {
"type": "user-rate-1",
"A": 38.7
}
_yaml = """
equation: H2 + O <=> H + OH
type: user-rate-1
A: 38.7
"""

def setUp(self):
super().setUp()
self._rate_obj = UserRate1()
self._rate_obj = ct.ReactionRate.from_dict(self._rate)

def test_no_rate(self):
pytest.skip("ExtensibleRate does not support 'empty' rates")
pytest.skip("ExtensibleRate does not yet support validation")

def test_from_dict(self):
pytest.skip("ExtensibleRate does not support serialization")

def from_rate(self, rate):
pytest.skip("ExtensibleRate does not support dict-based instantiation")
pytest.skip("ExtensibleRate does not yet support serialization")

def test_roundtrip(self):
pytest.skip("ExtensibleRate does not support roundtrip conversion")
pytest.skip("ExtensibleRate does not yet support roundtrip conversion")


class TestExtensible2(utilities.CanteraTest):
Expand All @@ -1548,7 +1554,7 @@ def test_load_module(self):

for T in np.linspace(300, 3000, 10):
gas.TP = T, None
assert gas.forward_rate_constants[0] == pytest.approx(T**2)
assert gas.forward_rate_constants[0] == pytest.approx(3.14 * T**2)


class InterfaceReactionTests(ReactionTests):
Expand Down
5 changes: 4 additions & 1 deletion test/python/user_ext.py
Expand Up @@ -2,5 +2,8 @@

@ct.extension(name="square-rate")
class SquareRate(ct.ExtensibleRate):
def after_set_parameters(self, node, units):
self.A = node["A"]

def replace_eval(self, T):
return T**2
return self.A * T**2

0 comments on commit bf1e2b4

Please sign in to comment.