Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13,324 changes: 3,658 additions & 9,666 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

18 changes: 0 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,6 @@ jobs:

run_pylint pymbolic test/test_*.py

mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
-
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
pip install -e .[test]
python -m pip install mypy numpy
./run-mypy.sh

basedpyright:
runs-on: ubuntu-latest

Expand Down
13 changes: 0 additions & 13 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,6 @@ Pylint:
except:
- tags

Mypy:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
pip install -e .[test]
python -m pip install mypy numpy
./run-mypy.sh
tags:
- python3
except:
- tags

Documentation:
script:
- EXTRA_INSTALL="numpy sympy"
Expand Down
4 changes: 4 additions & 0 deletions experiments/traversal-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def main():
from time import time

if 1:
# Allow for JIT warm-up
for _ in range(10):
main()

t_start = time()
for _ in range(10_000):
main()
Expand Down
42 changes: 39 additions & 3 deletions pymbolic/interop/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import ast
from typing import TYPE_CHECKING, Any, ClassVar

from typing_extensions import override

import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper

Expand Down Expand Up @@ -201,7 +203,7 @@
func = self.rec(expr.func)
args = tuple([self.rec(arg) for arg in expr.args])
if getattr(expr, "keywords", []):
return p.CallWithKwargs(func, args,

Check warning on line 206 in pymbolic/interop/ast.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

CallWithKwargs created with non-hashable kw_parameters. This is deprecated and will stop working in 2025. If you need an immutable mapping, try the :mod:`constantdict` package.

Check warning on line 206 in pymbolic/interop/ast.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

CallWithKwargs created with non-hashable kw_parameters. This is deprecated and will stop working in 2025. If you need an immutable mapping, try the :mod:`constantdict` package.
{
kw.arg: self.rec(kw.value)
for kw in expr.keywords})
Expand Down Expand Up @@ -262,6 +264,7 @@
# {{{ PymbolicToASTMapper

class PymbolicToASTMapper(CachedMapper[ast.expr, []]):
@override
def map_variable(self, expr) -> ast.expr:
return ast.Name(id=expr.name)

Expand All @@ -275,22 +278,27 @@

return result

@override
def map_sum(self, expr: p.Sum) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Add())

@override
def map_product(self, expr: p.Product) -> ast.expr:
return self._map_multi_children_op(expr.children, ast.Mult())

@override
def map_constant(self, expr: object) -> ast.expr:
return ast.Constant(expr, None)

@override
def map_call(self, expr: p.Call) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
args=[self.rec(param) for param in expr.parameters],
keywords=[],
)

@override
def map_call_with_kwargs(self, expr) -> ast.expr:
return ast.Call(
func=self.rec(expr.function),
Expand All @@ -301,81 +309,100 @@
value=self.rec(param))
for kw, param in sorted(expr.kw_parameters.items())])

@override
def map_subscript(self, expr) -> ast.expr:
return ast.Subscript(value=self.rec(expr.aggregate),
slice=self.rec(expr.index))

@override
def map_lookup(self, expr) -> ast.expr:
return ast.Attribute(self.rec(expr.aggregate),
expr.name)

@override
def map_quotient(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Div())

@override
def map_floor_div(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.FloorDiv())

@override
def map_remainder(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.Mod())

@override
def map_power(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.base,
expr.exponent),
ast.Pow())

@override
def map_left_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.shiftee,
expr.shift),
ast.LShift())

@override
def map_right_shift(self, expr) -> ast.expr:
return self._map_multi_children_op((expr.numerator,
expr.denominator),
ast.RShift())

@override
def map_bitwise_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Invert(), self.rec(expr.child))

@override
def map_bitwise_or(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitOr())

@override
def map_bitwise_xor(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitXor())

@override
def map_bitwise_and(self, expr) -> ast.expr:
return self._map_multi_children_op(expr.children,
ast.BitAnd())

