Skip to content

Commit

Permalink
Remove need for prefixing required delegates
Browse files Browse the repository at this point in the history
  • Loading branch information
speth authored and ischoegl committed Jan 21, 2023
1 parent 8ece96f commit 5e5e7df
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 19 deletions.
14 changes: 10 additions & 4 deletions interfaces/cython/cantera/delegator.pyx
Expand Up @@ -237,8 +237,14 @@ cdef int assign_delegates(obj, CxxDelegator* delegator) except -1:
cdef string cxx_name
cdef string cxx_when
obj._delegates = []
for name in obj.delegatable_methods:
when = None
for name, options in obj.delegatable_methods.items():
if len(options) == 3:
# Delegate with pre-selected mode, without using prefix on method name
when = options[2]
method = getattr(obj, name)
else:
when = None

replace = 'replace_{}'.format(name)
if hasattr(obj, replace):
when = 'replace'
Expand All @@ -263,8 +269,8 @@ cdef int assign_delegates(obj, CxxDelegator* delegator) except -1:
if when is None:
continue

cxx_name = stringify(obj.delegatable_methods[name][0])
callback = obj.delegatable_methods[name][1].replace(' ', '')
cxx_name = stringify(options[0])
callback = options[1].replace(' ', '')

# Make sure that the number of arguments needed by the C++ function
# corresponds to the number of arguments accepted by the Python delegate
Expand Down
6 changes: 3 additions & 3 deletions interfaces/cython/cantera/reaction.pyx
Expand Up @@ -751,8 +751,8 @@ cdef class ExtensibleRate(ReactionRate):
_reaction_rate_type = "extensible"

delegatable_methods = {
"eval": ("evalFromStruct", "double(void*)"),
"set_parameters": ("setParameters", "void(AnyMap&, UnitStack&)")
"eval": ("evalFromStruct", "double(void*)", "replace"),
"set_parameters": ("setParameters", "void(AnyMap&, UnitStack&)", "after")
}
def __cinit__(self, *args, init=True, **kwargs):
if init:
Expand All @@ -776,7 +776,7 @@ cdef class ExtensibleRate(ReactionRate):

cdef class ExtensibleRateData:
delegatable_methods = {
"update": ("update", "double(void*)")
"update": ("update", "double(void*)", "replace")
}

cdef set_cxx_object(self, CxxReactionDataDelegator* data):
Expand Down
6 changes: 3 additions & 3 deletions samples/python/kinetics/custom_reactions.py
Expand Up @@ -40,7 +40,7 @@ class ExtensibleArrheniusData(ct.ExtensibleRateData):
def __init__(self):
self.T = None

def replace_update(self, gas):
def update(self, gas):
T = gas.T
if self.T != T:
self.T = T
Expand All @@ -51,12 +51,12 @@ def replace_update(self, gas):
@ct.extension(name="extensible-Arrhenius", data=ExtensibleArrheniusData)
class ExtensibleArrhenius(ct.ExtensibleRate):
__slots__ = ("A", "b", "Ea_R")
def after_set_parameters(self, params, units):
def set_parameters(self, params, units):
self.A = params["A"]
self.b = params["b"]
self.Ea_R = params["Ea_R"]

def replace_eval(self, data):
def eval(self, data):
return self.A * data.T**self.b * exp(-self.Ea_R/data.T)

extensible_yaml2 = """
Expand Down
6 changes: 3 additions & 3 deletions test/python/test_reaction.py
Expand Up @@ -1503,7 +1503,7 @@ class UserRate1Data(ct.ExtensibleRateData):
def __init__(self):
self.T = None

def replace_update(self, gas):
def update(self, gas):
T = gas.T
if T != self.T:
self.T = T
Expand All @@ -1518,10 +1518,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.A = np.nan

def after_set_parameters(self, params, units):
def set_parameters(self, params, units):
self.A = params["A"]

def replace_eval(self, data):
def eval(self, data):
return self.A * data.T**2.7 * exp(-3150.15428/data.T)


Expand Down
6 changes: 3 additions & 3 deletions test/python/user_ext.py
Expand Up @@ -3,16 +3,16 @@
class SquareRateData(ct.ExtensibleRateData):
__slots__ = ("Tsquared",)

def replace_update(self, gas):
def update(self, gas):
self.Tsquared = gas.T**2
return True

@ct.extension(name="square-rate", data=SquareRateData)
class SquareRate(ct.ExtensibleRate):
__slots__ = ("A",)

def after_set_parameters(self, node, units):
def set_parameters(self, node, units):
self.A = node["A"]

def replace_eval(self, data):
def eval(self, data):
return self.A * data.Tsquared
6 changes: 3 additions & 3 deletions test/python/user_ext_invalid.py
Expand Up @@ -3,14 +3,14 @@
this is a syntax error

class SquareRateData(ct.ExtensibleRateData):
def replace_update(self, gas):
def update(self, gas):
self.Tsquared = gas.T**2
return True

@ct.extension(name="square-rate", data=SquareRateData)
class SquareRate(ct.ExtensibleRate):
def after_set_parameters(self, node, units):
def set_parameters(self, node, units):
self.A = node["A"]

def replace_eval(self, data):
def eval(self, data):
return self.A * data.Tsquared

0 comments on commit 5e5e7df

Please sign in to comment.