Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,104 changes: 0 additions & 1,104 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3", None),
"sympy": ("https://docs.sympy.org/dev/", None),
"pytools": ("https://documen.tician.de/pytools/", None),
"typing_extensions":
("https://typing-extensions.readthedocs.io/en/latest/", None),
"constantdict":
Expand Down
33 changes: 25 additions & 8 deletions pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, cast

from typing_extensions import override

import pymbolic.primitives as prim
from pymbolic.mapper import IdentityMapper, WalkMapper


if TYPE_CHECKING:
from collections.abc import Callable, Hashable

from pymbolic.typing import Expression


COMMUTATIVE_CLASSES = (prim.Sum, prim.Product)


Expand Down Expand Up @@ -76,27 +86,33 @@ def map_common_subexpression(self, expr, *args, **kwargs):
self.subexpr_counts[key] = 1


class CSEMapper(IdentityMapper):
class CSEMapper(IdentityMapper[[]]):
def __init__(self, to_eliminate, get_key):
self.to_eliminate = to_eliminate
self.get_key = get_key
self.get_key: Callable[[Expression], Hashable] = get_key

self.canonical_subexprs = {}
self.canonical_subexprs: dict[Hashable, Expression] = {}

def get_cse(self, expr, key=None):
def get_cse(self, expr: prim.ExpressionNode, key: Hashable = None):
if key is None:
key = self.get_key(expr)

try:
return self.canonical_subexprs[key]
except KeyError:
new_expr = prim.make_common_subexpression(
new_expr = cast("Expression", prim.make_common_subexpression(
getattr(IdentityMapper, expr.mapper_method)(self, expr)
)
))
self.canonical_subexprs[key] = new_expr
return new_expr

def map_sum(self, expr):
@override
def map_sum(self,
expr: (prim.Sum | prim.Product | prim.Power
| prim.Quotient | prim.Remainder | prim.FloorDiv
| prim.Call
)
):
key = self.get_key(expr)
if key in self.to_eliminate:
result = self.get_cse(expr, key)
Expand All @@ -111,7 +127,8 @@ def map_sum(self, expr):
map_floor_div = map_sum
map_call = map_sum