@override
def map_logical_not(self, expr) -> ast.expr:
return ast.UnaryOp(ast.Not(), self.rec(expr.child))

@override
def map_logical_or(self, expr) -> ast.expr:
return ast.BoolOp(ast.Or(), [self.rec(child)
for child in expr.children])

@override
def map_logical_and(self, expr) -> ast.expr:
return ast.BoolOp(ast.And(), [self.rec(child)
for child in expr.children])

@override
def map_list(self, expr: list[Any]) -> ast.expr:
return ast.List([self.rec(el) for el in expr])

@override
def map_tuple(self, expr: tuple[Any, ...]) -> ast.expr:
return ast.Tuple([self.rec(el) for el in expr])

@override
def map_if(self, expr: p.If) -> ast.expr:
return ast.IfExp(test=self.rec(expr.condition),
body=self.rec(expr.then),
orelse=self.rec(expr.else_))

@override
def map_nan(self, expr: p.NaN) -> ast.expr:
assert expr.data_type is not None
if isinstance(expr.data_type(float("nan")), float):
Expand All @@ -387,46 +414,55 @@
# TODO: would need attributes of NumPy
raise NotImplementedError("Non-float nan not implemented")

@override
def map_slice(self, expr: p.Slice) -> ast.expr:
return ast.Slice(*[None if child is None else self.rec(child)
for child in expr.children])

@override
def map_numpy_array(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_multivector(self, expr) -> ast.expr:
raise NotImplementedError

def map_common_subexpression(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_substitution(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_derivative(self, expr) -> ast.expr:
raise NotImplementedError

def map_if_positive(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_comparison(self, expr: p.Comparison) -> ast.expr:
raise NotImplementedError

@override
def map_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_dot_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_star_wildcard(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_function_symbol(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_min(self, expr) -> ast.expr:
raise NotImplementedError

@override
def map_max(self, expr) -> ast.expr:
raise NotImplementedError

Expand Down
12 changes: 11 additions & 1 deletion pymbolic/interop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from functools import partial

from typing_extensions import override

import pymbolic.primitives as prim
from pymbolic.mapper.evaluator import EvaluationMapper

Expand Down Expand Up @@ -154,15 +155,19 @@ def sym(self):
def raise_conversion_error(self, message):
raise NotImplementedError

@override
def map_variable(self, expr):
return self.sym.Symbol(expr.name)

@override
def map_constant(self, expr):
return self.sym.sympify(expr)

@override
def map_floor_div(self, expr):
return self.sym.floor(self.rec(expr.numerator) / self.rec(expr.denominator))

@override
def map_call(self, expr):
if isinstance(expr.function, prim.Variable):
func_name = expr.function.name
Expand All @@ -174,25 +179,29 @@ def map_call(self, expr):
else:
self.raise_conversion_error(expr)

@override
def map_subscript(self, expr):
if isinstance(expr.aggregate, prim.Variable):
return self.sym.Function("Indexed")(expr.aggregate.name,
*(self.rec(idx) for idx in expr.index_tuple))
else:
self.raise_conversion_error(expr)

@override
def map_substitution(self, expr):
return self.sym.Subs(self.rec(expr.child),
tuple([self.sym.Symbol(v) for v in expr.variables]),
tuple([self.rec(v) for v in expr.values]),
)

@override
def map_if(self, expr):
cond = self.rec(expr.condition)
return self.sym.Piecewise((self.rec(expr.then), cond),
(self.rec(expr.else_), True)
)

@override
def map_comparison(self, expr):
left = self.rec(expr.left)
right = self.rec(expr.right)
Expand All @@ -211,6 +220,7 @@ def map_comparison(self, expr):
else:
raise NotImplementedError(f"Unknown operator '{expr.operator}'")

@override
def map_derivative(self, expr):
return self.sym.Derivative(self.rec(expr.child),
*[self.sym.Symbol(v) for v in expr.variables])
Expand Down
Loading
Loading