Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
506 changes: 17 additions & 489 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,9 @@ def as_scalar(self) -> CoeffT | int:

return result

def as_vector(self, dtype: DTypeLike = None) -> NDArray[Any]:
def as_vector(self,
dtype: DTypeLike = None
) -> np.ndarray[tuple[int], np.dtype[Any]]:
"""Return a :mod:`numpy` vector corresponding to the grade-1
:class:`MultiVector` *self*.

Expand Down
44 changes: 28 additions & 16 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# Consider yourself warned.
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import override

import pymbolic.geometric_algebra.primitives as prim
from pymbolic.geometric_algebra import MultiVector
from pymbolic.mapper import (
Expand All @@ -48,6 +50,7 @@
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)
from pymbolic.typing import ArithmeticExpression


if TYPE_CHECKING:
Expand Down Expand Up @@ -117,13 +120,20 @@ def map_derivative_source(
self.post_visit(expr, *args, **kwargs)


class EvaluationMapper(EvaluationMapperBase):
class EvaluationMapper(EvaluationMapperBase[ResultT]):
def map_nabla_component(self, expr):
return expr

map_nabla = map_nabla_component

def map_derivative_source(self, expr):

class EvaluationRewriter(EvaluationMapperBase[ArithmeticExpression]):
def map_nabla_component(self, expr: prim.NablaComponent) -> ArithmeticExpression:
return expr

def map_derivative_source(self,
expr: prim.DerivativeSource
) -> ArithmeticExpression:
operand = self.rec(expr.operand)
if operand is expr.operand:
return expr
Expand Down Expand Up @@ -160,15 +170,15 @@ def map_derivative_source(self, expr):

# {{{ dimensionalizer

class Dimensionalizer(EvaluationMapper):
class Dimensionalizer(EvaluationRewriter):
"""
.. attribute:: ambient_dim

Dimension of ambient space. Must be provided by subclass.
"""

@property
def ambient_dim(self):
def ambient_dim(self) -> int:
raise NotImplementedError

def map_multivector_variable(self, expr):
Expand All @@ -183,7 +193,8 @@ def map_nabla(self, expr):
[prim.NablaComponent(axis, expr.nabla_id)
for axis in range(self.ambient_dim)]))

def map_derivative_source(self, expr):
@override
def map_derivative_source(self, expr: prim.DerivativeSource):
rec_op = self.rec(expr.operand)

if isinstance(rec_op, MultiVector):
Expand Down Expand Up @@ -214,44 +225,45 @@ def map_derivative_source(self, expr):
return {expr} | self.rec(expr.operand)


class NablaComponentToUnitVector(EvaluationMapper):
class NablaComponentToUnitVector(EvaluationRewriter):
def __init__(self, nabla_id, ambient_axis):
self.nabla_id = nabla_id
self.ambient_axis = ambient_axis
self.nabla_id: prim.NablaId = nabla_id
self.ambient_axis: int = ambient_axis

def map_variable(self, expr):
return expr

def map_nabla_component(self, expr):
@override
def map_nabla_component(self, expr: prim.NablaComponent) -> ArithmeticExpression:
if expr.nabla_id == self.nabla_id:
if expr.ambient_axis == self.ambient_axis:
return 1
else:
return 0
else:
return EvaluationMapper.map_nabla_component(self, expr)
return super().map_nabla_component(expr)


class DerivativeSourceFinder(EvaluationMapper):
class DerivativeSourceFinder(EvaluationRewriter):
"""Recurses down until it finds the
:class:`pymbolic.geometric_algebra.primitives.DerivativeSource`
with the right *nabla_id*, then calls :method:`DerivativeBinder.take_derivative`
on the source's argument.
"""

def __init__(self, nabla_id, binder, ambient_axis):
self.nabla_id = nabla_id
self.binder = binder
self.ambient_axis = ambient_axis
self.nabla_id: prim.NablaId = nabla_id
self.binder: DerivativeBinder = binder
self.ambient_axis: int = ambient_axis

def map_derivative_source(self, expr):
if expr.nabla_id == self.nabla_id:
return self.binder.take_derivative(self.ambient_axis, expr.operand)
else:
return EvaluationMapper.map_derivative_source(self, expr)
return super().map_derivative_source(expr)


class DerivativeBinder(IdentityMapper):
class DerivativeBinder(IdentityMapper[[]]):
derivative_source_and_nabla_component_collector = \
DerivativeSourceAndNablaComponentCollector
nabla_component_to_unit_vector = NablaComponentToUnitVector
Expand Down
7 changes: 5 additions & 2 deletions pymbolic/geometric_algebra/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# This is experimental, undocumented, and could go away any second.
# Consider yourself warned.

from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, TypeAlias

from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass

Expand All @@ -49,10 +49,13 @@ def stringifier(self):
return StringifyMapper


NablaId: TypeAlias = "Hashable"


@expr_dataclass()
class NablaComponent(_GeometricCalculusExpression):
ambient_axis: int
nabla_id: Hashable
nabla_id: NablaId


@expr_dataclass()
Expand Down
4 changes: 2 additions & 2 deletions pymbolic/mapper/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence, Set

from pymbolic.mapper.dependency import DependenciesT
from pymbolic.mapper.dependency import Dependencies
from pymbolic.typing import ArithmeticExpression, Expression


Expand All @@ -53,7 +53,7 @@ def __init__(self, parameters: Set[p.AlgebraicLeaf] | None = None):
parameters = set()
self.parameters = parameters

def get_dependencies(self, expr: Expression) -> DependenciesT:
def get_dependencies(self, expr: Expression) -> Dependencies:
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper()(expr)

Expand Down
32 changes: 18 additions & 14 deletions pymbolic/mapper/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from __future__ import annotations

from typing_extensions import override


__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"

Expand All @@ -31,17 +29,23 @@
"""

