Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20,136 changes: 7,962 additions & 12,174 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions pymbolic/imperative/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@
"""


def get_all_used_insn_ids(insn_stream):
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from collections.abc import Iterable

from pymbolic.imperative.statement import StatementLike


def get_all_used_insn_ids(insn_stream: Iterable[StatementLike]):
return frozenset(insn.id for insn in insn_stream)


def get_all_used_identifiers(insn_stream):
result = set()
def get_all_used_identifiers(insn_stream: Iterable[StatementLike]):
result: set[str] = set()
for insn in insn_stream:
result |= insn.get_read_variables()
result |= insn.get_written_variables()
Expand Down
136 changes: 90 additions & 46 deletions pymbolic/imperative/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,121 +24,155 @@
THE SOFTWARE.
"""

from dataclasses import dataclass, replace
from sys import intern
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, cast

from pytools import RecordWithoutPickling
from typing_extensions import Self, override

from pymbolic.typing import not_none
from pymbolic.typing import Expression, not_none


# {{{ statemetn classes
if TYPE_CHECKING:
from collections.abc import Callable, Set

class Statement(RecordWithoutPickling):
"""
.. attribute:: depends_on
from pymbolic.primitives import Variable


# {{{ statement classes

class BasicStatementLike(Protocol):
@property
def id(self) -> str: ...
@property
def depends_on(self) -> Set[str]: ...

def copy(self, **kwargs: object) -> Self: ...


BasicStatementLikeT = TypeVar("BasicStatementLikeT", bound=BasicStatementLike)


class StatementLike(BasicStatementLike, Protocol):
def get_written_variables(self) -> Set[str]: ...

A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed.
def get_read_variables(self) -> Set[str]: ...

.. attribute:: id
def map_expressions(self,
mapper: Callable[[Expression], Expression],
include_lhs: bool = True
) -> Self: ...

A string, a unique identifier for this instruction.

StatementLikeT = TypeVar("StatementLikeT", bound=StatementLike)


@dataclass(frozen=True)
class Statement:
"""
.. autoattribute:: depends_on
.. autoattribute:: id

.. automethod:: get_written_variables
.. automethod:: get_read_variables
"""
id: str
"""
A string, a unique identifier for this instruction.
"""

def __init__(self, **kwargs):
id = kwargs.pop("id", None)
if id is not None:
id = intern(id)
depends_on: Set[str]
"""A :class:`frozenset` of instruction ids that are reuqired to be
executed within this execution context before this instruction can be
executed."""

depends_on = frozenset(kwargs.pop("depends_on", []))
super().__init__(
id=id,
depends_on=depends_on,
**kwargs)
def __post_init__(self):
object.__setattr__(self, "id", intern(self.id))

def get_written_variables(self):
def get_written_variables(self) -> Set[str]:
"""Returns a :class:`frozenset` of variables being written by this
instruction.
"""
return frozenset()

def get_read_variables(self):
def get_read_variables(self) -> Set[str]:
"""Returns a :class:`frozenset` of variables being read by this
instruction.
"""
return frozenset()

def map_expressions(self, mapper, include_lhs=True):
def map_expressions(self,
mapper: Callable[[Expression], Expression],
include_lhs: bool = True
) -> Self:
"""Returns a new copy of *self* with all expressions
replaced by ``mapepr(expr)`` for every
:class:`pymbolic.primitives.Expression`
contained in *self*.
"""
return self

def get_dependency_mapper(self, include_calls="descend_args"):
def get_dependency_mapper(self,
include_calls: bool | Literal["descend_args"] = True,
):
from pymbolic.mapper.dependency import DependencyMapper
return DependencyMapper(
return DependencyMapper[[]](
include_subscripts=False,
include_lookups=False,
include_calls=include_calls)

def copy(self, **kwargs: Any) -> Self: # pyright: ignore[reportAny]
return replace(self, **kwargs)

# }}}


# {{{ statement with condition

@dataclass(frozen=True)
class ConditionalStatement(Statement):
__doc__ = not_none(Statement.__doc__) + """
.. attribute:: condition

The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)
.. autoattribute:: condition
"""

def __init__(self, **kwargs):
condition = kwargs.pop("condition", True)
super().__init__(
condition=condition,
**kwargs)
condition: Expression
"""The instruction condition as a :mod:`pymbolic` expression (`True` if the
instruction is unconditionally executed)"""

def _condition_printing_suffix(self):
if self.condition is True:
return ""
return " if " + str(self.condition)

@override
def __str__(self):
return (super().__str__()
+ self._condition_printing_suffix())

def get_read_variables(self):
@override
def get_read_variables(self) -> Set[str]:
dep_mapper = self.get_dependency_mapper()
return (
super().get_read_variables()
| frozenset(
dep.name for dep in dep_mapper(self.condition)))
cast("Variable", dep).name for dep in dep_mapper(self.condition)))

# }}}


# {{{ assignment

@dataclass(frozen=True)
class Assignment(Statement):
"""
.. attribute:: lhs
.. attribute:: rhs
"""

def __init__(self, lhs, rhs, **kwargs):
super().__init__(
lhs=lhs,
rhs=rhs,
**kwargs)
lhs: Expression
rhs: Expression

@override
def get_written_variables(self):
from pymbolic.primitives import Subscript, Variable
if isinstance(self.lhs, Variable):
Expand All @@ -149,24 +183,30 @@ def get_written_variables(self):
else:
raise TypeError("unexpected type of LHS")

def get_read_variables(self):
@override
def get_read_variables(self) -> Set[str]:
result = super().get_read_variables()
get_deps = self.get_dependency_mapper()

def get_vars(expr):
return frozenset(dep.name for dep in get_deps(self.rhs))
def get_vars(expr: Expression):
return frozenset(cast("Variable", dep).name for dep in get_deps(expr))

result = get_vars(self.rhs) | get_vars(self.lhs)

return result

def map_expressions(self, mapper, include_lhs=True):
@override
def map_expressions(self,
mapper: Callable[[Expression], Expression],
include_lhs: bool = True
) -> Self:
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(
lhs=mapper(self.lhs) if include_lhs else self.lhs,
rhs=mapper(self.rhs)))

@override
def __str__(self):
result = "{assignee} <- {expr}".format(
assignee=str(self.lhs),
Expand All @@ -180,7 +220,11 @@ def __str__(self):
# {{{ conditional assignment

class ConditionalAssignment(ConditionalStatement, Assignment):
def map_expressions(self, mapper, include_lhs=True):
@override
def map_expressions(self,
mapper: Callable[[Expression], Expression],
include_lhs: bool = True
) -> Self:
return (super()
.map_expressions(mapper, include_lhs=include_lhs)
.copy(condition=mapper(self.condition)))
Expand Down
47 changes: 36 additions & 11 deletions pymbolic/imperative/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,33 @@
"""


