Skip to content

Commit

Permalink
Implement ExtensibleRateData for Python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
speth authored and ischoegl committed Jan 21, 2023
1 parent 2e65ab1 commit f687bfd
Show file tree
Hide file tree
Showing 21 changed files with 389 additions and 52 deletions.
47 changes: 38 additions & 9 deletions include/cantera/base/ExtensionManager.h
Expand Up @@ -7,10 +7,29 @@
// at https://cantera.org/license.txt for license and copyright information.

#include "cantera/base/ctexceptions.h"
#include <functional>

namespace Cantera
{

class ReactionDataDelegator;
class Solution;

//! A base class for managing the lifetime of an external object, such as a Python
//! object used by a Delegator
class ExternalHandle
{
public:
ExternalHandle() {}
ExternalHandle(const ExternalHandle&) = delete;
virtual ~ExternalHandle() = default;

//! Get the underlying external object
virtual void* get() {
throw NotImplementedError("ExternalHandle::get");
}
};

//! Base class for managing user-defined Cantera extensions written in other languages
//!
//! @since New in Cantera 3.0
Expand All @@ -24,16 +43,26 @@ class ExtensionManager
virtual void registerRateBuilders(const std::string& extensionName) {
throw NotImplementedError("ExtensionManager::registerRateBuilders");
};
};

//! A base class for managing the lifetime of an external object, such as a Python
//! object used by a Delegator
class ExternalHandle
{
public:
ExternalHandle() {}
ExternalHandle(const ExternalHandle&) = delete;
virtual ~ExternalHandle() = default;
static void wrapReactionData(const std::string& rateName,
ReactionDataDelegator& data);

static shared_ptr<ExternalHandle> wrapSolution(const std::string& rateName,
shared_ptr<Solution> soln);

static void registerReactionDataLinker(const std::string& rateName,
std::function<void(ReactionDataDelegator&)> link);

static void registerSolutionLinker(const std::string& rateName,
std::function<shared_ptr<ExternalHandle>(shared_ptr<Solution>)> link);

protected:
static std::map<std::string,
std::function<void(ReactionDataDelegator&)>> s_ReactionData_linkers;

static std::map<std::string,
std::function<shared_ptr<ExternalHandle>(shared_ptr<Solution>)>> s_Solution_linkers;

};

}
Expand Down
4 changes: 4 additions & 0 deletions include/cantera/extensions/PythonExtensionManager.h
Expand Up @@ -33,6 +33,10 @@ class PythonExtensionManager : public ExtensionManager
static void registerPythonRateBuilder(const std::string& moduleName,
const std::string& className, const std::string& rateName);

//! Function called from Cython to register an ExtensibleRateData implementation
static void registerPythonRateDataBuilder(const std::string& moduleName,
const std::string& className, const std::string& rateName);

private:
static bool s_imported;
};
Expand Down
4 changes: 4 additions & 0 deletions include/cantera/extensions/PythonHandle.h
Expand Up @@ -24,6 +24,10 @@ class PythonHandle : public ExternalHandle
}
}

void* get() {
return m_obj;
}

private:
PyObject* m_obj;
bool m_weak;
Expand Down
4 changes: 4 additions & 0 deletions include/cantera/kinetics/Kinetics.h
Expand Up @@ -1233,6 +1233,10 @@ class Kinetics
m_root = root;
}

shared_ptr<Solution> root() const {
return m_root.lock();
}

//! Calculate the reaction enthalpy of a reaction which
//! has not necessarily been added into the Kinetics object
virtual double reactionEnthalpy(const Composition& reactants,
Expand Down
4 changes: 4 additions & 0 deletions include/cantera/kinetics/MultiRate.h
Expand Up @@ -129,6 +129,10 @@ class MultiRate final : public MultiRateBase
return R.evalFromStruct(m_shared);
}

DataType& sharedData() {
return m_shared;
}

