Skip to content

Commit

Permalink
Use decorator to register ExtensibleRate objects
Browse files Browse the repository at this point in the history
  • Loading branch information
speth committed Sep 11, 2022
1 parent ad19892 commit da59a90
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 74 deletions.
10 changes: 10 additions & 0 deletions include/cantera/extensions/PythonExtensionManager.h
Expand Up @@ -14,11 +14,21 @@ namespace Cantera
//! Class for managing user-defined Cantera extensions written in Python
//!
//! Handles Python initialization if the main application is not the Python interpreter.
//!
//! Imports a user-specified module, which must be on the Python path and registers
//! user-defined classes that are marked with the `@extension` decorator. See the
//! documentation for
//! <a href="../../sphinx/html/cython/utilities.html#cantera.extension">`@extension`</a>
//! in the Python documentation for more information.
class PythonExtensionManager : public ExtensionManager
{
public:
PythonExtensionManager();
virtual void registerRateBuilders(const std::string& extensionName) override;

//! Function called from Cython to register an ExtensibleRate implementation
static void registerPythonRateBuilder(const std::string& moduleName,
const std::string& className, const std::string& rateName);
};

}
Expand Down
4 changes: 0 additions & 4 deletions include/cantera/kinetics/ReactionRateDelegator.h
Expand Up @@ -25,10 +25,6 @@ class ReactionRateDelegator : public Delegator, public ReactionRate
);
}

ReactionRateDelegator(const AnyMap& node, const UnitStack& rate_units)
: ReactionRateDelegator()
{}

virtual unique_ptr<MultiRateBase> newMultiRate() const override {
return unique_ptr<MultiRateBase>(
new MultiRate<ReactionRateDelegator, ArrheniusData>);
Expand Down
5 changes: 5 additions & 0 deletions interfaces/cython/cantera/delegator.pxd
Expand Up @@ -57,6 +57,11 @@ cdef extern from "cantera/cython/funcWrapper.h":
PyObject*, int(PyFuncInfo&, size_t&, const string&))


cdef extern from "cantera/extensions/PythonExtensionManager.h" namespace "Cantera":
cdef cppclass CxxPythonExtensionManager "Cantera::PythonExtensionManager":
@staticmethod
void registerPythonRateBuilder(string&, string&, string&) except +translate_exception

ctypedef CxxDelegator* CxxDelegatorPtr

cdef int assign_delegates(object, CxxDelegator*) except -1
17 changes: 17 additions & 0 deletions interfaces/cython/cantera/delegator.pyx
Expand Up @@ -6,6 +6,7 @@ import sys

from ._utils import CanteraError
from ._utils cimport stringify, pystr
from .reaction import ExtensibleRate
from cython.operator import dereference as deref

# ## Implementation for each delegated function type
Expand Down Expand Up @@ -309,3 +310,19 @@ cdef int assign_delegates(obj, CxxDelegator* delegator) except -1:
obj._delegates.append(method)

return 0

def extension(*, name):
"""
A decorator for declaring Cantera extensions that should be registered with
the corresponding factory classes to create objects with the specified *name*.
"""
def decorator(cls):
if issubclass(cls, ExtensibleRate):
cls._reaction_rate_type = name
CxxPythonExtensionManager.registerPythonRateBuilder(
stringify(cls.__module__), stringify(cls.__name__), stringify(name))
else:
raise TypeError(f"{cls} is not extensible")
return cls

return decorator
75 changes: 37 additions & 38 deletions src/extensions/PythonExtensionManager.cpp
Expand Up @@ -56,6 +56,9 @@ PythonExtensionManager::PythonExtensionManager()
}

// PEP 489 Multi-phase initialization

