Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions legate/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
#


from collections.abc import Iterable

from .partition import Restriction


class Expr(object):
def __eq__(self, rhs):
return Alignment(self, rhs)
Expand All @@ -28,8 +33,22 @@ def __add__(self, offset):
raise ValueError("Dimensions don't match")
return Translate(self, offset)

def broadcast(self):
return Broadcast(self)
def broadcast(self, axes=None):
if axes is None:
axes = set(range(self.ndim))
else:
if isinstance(axes, Iterable):
axes = set(axes)
else:
axes = {axes}
restrictions = []
for i in range(self.ndim):
restrictions.append(
Restriction.RESTRICTED
if i in axes
else Restriction.UNRESTRICTED
)
return Broadcast(self, restrictions)


class Lit(Expr):
Expand Down Expand Up @@ -151,8 +170,9 @@ def __repr__(self):


class Broadcast(Constraint):
def __init__(self, expr):
def __init__(self, expr, restrictions):
self._expr = expr
self._restrictions = restrictions

def __repr__(self):
return f"Broadcast({self._expr})"
return f"Broadcast({self._expr}, axes={self._restrictions})"
4 changes: 2 additions & 2 deletions legate/core/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def add_alignment(self, store1, store2):
part2 = self._get_unique_partition(store2)
self.add_constraint(part1 == part2)

def add_broadcast(self, store):
def add_broadcast(self, store, axes=None):
self._check_store(store)
part = self._get_unique_partition(store)
self.add_constraint(part.broadcast())
self.add_constraint(part.broadcast(axes=axes))

def add_constraint(self, constraint):
self._constraints.append(constraint)
Expand Down
37 changes: 24 additions & 13 deletions legate/core/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from .utils import OrderedSet


def join_restrictions(x, y):
return tuple(min(a, b) for a, b in zip(x, y))


class EqClass(object):
def __init__(self):
# Maps a variable to the equivalent class id
Expand Down Expand Up @@ -172,13 +176,13 @@ def __init__(self, runtime, ops, must_be_single=False):
self._ops = ops
self._must_be_single = must_be_single

def _solve_broadcast_constraints(
self, unknowns, constraints, broadcasts, partitions
def _solve_constraints_for_futures(
self, unknowns, constraints, partitions
):
to_remove = OrderedSet()
for unknown in unknowns:
store = unknown.store
if not (store.kind is Future or unknown in broadcasts):
if store.kind is not Future:
continue

to_remove.add(unknown)
Expand Down Expand Up @@ -220,40 +224,44 @@ def _solve_unbound_constraints(
return unknowns - to_remove, len(to_remove) > 0

@staticmethod
def _find_restrictions(cls):
def _find_restrictions(cls, broadcasts):
merged = None
for unknown in cls:
store = unknown.store
restrictions = store.find_restrictions()
if unknown in broadcasts:
restrictions = join_restrictions(
broadcasts[unknown], restrictions
)
if merged is None:
merged = restrictions
else:
merged = tuple(min(a, b) for a, b in zip(merged, restrictions))
merged = join_restrictions(merged, restrictions)
return merged

def _find_all_restrictions(self, unknowns, constraints):
def _find_all_restrictions(self, unknowns, broadcasts, constraints):
all_restrictions = {}
for unknown in unknowns:
if unknown in all_restrictions:
continue
cls = constraints.find(unknown)
restrictions = self._find_restrictions(cls)
restrictions = self._find_restrictions(cls, broadcasts)
for store in cls:
all_restrictions[unknown] = restrictions
return all_restrictions

def partition_stores(self):
unknowns = OrderedSet()
constraints = EqClass()
broadcasts = OrderedSet()
broadcasts = {}
dependent = {}
for op in self._ops:
unknowns.update(op.all_unknowns)
for c in op.constraints:
if isinstance(c, Alignment):
constraints.record(c._lhs, c._rhs)
elif isinstance(c, Broadcast):
broadcasts.add(c._expr)
broadcasts[c._expr] = c._restrictions
elif isinstance(c, Containment):
if c._rhs in dependent:
raise NotImplementedError(
Expand All @@ -263,15 +271,16 @@ def partition_stores(self):
dependent[c._rhs] = c._lhs

if self._must_be_single or len(unknowns) == 0:
broadcasts = unknowns
for unknown in unknowns:
c = unknown.broadcast()
broadcasts[unknown] = c._restrictions

partitions = {}
fspaces = {}

unknowns = self._solve_broadcast_constraints(
unknowns = self._solve_constraints_for_futures(
unknowns,
constraints,
broadcasts,
partitions,
)

Expand All @@ -282,7 +291,9 @@ def partition_stores(self):
fspaces,
)

all_restrictions = self._find_all_restrictions(unknowns, constraints)
all_restrictions = self._find_all_restrictions(
unknowns, broadcasts, constraints
)

def cost(unknown):
store = unknown.store
Expand Down