protected:
//! Helper function to process updates for rate types that implement the
//! `updateFromStruct` method.
Expand Down
66 changes: 46 additions & 20 deletions include/cantera/kinetics/ReactionRateDelegator.h
Expand Up @@ -13,29 +13,58 @@
namespace Cantera
{

//! Delegate methods of the ReactionRate class to external functions
//! Delegate methods of the ReactionData class to external functions
//!
//! @since New in Cantera 3.0
class ReactionRateDelegator : public Delegator, public ReactionRate
class ReactionDataDelegator : public Delegator, public ReactionData
{
public:
ReactionRateDelegator() {
install("evalFromStruct", m_evalFromStruct,
[](void*) {
throw NotImplementedError("ReactionRateDelegator::evalFromStruct");
return 0.0; // necessary to set lambda's function signature
}
);
install("setParameters", m_setParameters,
[this](const AnyMap& node, const UnitStack& units) {
ReactionRate::setParameters(node, units); });
ReactionDataDelegator();

bool update(const ThermoPhase& phase, const Kinetics& kin) override;

void update(double T) override {
throw NotImplementedError("ReactionDataDelegator",
"Not implemented for delegated reaction rates");
}

using ReactionData::update;

void setType(const std::string& name) {
m_rateType = name;
}

shared_ptr<ExternalHandle> getWrapper() const {
return m_wrappedData;
}

virtual unique_ptr<MultiRateBase> newMultiRate() const override {
return unique_ptr<MultiRateBase>(
new MultiRate<ReactionRateDelegator, ArrheniusData>);
void setWrapper(shared_ptr<ExternalHandle> wrapper) {
m_wrappedData = wrapper;
}

void setSolutionWrapperType(const std::string& type) {
m_solutionWrapperType = type;
}

protected:
std::string m_rateType;
std::string m_solutionWrapperType;
shared_ptr<ExternalHandle> m_wrappedSolution;
shared_ptr<ExternalHandle> m_wrappedData;

std::function<double(void*)> m_update;
};

//! Delegate methods of the ReactionRate class to external functions
//!
//! @since New in Cantera 3.0
class ReactionRateDelegator : public Delegator, public ReactionRate
{
public:
ReactionRateDelegator();

virtual unique_ptr<MultiRateBase> newMultiRate() const override;

//! Set the reaction type based on the user-provided reaction rate parameterization
void setType(const std::string& type) {
m_rateType = type;
Expand All @@ -50,11 +79,8 @@ class ReactionRateDelegator : public Delegator, public ReactionRate
//! Evaluate reaction rate
//!
//! @param shared_data data shared by all reactions of a given type
double evalFromStruct(const ArrheniusData& shared_data) {
// @TODO: replace passing pointer to temperature with a language-specific
// wrapper of the ReactionData object
double T = shared_data.temperature;
return m_evalFromStruct(&T);
double evalFromStruct(const ReactionDataDelegator& shared_data) {
return m_evalFromStruct(shared_data.getWrapper()->get());
}

void setParameters(const AnyMap& node, const UnitStack& units) override {
Expand Down
3 changes: 2 additions & 1 deletion interfaces/cython/cantera/delegator.pxd
Expand Up @@ -60,11 +60,12 @@ cdef extern from "cantera/cython/funcWrapper.h":
cdef function[int(size_t&, const string&)] pyOverride(
PyObject*, int(PyFuncInfo&, size_t&, const string&))


cdef extern from "cantera/extensions/PythonExtensionManager.h" namespace "Cantera":
cdef cppclass CxxPythonExtensionManager "Cantera::PythonExtensionManager":
@staticmethod
void registerPythonRateBuilder(string&, string&, string&) except +translate_exception
@staticmethod
void registerPythonRateDataBuilder(string&, string&, string&) except +translate_exception

ctypedef CxxDelegator* CxxDelegatorPtr

Expand Down
14 changes: 11 additions & 3 deletions interfaces/cython/cantera/delegator.pyx
Expand Up @@ -7,7 +7,7 @@ import sys
from ._utils import CanteraError
from ._utils cimport stringify, pystr, anymap_to_dict
from .units cimport Units
from .reaction import ExtensibleRate
from .reaction import ExtensibleRate, ExtensibleRateData
from cython.operator import dereference as deref

# ## Implementation for each delegated function type
Expand Down Expand Up @@ -157,7 +157,7 @@ cdef void callback_v_dp_dp_dp(PyFuncInfo& funcInfo,
# Wrapper for functions of type double(void*)
cdef int callback_d_vp(PyFuncInfo& funcInfo, double& out, void* obj):
try:
ret = (<object>funcInfo.func())(deref(<double*>obj))
ret = (<object>funcInfo.func())(<object>obj)
if ret is None:
return 0
else:
Expand Down Expand Up @@ -332,8 +332,9 @@ cdef int assign_delegates(obj, CxxDelegator* delegator) except -1:
# ReactionRateFactory. This list is read by PythonExtensionManager::registerRateBuilders
# and then cleared.
_rate_delegators = []
_rate_data_delegators = []

def extension(*, name):
def extension(*, name, data=None):
"""
A decorator for declaring Cantera extensions that should be registered with
the corresponding factory classes to create objects with the specified *name*.
Expand Down Expand Up @@ -386,6 +387,13 @@ def extension(*, name):
# Deferred registration supports the case where the main application
# is not Python
_rate_delegators.append((cls.__module__, cls.__name__, name))

# Register the ReactionData delegator
if not issubclass(data, ExtensibleRateData):
raise ValueError("'data' must inherit from 'ExtensibleRateData'")
CxxPythonExtensionManager.registerPythonRateDataBuilder(
stringify(data.__module__), stringify(data.__name__), stringify(name))
_rate_data_delegators.append((data.__module__, data.__name__, name))
else:
raise TypeError(f"{cls} is not extensible")
return cls
Expand Down
6 changes: 6 additions & 0 deletions interfaces/cython/cantera/reaction.pxd
Expand Up @@ -189,6 +189,9 @@ cdef extern from "cantera/kinetics/Custom.h" namespace "Cantera":


cdef extern from "cantera/kinetics/ReactionRateDelegator.h" namespace "Cantera":
cdef cppclass CxxReactionDataDelegator "Cantera::ReactionDataDelegator":
CxxReactionDataDelegator()

cdef cppclass CxxReactionRateDelegator "Cantera::ReactionRateDelegator" (CxxReactionRate):
CxxReactionRateDelegator()
void setType(string&)
Expand Down Expand Up @@ -256,6 +259,9 @@ cdef class CustomRate(ReactionRate):
cdef class ExtensibleRate(ReactionRate):
cdef set_cxx_object(self, CxxReactionRate* rate=*)

cdef class ExtensibleRateData:
cdef set_cxx_object(self, CxxReactionDataDelegator* rate)

cdef class InterfaceRateBase(ArrheniusRateBase):
cdef CxxInterfaceRateBase* interface

Expand Down
8 changes: 8 additions & 0 deletions interfaces/cython/cantera/reaction.pyx
Expand Up @@ -774,6 +774,14 @@ cdef class ExtensibleRate(ReactionRate):
(<CxxReactionRateDelegator*>self.rate).setType(
stringify(self._reaction_rate_type))

cdef class ExtensibleRateData:
delegatable_methods = {
"update": ("update", "double(void*)")
}

cdef set_cxx_object(self, CxxReactionDataDelegator* data):
assign_delegates(self, dynamic_cast[CxxDelegatorPtr](data))


cdef class InterfaceRateBase(ArrheniusRateBase):
"""
Expand Down
2 changes: 1 addition & 1 deletion interfaces/cython/cantera/solutionbase.pxd
Expand Up @@ -73,7 +73,7 @@ ctypedef void (*transportPolyMethod1i)(CxxTransport*, size_t, double*) except +t
ctypedef void (*transportPolyMethod2i)(CxxTransport*, size_t, size_t, double*) except +translate_exception

cdef _assign_Solution(_SolutionBase soln, shared_ptr[CxxSolution] cxx_soln,
pybool reset_adjacent)
pybool reset_adjacent, pybool weak=?)
cdef object _wrap_Solution(shared_ptr[CxxSolution] cxx_soln)

cdef class _SolutionBase:
Expand Down
9 changes: 6 additions & 3 deletions interfaces/cython/cantera/solutionbase.pyx
Expand Up @@ -100,7 +100,7 @@ cdef class _SolutionBase:

cdef shared_ptr[CxxExternalHandle] handle
handle.reset(new CxxPythonHandle(<PyObject*>self, True))
self.base.holdExternalHandle(stringify("python-Solution"), handle)
self.base.holdExternalHandle(stringify("python"), handle)

def __init__(self, *args, **kwargs):
if isinstance(self, Transport) and kwargs.get("init", True):
Expand Down Expand Up @@ -425,8 +425,11 @@ cdef class _SolutionBase:
# These cdef functions are declared as free functions to avoid creating layout
# conflicts with types derived from _SolutionBase
cdef _assign_Solution(_SolutionBase soln, shared_ptr[CxxSolution] cxx_soln,
pybool reset_adjacent):
soln._base = cxx_soln
pybool reset_adjacent, pybool weak=False):
if not weak:
# When the main application isn't Python, we should only hold a weak reference
# here, since the C++ Solution object owns this Python Solution.
soln._base = cxx_soln
soln.base = cxx_soln.get()
soln.thermo = soln.base.thermo().get()
soln.kinetics = soln.base.kinetics().get()
Expand Down
18 changes: 15 additions & 3 deletions samples/python/kinetics/custom_reactions.py
Expand Up @@ -32,15 +32,27 @@

# construct reactions based on ExtensibleRate: replace 2nd reaction with equivalent
# ExtensibleRate
@ct.extension(name="extensible-Arrhenius")
class ExtensibleArrheniusData(ct.ExtensibleRateData):
def __init__(self):
self.T = None

def replace_update(self, gas):
T = gas.T
if self.T != T:
self.T = T
return True
else:
return False

@ct.extension(name="extensible-Arrhenius", data=ExtensibleArrheniusData)
class ExtensibleArrhenius(ct.ExtensibleRate):
def after_set_parameters(self, params, units):
self.A = params["A"]
self.b = params["b"]
self.Ea_R = params["Ea_R"]

def replace_eval(self, T):
return self.A * T**self.b * exp(-self.Ea_R/T)
def replace_eval(self, data):
return self.A * data.T**self.b * exp(-self.Ea_R/data.T)

extensible_yaml = """
equation: H2 + O <=> H + OH
Expand Down

0 comments on commit f687bfd

Please sign in to comment.