// The 'pythonExtensions' Cython module defines some functions that are used
// to instantiate ExtensibleSomething objects.
PyModuleDef* modDef = (PyModuleDef*) PyInit_pythonExtensions();
if (!modDef->m_slots || !PyModuleDef_Init(modDef)) {
throw CanteraError("PythonExtensionManager::PythonExtensionManager",
Expand Down Expand Up @@ -89,49 +92,45 @@ PythonExtensionManager::PythonExtensionManager()

void PythonExtensionManager::registerRateBuilders(const string& extensionName)
{
char* c_rateTypes = ct_getPythonExtensibleRateTypes(extensionName);
if (c_rateTypes == nullptr) {
// Each rate builder class is decorated with @extension, which calls the
// registerPythonRateBuilder method to register that class. So all we have
// to do here is load the module.
PyObject* module_name = PyUnicode_FromString(extensionName.c_str());
PyObject* py_module = PyImport_Import(module_name);
Py_DECREF(module_name);
if (py_module == nullptr) {
throw CanteraError("PythonExtensionManager::registerRateBuilders",
"Problem loading module:\n{}", getPythonExceptionInfo());
}
string rateTypes(c_rateTypes);
free(c_rateTypes);

// Each line in rateTypes is a (class name, rate name) pair, separated by a tab
vector<string> lines;
ba::split(lines, rateTypes, boost::is_any_of("\n"));
for (auto& line : lines) {
vector<string> tokens;
ba::split(tokens, line, boost::is_any_of("\t"));
if (tokens.size() != 2) {
CanteraError("PythonExtensionManager::registerRateBuilders",
"Got unparsable input from ct_getPythonExtensibleRateTypes:"
"\n'''{}\n'''", rateTypes);
}

void PythonExtensionManager::registerPythonRateBuilder(
const std::string& moduleName, const std::string& className,
const std::string& rateName)
{
// Make sure the helper module has been loaded
PythonExtensionManager mgr;

// Create a function that constructs and links a C++ ReactionRateDelegator
// object and a Python ExtensibleRate object of a particular type, and register
// this as the builder for reactions of this type
auto builder = [moduleName, className](const AnyMap& params, const UnitStack& units) {
auto delegator = make_unique<ReactionRateDelegator>();
PyObject* extRate = ct_newPythonExtensibleRate(delegator.get(),
moduleName, className);
if (extRate == nullptr) {
throw CanteraError("PythonExtensionManager::registerRateBuilders",
"Problem in ct_newPythonExtensibleRate:\n{}",
getPythonExceptionInfo());
}

string rateName = tokens[0];

// Create a function that constructs and links a C++ ReactionRateDelegator
// object and a Python ExtensibleRate object of a particular type, and register
// this as the builder for reactions of this type
auto builder = [rateName, extensionName](const AnyMap& params, const UnitStack& units) {
auto delegator = make_unique<ReactionRateDelegator>();
PyObject* extRate = ct_newPythonExtensibleRate(delegator.get(),
extensionName, rateName);
if (extRate == nullptr) {
throw CanteraError("PythonExtensionManager::registerRateBuilders",
"Problem in ct_newPythonExtensibleRate:\n{}",
getPythonExceptionInfo());
}

// Make the delegator responsible for eventually deleting the Python object
Py_IncRef(extRate);
delegator->addCleanupFunc([extRate]() { Py_DecRef(extRate); });

return delegator.release();
};
ReactionRateFactory::factory()->reg(tokens[1], builder);
}
// Make the delegator responsible for eventually deleting the Python object
Py_IncRef(extRate);
delegator->addCleanupFunc([extRate]() { Py_DecRef(extRate); });

return delegator.release();
};
ReactionRateFactory::factory()->reg(rateName, builder);
}

};
18 changes: 0 additions & 18 deletions src/extensions/pythonExtensions.pyx
Expand Up @@ -32,24 +32,6 @@ cdef public char* ct_getExceptionString(object exType, object exValue, object ex
return c_string


cdef public char* ct_getPythonExtensibleRateTypes(const string& module_name) except NULL:
"""
Load the named module and find classes derived from ExtensibleRate.
Returns a string where each line contains the class name and the corresponding
rate name, separated by a space
"""
mod = importlib.import_module(module_name.decode())
names = "\n".join(
f"{name}\t{cls._reaction_rate_type}"
for name, cls in inspect.getmembers(mod)
if inspect.isclass(cls) and issubclass(cls, ct.ExtensibleRate))
tmp = bytes(names.encode())
cdef char* c_string = <char*> malloc((len(tmp) + 1) * sizeof(char))
strcpy(c_string, tmp)
return c_string


cdef public object ct_newPythonExtensibleRate(CxxReactionRateDelegator* delegator,
const string& module_name,
const string& class_name):
Expand Down
5 changes: 0 additions & 5 deletions src/kinetics/ReactionRateFactory.cpp
Expand Up @@ -16,7 +16,6 @@
#include "cantera/kinetics/InterfaceRate.h"
#include "cantera/kinetics/PlogRate.h"
#include "cantera/kinetics/TwoTempPlasmaRate.h"
#include "cantera/kinetics/ReactionRateDelegator.h"

namespace Cantera
{
Expand Down Expand Up @@ -99,10 +98,6 @@ ReactionRateFactory::ReactionRateFactory()
reg("sticking-Blowers-Masel", [](const AnyMap& node, const UnitStack& rate_units) {
return new StickingBlowersMaselRate(node, rate_units);
});

reg("extensible", [](const AnyMap& node, const UnitStack& rate_units) {
return new ReactionRateDelegator(node, rate_units);
});
}

shared_ptr<ReactionRate> newReactionRate(const std::string& type)
Expand Down
14 changes: 14 additions & 0 deletions test/data/extensible-reactions.yaml
@@ -0,0 +1,14 @@
extensions:
- type: python
name: user_ext

phases:
- name: gas
thermo: ideal-gas
species: [{h2o2.yaml/species: all}]
kinetics: gas
state: {T: 300.0, P: 1 atm}

reactions:
- equation: H + O2 = HO2
type: square-rate
41 changes: 32 additions & 9 deletions test/python/test_reaction.py
@@ -1,5 +1,6 @@
from math import exp
from pathlib import Path
import sys
import textwrap

import cantera as ct
Expand Down Expand Up @@ -1498,28 +1499,37 @@ def func(T):
assert (gas.forward_rate_constants == gas.T).all()


@ct.extension(name="user-rate-1")
class UserRate1(ct.ExtensibleRate):
def replace_eval(self, T):
return 38.7 * T**2.7 * exp(-3150.15428/T)


class TestExtensible(ReactionTests, utilities.CanteraTest):
# test Extensible reaction rate
class UserRate(ct.ExtensibleRate):
def replace_eval(self, T):
return 38.7 * T**2.7 * exp(-3150.15428/T)

# probe O + H2 <=> H + OH
_rate_cls = UserRate
_rate_cls = UserRate1
_equation = "H2 + O <=> H + OH"
_index = 0
_rate_type = "extensible"
_yaml = None
_rate_type = "user-rate-1"
_rate = {
"type": "user-rate-1",
}
_yaml = """
equation: H2 + O <=> H + OH
type: user-rate-1
"""

def setUp(self):
super().setUp()
self._rate_obj = self.UserRate()
self._rate_obj = UserRate1()

def test_no_rate(self):
pytest.skip("ExtensibleRate does not support 'empty' rates")

def from_yaml(self):
pytest.skip("ExtensibleRate does not support YAML")
def test_from_dict(self):
pytest.skip("ExtensibleRate does not support serialization")

def from_rate(self, rate):
pytest.skip("ExtensibleRate does not support dict-based instantiation")
Expand All @@ -1528,6 +1538,19 @@ def test_roundtrip(self):
pytest.skip("ExtensibleRate does not support roundtrip conversion")


class TestExtensible2(utilities.CanteraTest):
def test_load_module(self):
here = str(Path(__file__).parent)
if here not in sys.path:
sys.path.append(here)

gas = ct.Solution("extensible-reactions.yaml", transport_model=None)

for T in np.linspace(300, 3000, 10):
gas.TP = T, None
assert gas.forward_rate_constants[0] == pytest.approx(T**2)


class InterfaceReactionTests(ReactionTests):
# test suite for surface reaction expressions

Expand Down
6 changes: 6 additions & 0 deletions test/python/user_ext.py
@@ -0,0 +1,6 @@
import cantera as ct

@ct.extension(name="square-rate")
class SquareRate(ct.ExtensibleRate):
def replace_eval(self, T):
return T**2

0 comments on commit da59a90

Please sign in to comment.