Skip to content

Commit

Permalink
[Cython] Retain units info when converting AnyMap to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
speth authored and ischoegl committed Apr 18, 2023
1 parent 803f106 commit 4b88a20
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 25 deletions.
3 changes: 3 additions & 0 deletions include/cantera/base/AnyMap.h
Expand Up @@ -621,6 +621,9 @@ class AnyMap : public AnyBase
//! Return the default units that should be used to convert stored values
const UnitSystem& units() const { return *m_units; }

//! @copydoc units()
shared_ptr<UnitSystem> unitsShared() const { return m_units; }

//! Use the supplied UnitSystem to set the default units, and recursively
//! process overrides from nodes named `units`.
/*!
Expand Down
7 changes: 6 additions & 1 deletion interfaces/cython/cantera/_utils.pxd
Expand Up @@ -7,6 +7,7 @@
from libcpp.unordered_map cimport unordered_map

from .ctcxx cimport *
from .units cimport UnitSystem

cdef extern from "cantera/base/AnyMap.h" namespace "Cantera":
cdef cppclass CxxAnyValue "Cantera::AnyValue"
Expand Down Expand Up @@ -37,6 +38,7 @@ cdef extern from "cantera/base/AnyMap.h" namespace "Cantera":
void update(CxxAnyMap& other, cbool)
string keys_str()
void applyUnits()
shared_ptr[CxxUnitSystem] unitsShared()

cdef cppclass CxxAnyValue "Cantera::AnyValue":
CxxAnyValue()
Expand Down Expand Up @@ -86,14 +88,17 @@ cdef extern from "cantera/cython/utils_utils.h":

cdef void CxxSetLogger "setLogger" (CxxPythonLogger*)

cdef class AnyMap(dict):
cdef _set_CxxUnitSystem(self, shared_ptr[CxxUnitSystem] units)
cdef UnitSystem unitsystem

cdef string stringify(x) except *
cdef pystr(string x)

cdef comp_map_to_dict(Composition m)
cdef Composition comp_map(X) except *

cdef CxxAnyMap dict_to_anymap(dict data, cbool hyphenize=*) except *
cdef CxxAnyMap dict_to_anymap(data, cbool hyphenize=*) except *
cdef anymap_to_dict(CxxAnyMap& m)

cdef CxxAnyValue python_to_anyvalue(item, name=*) except *
Expand Down
24 changes: 19 additions & 5 deletions interfaces/cython/cantera/_utils.pyx
Expand Up @@ -153,6 +153,18 @@ class CanteraError(RuntimeError):

cdef public PyObject* pyCanteraError = <PyObject*>CanteraError


cdef class AnyMap(dict):
def __cinit__(self, *args, **kwawrgs):
self.unitsystem = UnitSystem()

cdef _set_CxxUnitSystem(self, shared_ptr[CxxUnitSystem] units):
self.unitsystem._set_unitSystem(units)

def default_units(self):
return self.unitsystem.defaults()


cdef anyvalue_to_python(string name, CxxAnyValue& v):
cdef CxxAnyMap a
cdef CxxAnyValue b
Expand Down Expand Up @@ -207,12 +219,14 @@ cdef anyvalue_to_python(string name, CxxAnyValue& v):
cdef anymap_to_dict(CxxAnyMap& m):
cdef pair[string,CxxAnyValue] item
m.applyUnits()
if m.empty():
return {}
return {pystr(item.first): anyvalue_to_python(item.first, item.second)
for item in m.ordered()}
cdef AnyMap out = AnyMap()
out._set_CxxUnitSystem(m.unitsShared())
for item in m.ordered():
out[pystr(item.first)] = anyvalue_to_python(item.first, item.second)
return out


cdef CxxAnyMap dict_to_anymap(dict data, cbool hyphenize=False) except *:
cdef CxxAnyMap dict_to_anymap(data, cbool hyphenize=False) except *:
cdef CxxAnyMap m
if hyphenize:
# replace "_" by "-": while Python dictionaries typically use "_" in key names,
Expand Down
24 changes: 12 additions & 12 deletions interfaces/cython/cantera/reaction.pyx
Expand Up @@ -221,7 +221,7 @@ cdef class ArrheniusRate(ArrheniusRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea=Ea)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxArrheniusRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea):
Expand All @@ -248,7 +248,7 @@ cdef class BlowersMaselRate(ArrheniusRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea0=Ea0, w=w)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxBlowersMaselRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea0, w):
Expand Down Expand Up @@ -321,7 +321,7 @@ cdef class TwoTempPlasmaRate(ArrheniusRateBase):
"""
return self.rate.eval(temperature, elec_temp)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(
new CxxTwoTempPlasmaRate(dict_to_anymap(input_data, hyphenize=True))
)
Expand Down Expand Up @@ -448,7 +448,7 @@ cdef class LindemannRate(FalloffRate):
"""
_reaction_rate_type = "Lindemann"

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(
new CxxLindemannRate(dict_to_anymap(input_data, hyphenize=True))
)
Expand All @@ -468,7 +468,7 @@ cdef class TroeRate(FalloffRate):
"""
_reaction_rate_type = "Troe"

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(
new CxxTroeRate(dict_to_anymap(input_data, hyphenize=True))
)
Expand All @@ -488,7 +488,7 @@ cdef class SriRate(FalloffRate):
"""
_reaction_rate_type = "SRI"

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(
new CxxSriRate(dict_to_anymap(input_data, hyphenize=True))
)
Expand All @@ -504,7 +504,7 @@ cdef class TsangRate(FalloffRate):
"""
_reaction_rate_type = "Tsang"

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(
new CxxTsangRate(dict_to_anymap(input_data, hyphenize=True))
)
Expand Down Expand Up @@ -834,7 +834,7 @@ cdef class InterfaceRateBase(ArrheniusRateBase):
cdef CxxAnyMap cxx_deps
self.interface.getCoverageDependencies(cxx_deps)
return anymap_to_dict(cxx_deps)
def __set__(self, dict deps):
def __set__(self, deps):
cdef CxxAnyMap cxx_deps = dict_to_anymap(deps)

self.interface.setCoverageDependencies(cxx_deps)
Expand Down Expand Up @@ -892,7 +892,7 @@ cdef class InterfaceArrheniusRate(InterfaceRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea=Ea)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxInterfaceArrheniusRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea):
Expand Down Expand Up @@ -920,7 +920,7 @@ cdef class InterfaceBlowersMaselRate(InterfaceRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea0=Ea0, w=w)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxInterfaceBlowersMaselRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea0, w):
Expand Down Expand Up @@ -1030,7 +1030,7 @@ cdef class StickingArrheniusRate(StickRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea=Ea)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxStickingArrheniusRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea):
Expand All @@ -1057,7 +1057,7 @@ cdef class StickingBlowersMaselRate(StickRateBase):
if init:
self._cinit(input_data, A=A, b=b, Ea0=Ea0, w=w)

def _from_dict(self, dict input_data):
def _from_dict(self, input_data):
self._rate.reset(new CxxStickingBlowersMaselRate(dict_to_anymap(input_data)))

def _from_parameters(self, A, b, Ea0, w):
Expand Down
4 changes: 3 additions & 1 deletion interfaces/cython/cantera/units.pxd
Expand Up @@ -34,4 +34,6 @@ cdef class Units:
cdef copy(CxxUnits)

cdef class UnitSystem:
cdef CxxUnitSystem unitsystem
cdef _set_unitSystem(self, shared_ptr[CxxUnitSystem] units)
cdef shared_ptr[CxxUnitSystem] _unitsystem
cdef CxxUnitSystem* unitsystem
10 changes: 9 additions & 1 deletion interfaces/cython/cantera/units.pyx
Expand Up @@ -94,13 +94,21 @@ cdef class UnitSystem:
ct.UnitSystem()
"""
def __cinit__(self, units=None):
self.unitsystem = CxxUnitSystem()
self._unitsystem.reset(new CxxUnitSystem())
self.unitsystem = self._unitsystem.get()
if units:
self.units = units

def __repr__(self):
return f"<UnitSystem at {id(self):0x}>"

cdef _set_unitSystem(self, shared_ptr[CxxUnitSystem] units):
self._unitsystem = units
self.unitsystem = self._unitsystem.get()

def defaults(self):
return self.unitsystem.defaults()

property units:
"""
Units used by the unit system
Expand Down
2 changes: 1 addition & 1 deletion interfaces/cython/cantera/yamlwriter.pxd
Expand Up @@ -24,4 +24,4 @@ cdef class YamlWriter:
cdef shared_ptr[CxxYamlWriter] _writer
cdef CxxYamlWriter* writer
@staticmethod
cdef CxxUnitSystem _get_unitsystem(UnitSystem units)
cdef shared_ptr[CxxUnitSystem] _get_unitsystem(UnitSystem units)
8 changes: 4 additions & 4 deletions interfaces/cython/cantera/yamlwriter.pyx
Expand Up @@ -3,6 +3,7 @@

from .solutionbase cimport *
from ._utils cimport *
from cython.operator import dereference as deref

cdef class YamlWriter:
"""
Expand Down Expand Up @@ -64,12 +65,11 @@ cdef class YamlWriter:
def __set__(self, units):
if not isinstance(units, UnitSystem):
units = UnitSystem(units)
cdef CxxUnitSystem cxxunits = YamlWriter._get_unitsystem(units)
self.writer.setUnitSystem(cxxunits)
self.writer.setUnitSystem(deref(YamlWriter._get_unitsystem(units).get()))

@staticmethod
cdef CxxUnitSystem _get_unitsystem(UnitSystem units):
return units.unitsystem
cdef shared_ptr[CxxUnitSystem] _get_unitsystem(UnitSystem units):
return units._unitsystem

def __reduce__(self):
raise NotImplementedError('YamlWriter object is not picklable')
Expand Down

0 comments on commit 4b88a20

Please sign in to comment.