from collections.abc import Set
from typing import Literal, TypeAlias
from typing import TYPE_CHECKING, Literal, TypeAlias

from typing_extensions import override

import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper, Collector, CSECachingMapperMixin, P


DependenciesT: TypeAlias = Set[p.AlgebraicLeaf | p.CommonSubexpression]
Dependency: TypeAlias = p.AlgebraicLeaf | p.CommonSubexpression
Dependencies: TypeAlias = Set[Dependency]

if not TYPE_CHECKING:
DependenciesT: TypeAlias = Dependencies


class DependencyMapper(
CSECachingMapperMixin[DependenciesT, P],
CSECachingMapperMixin[Dependencies, P],
Collector[p.AlgebraicLeaf | p.CommonSubexpression, P],
):
"""Maps an expression to the :class:`set` of expressions it
Expand Down Expand Up @@ -84,13 +88,13 @@ def __init__(
@override
def map_variable(
self, expr: p.Variable, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
return {expr}

@override
def map_call(
self, expr: p.Call, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
if self.include_calls == "descend_args":
return self.combine([
self.rec(child, *args, **kwargs) for child in expr.parameters
Expand All @@ -103,7 +107,7 @@ def map_call(
@override
def map_call_with_kwargs(
self, expr: p.CallWithKwargs, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
if self.include_calls == "descend_args":
return self.combine(
[self.rec(child, *args, **kwargs) for child in expr.parameters]
Expand All @@ -120,7 +124,7 @@ def map_call_with_kwargs(
@override
def map_lookup(
self, expr: p.Lookup, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
if self.include_lookups:
return {expr}
else:
Expand All @@ -129,7 +133,7 @@ def map_lookup(
@override
def map_subscript(
self, expr: p.Subscript, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
if self.include_subscripts:
return {expr}
else:
Expand All @@ -138,7 +142,7 @@ def map_subscript(
@override
def map_common_subexpression_uncached(
self, expr: p.CommonSubexpression, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
if self.include_cses:
return {expr}
else:
Expand All @@ -148,19 +152,19 @@ def map_common_subexpression_uncached(
@override
def map_slice(
self, expr: p.Slice, *args: P.args, **kwargs: P.kwargs
) -> DependenciesT:
) -> Dependencies:
return self.combine([
self.rec(child, *args, **kwargs)
for child in expr.children
if child is not None
])

@override
def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> DependenciesT:
def map_nan(self, expr: p.NaN, *args: P.args, **kwargs: P.kwargs) -> Dependencies:
return set()


class CachedDependencyMapper(CachedMapper[DependenciesT, P],
class CachedDependencyMapper(CachedMapper[Dependencies, P],
DependencyMapper[P]):
def __init__(
self,
Expand Down
11 changes: 7 additions & 4 deletions pymbolic/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pytools.lex
from pytools import memoize_method

from pymbolic.primitives import is_arithmetic_expression
from pymbolic.primitives import Slice, is_arithmetic_expression


if TYPE_CHECKING:
Expand Down Expand Up @@ -104,7 +104,9 @@
_PREC_CALL = 250


def _join_to_slice(left, right):
def _join_to_slice(
left: Slice | ArithmeticExpression | None,
right: ArithmeticExpression | None):
from pymbolic.primitives import Slice
if isinstance(right, Slice):
return Slice((left, *right.children))
Expand Down Expand Up @@ -247,7 +249,7 @@ def parse_prefix(self, pstate: LexIterator):
expr_pstate = pstate.copy()
from pytools.lex import ParseError
try:
next_expr = self.parse_expression(expr_pstate, _PREC_SLICE)
next_expr = self.parse_arith_expression(expr_pstate, _PREC_SLICE)
except ParseError:
# no expression follows, too bad.
left_exp = primitives.Slice((None,))
Expand Down Expand Up @@ -489,10 +491,11 @@ def parse_postfix(self,
expr_pstate = pstate.copy()

assert not isinstance(left_exp, primitives.Slice)
assert is_arithmetic_expression(left_exp)

from pytools.lex import ParseError
try:
next_expr = self.parse_expression(expr_pstate, _PREC_SLICE)
next_expr = self.parse_arith_expression(expr_pstate, _PREC_SLICE)
except ParseError:
# no expression follows, too bad.
left_exp = primitives.Slice((left_exp, None,))
Expand Down
Loading
Loading