Skip to content

Commit

Permalink
Merge pull request #1719 from tjb900/generalised_msf
Browse files Browse the repository at this point in the history
Generalised MatrixSparseTimeFunction
  • Loading branch information
FabioLuporini committed Jul 6, 2021
2 parents 47aaa87 + 8b8543a commit cde7a8e
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 106 deletions.
7 changes: 5 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.symbolics import (retrieve_functions, retrieve_indexed, split_affine,
uxreplace)
from devito.tools import PartialOrderTuple, filter_sorted, flatten, as_tuple
from devito.types import Dimension, Eq
from devito.types import Dimension, Eq, IgnoreDimSort

__all__ = ['dimension_sort', 'generate_implicit_exprs', 'lower_exprs']

Expand Down Expand Up @@ -36,7 +36,10 @@ def handle_indexed(indexed):
if isinstance(d, Dimension)])
return tuple(relation)

relations = {handle_indexed(i) for i in retrieve_indexed(expr)}
if isinstance(expr.implicit_dims, IgnoreDimSort):
relations = set()
else:
relations = {handle_indexed(i) for i in retrieve_indexed(expr)}

# Add in any implicit dimension (typical of scalar temporaries, or Step)
relations.add(expr.implicit_dims)
Expand Down
15 changes: 11 additions & 4 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
__all__ = ['cse']


class Temp(Symbol):
pass


@cluster_pass
def cse(cluster, sregistry, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Symbol(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
processed = _cse(cluster.exprs, make)
make = lambda: Temp(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
processed = _cse(cluster, make)

return cluster.rebuild(processed)

Expand Down Expand Up @@ -108,9 +112,12 @@ def _compact_temporaries(exprs):
# First of all, convert to SSA
exprs = makeit_ssa(exprs)

# What's gonna be dropped
# Drop candidates are all exprs in the form `t0 = s` where `s` is a symbol
# Note: only CSE-captured Temps, which are by construction local objects, may
# safely be compacted; a generic Symbol could instead be accessed in a subsequent
# Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...`
mapper = {e.lhs: e.rhs for e in exprs
if e.lhs.is_Symbol and (q_leaf(e.rhs) or e.rhs.is_Function)}
if isinstance(e.lhs, Temp) and (q_leaf(e.rhs) or e.rhs.is_Function)}

processed = []
for e in exprs:
Expand Down
6 changes: 5 additions & 1 deletion devito/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def as_list(item, type=None, length=None):

def as_tuple(item, type=None, length=None):
"""
Force item to a tuple.
Force item to a tuple. Passes tuple subclasses through also.
Partly extracted from: https://github.com/OP2/PyOP2/.
"""
Expand All @@ -39,13 +39,17 @@ def as_tuple(item, type=None, length=None):
t = ()
elif isinstance(item, (str, sympy.Function)):
t = (item,)
elif isinstance(item, tuple):
# this makes tuple subclasses pass through
t = item
else:
# Convert iterable to list...
try:
t = tuple(item)
# ... or create a list of a single item
except (TypeError, NotImplementedError):
t = (item,) * (length or 1)

if length and not len(t) == length:
raise ValueError("Tuple needs to be of length %d" % length)
if type and not all(isinstance(i, type) for i in t):
Expand Down

0 comments on commit cde7a8e

Please sign in to comment.