from typing import TYPE_CHECKING


if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from pymbolic.imperative.statement import (
BasicStatementLikeT,
StatementLike,
StatementLikeT,
)
from pymbolic.primitives import Variable


# {{{ fuse statement streams

def fuse_statement_streams_with_unique_ids(statements_a, statements_b):
def fuse_statement_streams_with_unique_ids(
statements_a: Sequence[BasicStatementLikeT],
statements_b: Sequence[BasicStatementLikeT]
) -> tuple[list[BasicStatementLikeT], dict[str, str]]:
new_statements = list(statements_a)
from pytools import UniqueNameGenerator
stmt_id_gen = UniqueNameGenerator(
{stmta.id for stmta in new_statements})

b_unique_statements = []
old_b_id_to_new_b_id = {}
b_unique_statements: list[BasicStatementLikeT] = []
old_b_id_to_new_b_id: dict[str, str] = {}
for stmtb in statements_b:
old_id = stmtb.id
new_id = stmt_id_gen(old_id)
Expand All @@ -53,7 +70,10 @@ def fuse_statement_streams_with_unique_ids(statements_a, statements_b):
return new_statements, old_b_id_to_new_b_id


def fuse_instruction_streams_with_unique_ids(insns_a, insns_b):
def fuse_instruction_streams_with_unique_ids(
insns_a: Sequence[StatementLikeT],
insns_b: Sequence[StatementLikeT]
):
from warnings import warn
warn("fuse_instruction_streams_with_unique_ids has been renamed to "
"fuse_statement_streams_with_unique_ids", DeprecationWarning,
Expand All @@ -66,11 +86,13 @@ def fuse_instruction_streams_with_unique_ids(insns_a, insns_b):

# {{{ disambiguate_identifiers

def disambiguate_identifiers(statements_a, statements_b,
should_disambiguate_name=None):
def disambiguate_identifiers(
statements_a: Sequence[StatementLike],
statements_b: Sequence[StatementLike],
should_disambiguate_name: Callable[[str], bool] | None = None,
):
if should_disambiguate_name is None:
def should_disambiguate_name(name): # pylint:disable=function-redefined
return True
should_disambiguate_name = lambda name: True # noqa: E731

from pymbolic.imperative.analysis import get_all_used_identifiers

Expand All @@ -81,7 +103,7 @@ def should_disambiguate_name(name): # pylint:disable=function-redefined
vng = UniqueNameGenerator(id_a | id_b)

from pymbolic import var
subst_b = {}
subst_b: dict[str, Variable] = {}
for clash in id_a & id_b:
if should_disambiguate_name(clash):
unclash = vng(clash)
Expand All @@ -100,8 +122,11 @@ def should_disambiguate_name(name): # pylint:disable=function-redefined

# {{{ disambiguate_and_fuse

def disambiguate_and_fuse(statements_a, statements_b,
should_disambiguate_name=None):
def disambiguate_and_fuse(
statements_a: Sequence[StatementLike],
statements_b: Sequence[StatementLike],
should_disambiguate_name: Callable[[str], bool] | None = None,
):
statements_b, subst_b = disambiguate_identifiers(
statements_a, statements_b,
should_disambiguate_name)
Expand Down
2 changes: 1 addition & 1 deletion pymbolic/interop/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
func = self.rec(expr.func)
args = tuple([self.rec(arg) for arg in expr.args])
if getattr(expr, "keywords", []):
return p.CallWithKwargs(func, args,

Check warning on line 206 in pymbolic/interop/ast.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

CallWithKwargs created with non-hashable kw_parameters. This is deprecated and will stop working in 2025. If you need an immutable mapping, try the :mod:`constantdict` package.

Check warning on line 206 in pymbolic/interop/ast.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

CallWithKwargs created with non-hashable kw_parameters. This is deprecated and will stop working in 2025. If you need an immutable mapping, try the :mod:`constantdict` package.
{
kw.arg: self.rec(kw.value)
for kw in expr.keywords})
Expand Down Expand Up @@ -288,7 +288,7 @@

@override
def map_constant(self, expr: object) -> ast.expr:
return ast.Constant(expr, None)
return ast.Constant(expr, None) # pyright: ignore[reportArgumentType]

@override
def map_call(self, expr: p.Call) -> ast.expr:
Expand Down
Loading
Loading