diff --git a/src/symnum/array.py b/src/symnum/array.py index bda8fb7..9a4f44e 100644 --- a/src/symnum/array.py +++ b/src/symnum/array.py @@ -16,7 +16,13 @@ SympyArray: TypeAlias = Union[sympy.NDimArray, sympy.MatrixBase] ScalarLike: TypeAlias = Union[ - sympy.Expr, sympy.logic.boolalg.Boolean, bool, int, float, complex, np.number, + sympy.Expr, + sympy.logic.boolalg.Boolean, + bool, + int, + float, + complex, + np.number, ] ShapeLike: TypeAlias = Union[int, tuple[int, ...], sympy.Tuple] @@ -504,6 +510,28 @@ def subs(self, *args) -> SymbolicArray: dtype=self._dtype, ) + def simplify(self, **kwargs) -> SymbolicArray: + """Simplify symbolic expressions in array. + + Args: + **kwargs: Any keyword arguments to :py:meth:`sympy.NDimArray.simplify`. + + Returns: + Array with simplified symbolic expressions. + """ + if self.shape == (): + return SymbolicArray( + self[()].simplify(**kwargs), + shape=(), + dtype=self._dtype, + ) + else: + return SymbolicArray( + self._base_array.simplify(**kwargs), + shape=self.shape, + dtype=self._dtype, + ) + @property def free_symbols(self) -> set[sympy.Symbol]: """Set of all free symbols in symbolic expressions defined in array."""