From fc8768e662bfad2d58e31b0903cc4e5af03a1756 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 28 Jul 2021 20:22:03 -0500 Subject: [PATCH 1/2] defines PytatoPyOpenCLArrayContext.transform_dag Co-authored-by: Andreas Kloeckner --- arraycontext/impl/pytato/__init__.py | 37 ++++++++++++++++++++++++---- arraycontext/impl/pytato/compile.py | 6 ++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index b4ac63e6..beaebc4b 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -43,9 +43,12 @@ from arraycontext.context import ArrayContext import numpy as np -from typing import Any, Callable, Union, Sequence +from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag +if TYPE_CHECKING: + import pytato + class PytatoPyOpenCLArrayContext(ArrayContext): """ @@ -62,6 +65,8 @@ class PytatoPyOpenCLArrayContext(ArrayContext): to use the default allocator. .. automethod:: __init__ + + .. automethod:: transform_dag """ def __init__(self, queue, allocator=None): @@ -116,6 +121,7 @@ def call_loopy(self, program, **kwargs): def freeze(self, array): import pytato as pt import pyopencl.array as cla + import loopy as lp if isinstance(array, cla.Array): return array.with_queue(None) @@ -134,20 +140,27 @@ def freeze(self, array): # }}} from arraycontext.impl.pytato.utils import _normalize_pt_expr - normalized_expr, bound_arguments = _normalize_pt_expr(array) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( + {"_actx_out": array}) + + normalized_expr, bound_arguments = _normalize_pt_expr( + pt_dict_of_named_arrays) try: pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: - pt_prg = pt.generate_loopy(normalized_expr, cl_device=self.queue.device) + pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), + options=lp.Options(return_dict=True, + no_numpy=True), + cl_device=self.queue.device) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg assert len(pt_prg.bound_arguments) == 0 - evt, (cl_array,) = pt_prg(self.queue, **bound_arguments) + evt, out_dict = pt_prg(self.queue, **bound_arguments) evt.wait() - return cl_array.with_queue(None) + return out_dict["_actx_out"].with_queue(None) def thaw(self, array): import pytato as pt @@ -170,6 +183,20 @@ def transform_loopy_program(self, t_unit): "transform_loopy_program. Sub-classes are supposed " "to implement it.") + def transform_dag(self, dag: "pytato.DictOfNamedArrays" + ) -> "pytato.DictOfNamedArrays": + """ + Returns a transformed version of *dag*. Sub-classes are supposed to + override this method to implement context-specific transformations on + *dag* (most likely to perform domain-specific optimizations). Every + :mod:`pytato` DAG that is compiled to a :mod:`pyopencl` kernel is + passed through this routine. + + :arg dag: An instance of :class:`pytato.DictOfNamedArrays` + :returns: A transformed version of *dag*. + """ + return dag + def tag(self, tags: Union[Sequence[Tag], Tag], array): return array.tagged(tags) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5fad96cb..faed2cfe 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -226,7 +226,11 @@ def _as_dict_of_named_arrays(keys, ary): outputs) import loopy as lp - pytato_program = pt.generate_loopy(dict_of_named_arrays, + + pt_dict_of_named_arrays = self.actx.transform_dag( + pt.make_dict_of_named_arrays(dict_of_named_arrays)) + + pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, options=lp.Options( return_dict=True, no_numpy=True), From 49f747310333e5ee646cd5056b4355ebfa698320 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 14 Aug 2021 16:19:27 -0500 Subject: [PATCH 2/2] _normalize_pt_expr: perform the transformation for a dict-of-named-arrays --- arraycontext/impl/pytato/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 1e00c8d6..184bdd90 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -25,7 +25,7 @@ from typing import Any, Dict, Set, Tuple, Mapping from pytato.array import SizeParam, Placeholder -from pytato.array import Array, DataWrapper +from pytato.array import Array, DataWrapper, DictOfNamedArrays from pytato.transform import CopyMapper from pytools import UniqueNameGenerator @@ -66,8 +66,8 @@ def map_placeholder(self, expr: Placeholder) -> Array: " DatawrapperToBoundPlaceholderMapper.") -def _normalize_pt_expr(expr: Array) -> Tuple[Array, - Mapping[str, Any]]: +def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, + Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -78,5 +78,6 @@ def _normalize_pt_expr(expr: Array) -> Tuple[Array, equivalent graphs. """ normalize_mapper = _DatawrapperToBoundPlaceholderMapper() - normalized_expr = normalize_mapper(expr) + # type-ignore reason: Mapper.__call__ takes Array, passed DictOfNamedArrays + normalized_expr = normalize_mapper(expr) # type: ignore return normalized_expr, normalize_mapper.bound_arguments