Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions legate/core/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2021 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


class Expr(object):
def __eq__(self, rhs):
return Alignment(self, rhs)

def __le__(self, rhs):
return Containment(self, rhs)

def __add__(self, offset):
if not isinstance(offset, tuple):
raise ValueError("Offset must be a tuple")
elif self.ndim != len(offset):
raise ValueError("Dimensions don't match")
return Translate(self, offset)

def broadcast(self):
return Broadcast(self)


class Lit(Expr):
def __init__(self, part):
self._part = part

@property
def ndim(self):
raise NotImplementedError("ndim not implemented for literals")

@property
def closed(self):
return True

def __repr__(self):
return f"Lit({self._part})"

def subst(self):
return self

def reduce(self):
return self


class PartSym(Expr):
def __init__(self, op, store, id, disjoint, complete):
self._op = op
self._store = store
self._id = id
self._disjoint = disjoint
self._complete = complete

@property
def ndim(self):
return self._store.ndim

@property
def closed(self):
return False

def __repr__(self):
disj = "D" if self._disjoint else "A"
comp = "C" if self._complete else "I"
return f"X{self._id}({disj},{comp})@{self._op.get_name()}"

def __hash__(self):
return hash((self._op, self._id))

def subst(self, mapping):
return Lit(mapping[self])

def reduce(self):
return self


class Translate(Expr):
def __init__(self, expr, offset):
if not isinstance(expr, (PartSym, Lit)):
raise NotImplementedError(
"Compound expression is not supported yet"
)
self._expr = expr
self._offset = offset

@property
def ndim(self):
return len(self._offset)

@property
def closed(self):
return self._expr.closed

def __repr__(self):
return f"{self._expr} + {self._offset}"

def subst(self, mapping):
return Translate(self._expr.subst(mapping), self._offset)

def reduce(self):
expr = self._expr.reduce()
assert isinstance(expr, Lit)
part = expr._part
return Lit(part.translate(self._offset))


class Constraint(object):
pass


class Alignment(Constraint):
def __init__(self, lhs, rhs):
if not isinstance(lhs, PartSym) or not isinstance(rhs, PartSym):
raise NotImplementedError(
"Alignment between complex expressions is not supported yet"
)
self._lhs = lhs
self._rhs = rhs

def __repr__(self):
return f"{self._lhs} == {self._rhs}"


class Containment(Constraint):
def __init__(self, lhs, rhs):
if not isinstance(rhs, PartSym):
raise NotImplementedError(
"Containment on a complex expression is not supported yet"
)
self._lhs = lhs
self._rhs = rhs

def __repr__(self):
return f"{self._lhs} <= {self._rhs}"


class Broadcast(Constraint):
def __init__(self, expr):
self._expr = expr

def __repr__(self):
return f"Broadcast({self._expr})"
9 changes: 8 additions & 1 deletion legate/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def _create_scope(api, category, max_counts):
config.max_shardings,
)

self._unique_op_id = 0

def destroy(self):
self._library.destroy()

Expand Down Expand Up @@ -172,8 +174,13 @@ def get_tunable(self, tunable_id, dtype, mapper_id=0):
buf = fut.get_buffer(dtype.itemsize)
return np.frombuffer(buf, dtype=dtype)[0]

def get_unique_op_id(self):
op_id = self._unique_op_id
self._unique_op_id += 1
return op_id

def create_task(self, task_id, mapper_id=0):
return Task(self, task_id, mapper_id)
return Task(self, task_id, mapper_id, op_id=self.get_unique_op_id())

def create_copy(self, mapper_id=0):
return Copy(self, mapper_id)
Expand Down
111 changes: 84 additions & 27 deletions legate/core/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,28 @@

import legate.core.types as ty

from .constraints import PartSym
from .launcher import CopyLauncher, TaskLauncher
from .solver import EqClass
from .store import Store
from .utils import OrderedSet


class Operation(object):
def __init__(self, context, mapper_id=0):
def __init__(self, context, mapper_id=0, op_id=0):
self._context = context
self._mapper_id = mapper_id
self._op_id = op_id
self._inputs = []
self._outputs = []
self._reductions = []
self._input_parts = []
self._output_parts = []
self._reduction_parts = []
self._scalar_outputs = []
self._scalar_reductions = []
self._constraints = EqClass()
self._broadcasts = OrderedSet()
self._partitions = {}
self._constraints = []
self._all_parts = []

@property
def context(self):
Expand Down Expand Up @@ -66,8 +71,8 @@ def constraints(self):
return self._constraints

@property
def broadcasts(self):
return self._broadcasts
def all_unknowns(self):
return self._all_parts

