Skip to content

Commit

Permalink
Merge pull request pybamm-team#1067 from pybamm-team/issue-1066-addit…
Browse files Browse the repository at this point in the history
…ional-casadi-funcitons

pybamm-team#1066 add numpy function sqrt, sin, cos and exp to convert_to_casadi
  • Loading branch information
valentinsulzer committed Jun 23, 2020
2 parents 17460e8 + 67f0fc1 commit 6c7a06e
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

## Bug fixes

- Allowed for pybamm functions exp, sin, cos, sqrt to be used in expression trees that
are converted to casadi format ([#1067](https://github.com/pybamm-team/PyBaMM/pull/1067)
- Fix a bug where variables that depend on y and z were transposed in `QuickPlot` ([#1055](https://github.com/pybamm-team/PyBaMM/pull/1055))

## Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm

pybamm.set_logging_level("INFO")
# pybamm.set_logging_level("INFO")

# load models
models = [
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def entries_string(self, value):
if issparse(entries):
self._entries_string = str(entries.__dict__)
else:
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def entries_string(self, value):
self._entries_string = value
else:
entries = self.data
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
Expand Down
22 changes: 22 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.mmax(*converted_children)
elif symbol.function == np.abs:
return casadi.fabs(*converted_children)
elif symbol.function == np.sqrt:
return casadi.sqrt(*converted_children)
elif symbol.function == np.sin:
return casadi.sin(*converted_children)
elif symbol.function == np.arcsinh:
return casadi.arcsinh(*converted_children)
elif symbol.function == np.arccosh:
return casadi.arccosh(*converted_children)
elif symbol.function == np.tanh:
return casadi.tanh(*converted_children)
elif symbol.function == np.cosh:
return casadi.cosh(*converted_children)
elif symbol.function == np.sinh:
return casadi.sinh(*converted_children)
elif symbol.function == np.cos:
return casadi.cos(*converted_children)
elif symbol.function == np.exp:
return casadi.exp(*converted_children)
elif symbol.function == np.log:
return casadi.log(*converted_children)
elif symbol.function == np.sign:
return casadi.sign(*converted_children)
elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)):
return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)(
*converted_children
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def shape(self):
# Default behaviour is to try to evaluate the object directly
# Try with some large y, to avoid having to unpack (slow)
try:
y = np.linspace(0.1, 0.9, int(1e4))
y = np.nan * np.ones((1000, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
# If that fails, fall back to calculating how big y should really be
except ValueError:
Expand All @@ -753,7 +753,7 @@ def shape(self):
len(x._evaluation_array) for x in state_vectors_in_node
)
# Pick a y that won't cause RuntimeWarnings
y = np.linspace(0.1, 0.9, min_y_size)
y = np.nan * np.ones((min_y_size, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")

# Return shape of evaluated object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def test_convert_scalar_symbols(self):
self.assertEqual(abs(c).to_casadi(), casadi.MX(1))

# function
def sin(x):
return np.sin(x)
def square_plus_one(x):
return x ** 2 + 1

f = pybamm.Function(sin, b)
self.assertEqual(f.to_casadi(), casadi.MX(np.sin(1)))
f = pybamm.Function(square_plus_one, b)
self.assertEqual(f.to_casadi(), 2)

def myfunction(x, y):
return x + y
Expand Down Expand Up @@ -95,6 +95,12 @@ def test_special_functions(self):
self.assert_casadi_equal(
pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True
)
for np_fun in [np.sqrt, np.tanh, np.cosh, np.sinh,
np.exp, np.log, np.sign, np.sin, np.cos,
np.arccosh, np.arcsinh]:
self.assert_casadi_equal(
pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True
)

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]
Expand Down

0 comments on commit 6c7a06e

Please sign in to comment.