Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@

from arraycontext.context import ArrayContext
import numpy as np
from typing import Any, Callable, Union, Sequence
from typing import Any, Callable, Union, Sequence, TYPE_CHECKING
from pytools.tag import Tag

if TYPE_CHECKING:
import pytato


class PytatoPyOpenCLArrayContext(ArrayContext):
"""
Expand All @@ -62,6 +65,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext):
to use the default allocator.

.. automethod:: __init__

.. automethod:: transform_dag
"""

def __init__(self, queue, allocator=None):
Expand Down Expand Up @@ -116,6 +121,7 @@ def call_loopy(self, program, **kwargs):
def freeze(self, array):
import pytato as pt
import pyopencl.array as cla
import loopy as lp

if isinstance(array, cla.Array):
return array.with_queue(None)
Expand All @@ -134,20 +140,27 @@ def freeze(self, array):
# }}}

from arraycontext.impl.pytato.utils import _normalize_pt_expr
normalized_expr, bound_arguments = _normalize_pt_expr(array)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
{"_actx_out": array})

normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays)

try:
pt_prg = self._freeze_prg_cache[normalized_expr]
except KeyError:
pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device)
pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr),
options=lp.Options(return_dict=True,
no_numpy=True),
cl_device=self.queue.device)
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg

assert len(pt_prg.bound_arguments) == 0
evt, (cl_array,) = pt_prg(self.queue, **bound_arguments)
evt, out_dict = pt_prg(self.queue, **bound_arguments)
evt.wait()

return cl_array.with_queue(None)
return out_dict["_actx_out"].with_queue(None)

def thaw(self, array):
import pytato as pt
Expand All @@ -170,6 +183,20 @@ def transform_loopy_program(self, t_unit):
"transform_loopy_program. Sub-classes are supposed "
"to implement it.")

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
*dag* (most likely to perform domain-specific optimizations). Every
:mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is
passed through this routine.

:arg dag: An instance of :class:`pytato.DictOfNamedArrays`
:returns: A transformed version of *dag*.
"""
return dag

def tag(self, tags: Union[Sequence[Tag], Tag], array):
return array.tagged(tags)

Expand Down
6 changes: 5 additions & 1 deletion arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ def _as_dict_of_named_arrays(keys, ary):
outputs)

import loopy as lp
pytato_program = pt.generate_loopy(dict_of_named_arrays,

pt_dict_of_named_arrays = self.actx.transform_dag(
pt.make_dict_of_named_arrays(dict_of_named_arrays))

pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
options=lp.Options(
return_dict=True,
no_numpy=True),
Expand Down
9 changes: 5 additions & 4 deletions arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from typing import Any, Dict, Set, Tuple, Mapping
from pytato.array import SizeParam, Placeholder
from pytato.array import Array, DataWrapper
from pytato.array import Array, DataWrapper, DictOfNamedArrays
from pytato.transform import CopyMapper
from pytools import UniqueNameGenerator

Expand Down Expand Up @@ -66,8 +66,8 @@ def map_placeholder(self, expr: Placeholder) -> Array:
" DatawrapperToBoundPlaceholderMapper.")


def _normalize_pt_expr(expr: Array) -> Tuple[Array,
Mapping[str, Any]]:
def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays,
Mapping[str, Any]]:
"""
Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a
normalized form of *expr*, with all instances of
Expand All @@ -78,5 +78,6 @@ def _normalize_pt_expr(expr: Array) -> Tuple[Array,
equivalent graphs.
"""
normalize_mapper = _DatawrapperToBoundPlaceholderMapper()
normalized_expr = normalize_mapper(expr)
# type-ignore reason: Mapper.__call__ takes Array, passed DictOfNamedArrays
normalized_expr = normalize_mapper(expr) # type: ignore
return normalized_expr, normalize_mapper.bound_arguments