def get_all_stores(self):
stores = (
Expand All @@ -82,21 +87,42 @@ def _check_store(store):
if not isinstance(store, Store):
raise ValueError(f"Expected a Store object, but got {type(store)}")

def add_input(self, store):
def _get_unique_partition(self, store):
if store not in self._partitions:
return self.declare_partition(store)

parts = self._partitions[store]
if len(parts) > 1:
raise RuntimeError(
"Ambiguous store argument. When multple partitions exist for "
"this store, a partition should be specified."
)
return parts[0]

def add_input(self, store, partition=None):
self._check_store(store)
if partition is None:
partition = self._get_unique_partition(store)
self._inputs.append(store)
self._input_parts.append(partition)

def add_output(self, store):
def add_output(self, store, partition=None):
self._check_store(store)
if store.scalar:
self._scalar_outputs.append(len(self._outputs))
if partition is None:
partition = self._get_unique_partition(store)
self._outputs.append(store)
self._output_parts.append(partition)

def add_reduction(self, store, redop):
def add_reduction(self, store, redop, partition=None):
self._check_store(store)
if store.scalar:
self._scalar_reductions.append(len(self._reductions))
if partition is None:
partition = self._get_unique_partition(store)
self._reductions.append((store, redop))
self._reduction_parts.append(partition)

def add_alignment(self, store1, store2):
self._check_store(store1)
Expand All @@ -106,28 +132,57 @@ def add_alignment(self, store1, store2):
"Stores must have the same shape to be aligned, "
f"but got {store1.shape} and {store2.shape}"
)
self._constraints.record(store1, store2)
part1 = self._get_unique_partition(store1)
part2 = self._get_unique_partition(store2)
self.add_constraint(part1 == part2)

def add_broadcast(self, store):
self._broadcasts.add(store)
self._check_store(store)
part = self._get_unique_partition(store)
self.add_constraint(part.broadcast())

def add_constraint(self, constraint):
self._constraints.append(constraint)

def execute(self):
self._context.runtime.submit(self)

def get_tag(self, strategy, store):
if strategy.is_key_store(store):
def get_tag(self, strategy, part):
if strategy.is_key_part(part):
return 1 # LEGATE_CORE_KEY_STORE_TAG
else:
return 0

def _get_symbol_id(self):
return len(self._all_parts)

def declare_partition(self, store, disjoint=True, complete=True):
sym = PartSym(
self,
store,
self._get_symbol_id(),
disjoint=disjoint,
complete=complete,
)
if store not in self._partitions:
self._partitions[store] = [sym]
else:
self._partitions[store].append(sym)
self._all_parts.append(sym)
return sym


class Task(Operation):
def __init__(self, context, task_id, mapper_id=0):
Operation.__init__(self, context, mapper_id)
def __init__(self, context, task_id, mapper_id=0, op_id=0):
Operation.__init__(self, context, mapper_id=mapper_id, op_id=op_id)
self._task_id = task_id
self._scalar_args = []
self._futures = []

def get_name(self):
libname = self.context.library.get_name()
return f"{libname}.Task(tid:{self._task_id}, uid:{self._op_id})"

def add_scalar_arg(self, value, dtype):
self._scalar_args.append((value, dtype))

Expand All @@ -141,32 +196,34 @@ def add_future(self, future):
def launch(self, strategy):
launcher = TaskLauncher(self.context, self._task_id, self.mapper_id)

for input in self._inputs:
proj = strategy.get_projection(input)
tag = self.get_tag(strategy, input)
for input, input_part in zip(self._inputs, self._input_parts):
proj = strategy.get_projection(input_part)
tag = self.get_tag(strategy, input_part)
launcher.add_input(input, proj, tag=tag)
for output in self._outputs:
for output, output_part in zip(self._outputs, self._output_parts):
if output.unbound:
continue
proj = strategy.get_projection(output)
tag = self.get_tag(strategy, output)
proj = strategy.get_projection(output_part)
tag = self.get_tag(strategy, output_part)
launcher.add_output(output, proj, tag=tag)
partition = strategy.get_partition(output)
partition = strategy.get_partition(output_part)
# We update the key partition of a store only when it gets updated
output.set_key_partition(partition)
for (reduction, redop) in self._reductions:
partition = strategy.get_partition(reduction)
for ((reduction, redop), reduction_part) in zip(
self._reductions, self._reduction_parts
):
partition = strategy.get_partition(reduction_part)
can_read_write = partition.is_disjoint_for(strategy, reduction)
proj = strategy.get_projection(reduction)
proj = strategy.get_projection(reduction_part)
proj.redop = reduction.type.reduction_op_id(redop)
tag = self.get_tag(strategy, reduction)
launcher.add_reduction(
reduction, proj, tag=tag, read_write=can_read_write
)
for output in self._outputs:
for (output, output_part) in zip(self._outputs, self._output_parts):
if not output.unbound:
continue
fspace = strategy.get_field_space(output)
fspace = strategy.get_field_space(output_part)
field_id = fspace.allocate_field(output.type)
launcher.add_unbound_output(output, fspace, field_id)

Expand Down
Loading