Skip to content

Commit

Permalink
Register ExtensibleRate types with ReactionRateFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
speth committed Sep 11, 2022
1 parent 4f45b14 commit cd7b90f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
1 change: 1 addition & 0 deletions interfaces/cython/cantera/_cantera.pyx
Expand Up @@ -7,6 +7,7 @@
import sys
import importlib
import importlib.abc
import importlib.util

# Chooses the right init function
# See https://stackoverflow.com/a/52714500
Expand Down
35 changes: 32 additions & 3 deletions src/extensions/PythonExtensionManager.cpp
Expand Up @@ -5,9 +5,15 @@

#include "cantera/extensions/PythonExtensionManager.h"

#include "cantera/kinetics/ReactionRateFactory.h"
#include "cantera/kinetics/ReactionRateDelegator.h"
#include "pythonExtensions.h" // generated by Cython

#include <boost/algorithm/string.hpp>

namespace ba = boost::algorithm;
using namespace std;

namespace Cantera
{

Expand Down Expand Up @@ -49,12 +55,35 @@ PythonExtensionManager::PythonExtensionManager()
Py_DECREF(pyModule);
}

void PythonExtensionManager::registerRateBuilders(const std::string& extensionName)
void PythonExtensionManager::registerRateBuilders(const string& extensionName)
{
char* c_rateTypes = ct_getPythonExtensibleRateTypes(extensionName);
std::string rateTypes(c_rateTypes);
string rateTypes(c_rateTypes);
free(c_rateTypes);
writelog("Module returned types: '{}'\n", rateTypes);

// Each line in rateTypes is a (class name, rate name) pair, separated by a tab
vector<string> lines;
ba::split(lines, rateTypes, boost::is_any_of("\n"));
for (auto& line : lines) {
vector<string> tokens;
ba::split(tokens, line, boost::is_any_of("\t"));
if (tokens.size() != 2) {
CanteraError("PythonExtensionManager::registerRateBuilders",
"Got unparsable input from ct_getPythonExtensibleRateTypes:"
"\n'''{}\n'''", rateTypes);
}
string rateName = tokens[0];

// Create a function that constructs and links a C++ ReactionRateDelegator
// object and a Python ExtensibleRate object of a particular type, and register
// this as the builder for reactions of this type
auto builder = [rateName, extensionName](const AnyMap& params, const UnitStack& units) {
auto delegator = make_unique<ReactionRateDelegator>();
ct_addPythonExtensibleRate(delegator.get(), extensionName, rateName);
return delegator.release();
};
ReactionRateFactory::factory()->reg(tokens[1], builder);
}
}

};
26 changes: 23 additions & 3 deletions src/extensions/pythonExtensions.pyx
Expand Up @@ -7,14 +7,22 @@
from libcpp.string cimport string
from libc.stdlib cimport malloc
from libc.string cimport strcpy
from cpython.ref cimport Py_INCREF

import importlib
import inspect

import cantera as ct
from cantera.reaction cimport ExtensibleRate
from cantera.reaction cimport ExtensibleRate, CxxReactionRate
from cantera.delegator cimport CxxDelegator, assign_delegates

cdef public char* ct_getPythonExtensibleRateTypes(const string& module_name):

cdef extern from "cantera/kinetics/ReactionRateDelegator.h" namespace "Cantera":
cdef cppclass CxxReactionRateDelegator "Cantera::ReactionRateDelegator" (CxxDelegator, CxxReactionRate):
CxxReactionRateDelegator()


cdef public char* ct_getPythonExtensibleRateTypes(const string& module_name) except NULL:
"""
Load the named module and find classes derived from ExtensibleRate.
Expand All @@ -23,10 +31,22 @@ cdef public char* ct_getPythonExtensibleRateTypes(const string& module_name):
"""
mod = importlib.import_module(module_name.decode())
names = "\n".join(
f"{name} {cls._reaction_rate_type}"
f"{name}\t{cls._reaction_rate_type}"
for name, cls in inspect.getmembers(mod)
if inspect.isclass(cls) and issubclass(cls, ct.ExtensibleRate))
tmp = bytes(names.encode())
cdef char* c_string = <char*> malloc((len(tmp) + 1) * sizeof(char))
strcpy(c_string, tmp)
return c_string


cdef public int ct_addPythonExtensibleRate(CxxReactionRateDelegator* delegator,
const string& module_name,
const string& class_name) except -1:

mod = importlib.import_module(module_name.decode())
cdef ExtensibleRate rate = getattr(mod, class_name.decode())(init=False)
Py_INCREF(rate)
rate.set_cxx_object(delegator)
assign_delegates(rate, delegator)
return 0

0 comments on commit cd7b90f

Please sign in to comment.