Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions legate/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def reduce(self):


class PartSym(Expr):
def __init__(self, op, store, id, disjoint, complete):
self._op = op
def __init__(self, op_hash, op_name, store, id, disjoint, complete):
self._op_hash = op_hash
self._op_name = op_name
self._store = store
self._id = id
self._disjoint = disjoint
Expand All @@ -66,17 +67,21 @@ def __init__(self, op, store, id, disjoint, complete):
def ndim(self):
return self._store.ndim

@property
def store(self):
return self._store

@property
def closed(self):
return False

def __repr__(self):
disj = "D" if self._disjoint else "A"
comp = "C" if self._complete else "I"
return f"X{self._id}({disj},{comp})@{self._op.get_name()}"
return f"X{self._id}({disj},{comp})@{self._op_name}"

def __hash__(self):
return hash((self._op, self._id))
return hash((self._op_hash, self._id))

def subst(self, mapping):
return Lit(mapping[self])
Expand Down
3 changes: 2 additions & 1 deletion legate/core/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _get_symbol_id(self):

def declare_partition(self, store, disjoint=True, complete=True):
sym = PartSym(
self,
self._op_id,
self.get_name(),
store,
self._get_symbol_id(),
disjoint=disjoint,
Expand Down
60 changes: 45 additions & 15 deletions legate/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import gc
import math
import struct
import weakref
from collections import deque
from functools import reduce

Expand Down Expand Up @@ -319,11 +320,19 @@ def __init__(self, ptr, extent, region_field):
self.ptr = ptr
self.extent = extent
self.end = ptr + extent - 1
self.region_field = region_field
self._region_field = weakref.ref(region_field)

def overlaps(self, other):
return not (self.end < other.ptr or other.end < self.ptr)

@property
def region_field(self):
return self._region_field()

@region_field.setter
def region_field(self, region_field):
self._region_field = weakref.ref(region_field)


class AttachmentManager(object):
def __init__(self, runtime):
Expand Down Expand Up @@ -359,30 +368,55 @@ def attachment_key(alloc):

def has_attachment(self, alloc):
key = self.attachment_key(alloc)
return key in self._attachments
attachment = self._attachments.get(key, None)
return attachment is not None and attachment.region_field

def reuse_existing_attachment(self, alloc):
key = self.attachment_key(alloc)
if key not in self._attachments:
attachment = self._attachments.get(key, None)
if attachment is None:
return None
attachment = self._attachments[key]
return attachment.region_field
rf = attachment.region_field
# If the region field is already collected, we don't need to keep
# track of it for de-duplication.
if rf is None:
del self._attachments[key]
return rf

def attach_external_allocation(self, alloc, region_field):
key = self.attachment_key(alloc)
if key in self._attachments:
attachment = self._attachments.get(key, None)
if not (attachment is None or attachment.region_field is None):
raise RuntimeError(
"Cannot attach two different RegionFields to the same buffer"
)
attachment = Attachment(*key, region_field)
if attachment is None:
attachment = Attachment(*key, region_field)
else:
attachment.region_field = region_field
# We temporary remove the attachment from the map for
# the following alias checking
del self._attachments[key]
for other in self._attachments.values():
if other.overlaps(attachment):
raise RuntimeError(
"Aliased attachments not supported by Legate"
)
self._attachments[key] = attachment

def detach_external_allocation(self, alloc, detach, defer):
def _remove_allocation(self, alloc):
key = self.attachment_key(alloc)
if key not in self._attachments:
raise RuntimeError("Unable to find attachment to remove")
del self._attachments[key]

def detach_external_allocation(
self, alloc, detach, defer=False, previously_deferred=False
):
# If the detachment was previously deferred, then we don't
# need to remove the allocation from the map again.
if not previously_deferred:
self._remove_allocation(alloc)
if defer:
# If we need to defer this until later do that now
self._deferred_detachments.append((alloc, detach))
Expand All @@ -391,12 +425,6 @@ def detach_external_allocation(self, alloc, detach, defer):
# Dangle a reference to the field off the future to prevent the
# field from being recycled until the detach is done
future.field_reference = detach.field
# We also need to tell the core legate library that this buffer
# is no longer attached
key = self.attachment_key(alloc)
if key not in self._attachments:
raise RuntimeError("Unable to find attachment to remove")
del self._attachments[key]
# If the future is already ready, then no need to track it
if future.is_ready():
return
Expand All @@ -417,7 +445,9 @@ def perform_detachments(self):
detachments = self._deferred_detachments
self._deferred_detachments = list()
for alloc, detach in detachments:
self.detach_external_allocation(alloc, detach, defer=False)
self.detach_external_allocation(
alloc, detach, defer=False, previously_deferred=True
)

def prune_detachments(self):
to_remove = []
Expand Down
18 changes: 9 additions & 9 deletions legate/core/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ def launch_domain(self):

def get_projection(self, part):
partition = self.get_partition(part)
return partition.get_requirement(self._launch_shape, part._store)
return partition.get_requirement(self._launch_shape, part.store)

def get_partition(self, part):
assert not part._store.unbound
assert not part.store.unbound
if part not in self._strategy:
raise ValueError(f"No strategy is found for {part}")
return self._strategy[part]

def get_field_space(self, part):
assert part._store.unbound
assert part.store.unbound
if part not in self._fspaces:
raise ValueError(f"No strategy is found for {part}")
return self._fspaces[part]
Expand Down Expand Up @@ -160,7 +160,7 @@ def _solve_broadcast_constraints(
):
to_remove = OrderedSet()
for unknown in unknowns:
store = unknown._store
store = unknown.store
if not (store.kind is Future or unknown in broadcasts):
continue

Expand All @@ -183,7 +183,7 @@ def _solve_unbound_constraints(
):
to_remove = OrderedSet()
for unknown in unknowns:
store = unknown._store
store = unknown.store
if not store.unbound:
continue

Expand All @@ -193,7 +193,7 @@ def _solve_unbound_constraints(
continue

cls = constraints.find(unknown)
assert all(to_align._store.unbound for to_align in cls)
assert all(to_align.store.unbound for to_align in cls)

fspace = self._runtime.create_field_space()
for to_align in cls:
Expand All @@ -206,7 +206,7 @@ def _solve_unbound_constraints(
def _find_restrictions(cls):
merged = None
for unknown in cls:
store = unknown._store
store = unknown.store
restrictions = store.find_restrictions()
if merged is None:
merged = restrictions
Expand Down Expand Up @@ -268,7 +268,7 @@ def partition_stores(self):
all_restrictions = self._find_all_restrictions(unknowns, constraints)

def cost(unknown):
store = unknown._store
store = unknown.store
return (
-store.comm_volume(),
not store.has_key_partition(all_restrictions[unknown]),
Expand All @@ -284,7 +284,7 @@ def cost(unknown):
elif unknown in dependent:
continue

store = unknown._store
store = unknown.store
restrictions = all_restrictions[unknown]

if isinstance(prev_part, NoPartition):
Expand Down