def map_common_subexpression(self, expr):
@override
def map_common_subexpression(self, expr: prim.CommonSubexpression):
# Avoid creating CSE(CSE(...))
if type(expr) is prim.CommonSubexpression:
return prim.make_common_subexpression(
Expand Down
31 changes: 23 additions & 8 deletions pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from typing_extensions import Self, override

from pytools import memoize, memoize_method
from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT

from pymbolic.primitives import expr_dataclass, is_zero

Expand Down Expand Up @@ -151,16 +152,16 @@
class _HasArithmetic(Protocol):
def __neg__(self: CoeffT) -> CoeffT: ...
def __abs__(self: CoeffT) -> CoeffT: ...

def __add__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __radd__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __radd__(self: CoeffT, other: int, /) -> CoeffT: ...

def __sub__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __rsub__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...

def __mul__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __rmul__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __rmul__(self: CoeffT, other: int, /) -> CoeffT: ...

def __pow__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __rpow__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...


CoeffT = TypeVar("CoeffT", bound=_HasArithmetic)
Expand Down Expand Up @@ -632,6 +633,7 @@ def __init__(
data: (Mapping[int, CoeffT | int]
| Mapping[tuple[int, ...], CoeffT | int]
| NDArray[np.generic]
| ObjectArray1D[CoeffT]
| CoeffT
| int),
space: Space[CoeffT] | None = None
Expand All @@ -654,7 +656,7 @@ def __init__(
"""

data_dict: Mapping[tuple[int, ...], CoeffT | int] | Mapping[int, CoeffT | int]
if isinstance(data, np.ndarray):
if isinstance(data, (np.ndarray, ObjectArray)):
if len(data.shape) != 1:
raise ValueError(
"Only numpy vectors (not higher-rank objects) "
Expand Down Expand Up @@ -1247,17 +1249,30 @@ def map(self, f: Callable[[CoeffT], CoeffT]) -> MultiVector[CoeffT]:
# }}}


T = TypeVar("T")
@overload
def componentwise(
f: Callable[[CoeffT], CoeffT],
expr: MultiVector[CoeffT]
) -> MultiVector[CoeffT]: ...

@overload
def componentwise(
f: Callable[[CoeffT], CoeffT],
expr: ObjectArray[ShapeT, CoeffT]
) -> ObjectArray[ShapeT, CoeffT]: ...


def componentwise(f: Callable[[CoeffT], CoeffT], expr: T) -> T:
def componentwise(
f: Callable[[CoeffT], CoeffT],
expr: MultiVector[CoeffT] | ObjectArray[ShapeT, CoeffT]
) -> MultiVector[CoeffT] | ObjectArray[ShapeT, CoeffT]:
"""Apply function *f* componentwise to object arrays and
:class:`MultiVector` instances. *expr* is also allowed to
be a scalar.
"""

if isinstance(expr, MultiVector):
return cast("T", cast("MultiVector[CoeffT]", expr).map(f))
return expr.map(f)

from pytools.obj_array import obj_array_vectorize
return obj_array_vectorize(f, expr)
Expand Down
13 changes: 7 additions & 6 deletions pymbolic/geometric_algebra/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if TYPE_CHECKING:
from collections.abc import Hashable

from pymbolic.typing import Expression
from pymbolic.typing import ArithmeticExpression, Expression


class MultiVectorVariable(Variable):
Expand Down Expand Up @@ -77,20 +77,21 @@ class Derivative:
_next_id: ClassVar[list[int]] = [0]

def __init__(self):
self.my_id = f"id{self._next_id[0]}"
self.my_id: str = f"id{self._next_id[0]}"
self._next_id[0] += 1

@property
def nabla(self):
return Nabla(self.my_id)

def dnabla(self, ambient_dim):
def dnabla(self, ambient_dim: int):
from pytools.obj_array import make_obj_array

from pymbolic.geometric_algebra import MultiVector
return MultiVector(make_obj_array(
[NablaComponent(axis, self.my_id)
for axis in range(ambient_dim)]))
nablas: list[ArithmeticExpression] = [
NablaComponent(axis, self.my_id)
for axis in range(ambient_dim)]
return MultiVector(make_obj_array(nablas))

def __call__(self, operand):
from pymbolic.geometric_algebra import MultiVector
Expand Down
7 changes: 4 additions & 3 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from constantdict import constantdict
from typing_extensions import Any, ParamSpec, TypeIs, override

from pytools import ndindex

import pymbolic.primitives as p
from pymbolic.typing import ArithmeticExpression, Expression

Expand Down Expand Up @@ -201,7 +203,7 @@
else:
return self.handle_unsupported_expression(expr, *args, **kwargs)
else:
return self.map_foreign(expr, *args, **kwargs)

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 206 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

rec = __call__

Expand Down Expand Up @@ -1076,7 +1078,7 @@

import numpy
result = numpy.empty(expr.shape, dtype=object)
for i in numpy.ndindex(expr.shape):
for i in ndindex(expr.shape):
result[i] = self.rec(expr[i], *args, **kwargs)

# True fact: ndarrays aren't expressions
Expand Down Expand Up @@ -1376,8 +1378,7 @@
if not self.visit(expr, *args, **kwargs):
return

import numpy
for i in numpy.ndindex(expr.shape):
for i in ndindex(expr.shape):
self.rec(expr[i], *args, **kwargs)

self.post_visit(expr, *args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion pymbolic/mapper/differentiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from typing_extensions import override

from pytools import ndindex

import pymbolic
import pymbolic.mapper
import pymbolic.mapper.evaluator
Expand Down Expand Up @@ -207,7 +209,7 @@ def map_power(self, expr, *args):
def map_numpy_array(self, expr, *args):
import numpy
result = numpy.empty(expr.shape, dtype=object)
for i in numpy.ndindex(result.shape):
for i in ndindex(result.shape):
result[i] = self.rec(expr[i], *args)
return result

Expand Down
4 changes: 3 additions & 1 deletion pymbolic/mapper/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"""
from __future__ import annotations

from pytools import ndindex


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

Expand Down Expand Up @@ -187,7 +189,7 @@ def map_list(self, expr: list[Expression]) -> ResultT:
def map_numpy_array(self, expr: NDArray[np.generic]) -> ResultT:
import numpy
result = numpy.empty(expr.shape, dtype=object)
for i in numpy.ndindex(expr.shape):
for i in ndindex(expr.shape):
result[i] = self.rec(expr[i])
return result # type: ignore[return-value]

Expand Down
4 changes: 3 additions & 1 deletion pymbolic/mapper/stringifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from typing_extensions import deprecated, override

from pytools import ndindex

import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper, Mapper, P

Expand Down Expand Up @@ -632,7 +634,7 @@ def map_numpy_array(

str_array = numpy.zeros(expr.shape, dtype="object")
max_length = 0
for i in numpy.ndindex(expr.shape):
for i in ndindex(expr.shape):
s = self.rec(expr[i], PREC_NONE, *args, **kwargs)
max_length = max(len(s), max_length)
str_array[i] = s.replace("\n", "\n ")
Expand Down
Loading
Loading