From 9b89d14c3ee8eee1f5ded8b9fc8eee31f4c3fb85 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Dec 2022 16:29:59 +0100 Subject: [PATCH 1/5] limit JAX version due to CI errors --- .test-conda-env-py3.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 15b64965..93b5da33 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -13,4 +13,4 @@ dependencies: - pyopencl - islpy - pip -- jax +- jax <0.4 From 94d54a356cebc9e75a743fef235f780629e2be9c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Dec 2022 16:55:41 +0100 Subject: [PATCH 2/5] rename DeviceArray -> Array --- arraycontext/impl/jax/__init__.py | 4 ++-- arraycontext/impl/pytato/__init__.py | 31 ++++++++++++---------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index f4794e46..b49d24d7 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -54,8 +54,8 @@ class and performs all array operations eagerly. See def __init__(self) -> None: super().__init__() - from jax.numpy import DeviceArray - self.array_types = (DeviceArray, ) + from jax import Array + self.array_types = (Array, ) def _get_fake_numpy_namespace(self): from .fake_numpy import EagerJAXFakeNumpyNamespace diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8ccc7689..83373d78 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -686,14 +686,14 @@ def __init__(self, unstable. """ import pytato as pt - from jax.numpy import DeviceArray + from jax import Array super().__init__(compile_trace_callback=compile_trace_callback) - self.array_types = (pt.Array, DeviceArray) + self.array_types = (pt.Array, Array) @property def _frozen_array_types(self) -> Tuple[Type, ...]: - from jax.numpy import DeviceArray - return (DeviceArray, ) + from jax import Array + return (Array, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, @@ -756,16 +756,16 @@ def freeze(self, array): import pytato as pt - from jax.numpy import DeviceArray + from jax import Array from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[DeviceArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, DeviceArray] = {} + array_as_dict: Dict[str, Union[Array, pt.Array]] = {} + key_to_frozen_subary: Dict[str, Array] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[DeviceArray, pt.Array]) -> None: + ary: Union[Array, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -774,7 +774,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # {{{ remove any non pytato arrays from array_as_dict for key, subary in array_as_dict.items(): - if isinstance(subary, DeviceArray): + if isinstance(subary, Array): key_to_frozen_subary[key] = subary.block_until_ready() elif isinstance(subary, pt.DataWrapper): # trivial freeze. @@ -801,7 +801,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray: + def _to_frozen(key: Tuple[Any, ...], ary) -> Array: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -824,10 +824,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return LazilyJAXCompilingFunctionCaller(self, f) def tag(self, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag(ary): - if isinstance(ary, DeviceArray): + if isinstance(ary, Array): return ary else: return ary.tagged(_preprocess_array_tags(tags)) @@ -835,10 +833,8 @@ def _tag(ary): return self._rec_map_container(_tag, array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag_axis(ary): - if isinstance(ary, DeviceArray): + if isinstance(ary, Array): return ary else: return ary.with_tagged_axis(iaxis, tags) @@ -857,12 +853,11 @@ def call_loopy(self, program, **kwargs): def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt - from jax.numpy import DeviceArray if arg_names is None: arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, DeviceArray): + if isinstance(arg, Array): ary = self.thaw(arg) elif isinstance(arg, pt.Array): ary = arg From 702a6cc0f275b314984e91d1cf4cd821fec10f44 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Dec 2022 22:37:26 +0100 Subject: [PATCH 3/5] import as JAXArray --- arraycontext/impl/jax/__init__.py | 4 ++-- arraycontext/impl/pytato/__init__.py | 29 +++++++++++++++------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index b49d24d7..50832f82 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -54,8 +54,8 @@ class and performs all array operations eagerly. See def __init__(self) -> None: super().__init__() - from jax import Array - self.array_types = (Array, ) + from jax import Array as JAXArray + self.array_types = (JAXArray, ) def _get_fake_numpy_namespace(self): from .fake_numpy import EagerJAXFakeNumpyNamespace diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 83373d78..8cefa08e 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -686,14 +686,14 @@ def __init__(self, unstable. """ import pytato as pt - from jax import Array + from jax import Array as JAXArray super().__init__(compile_trace_callback=compile_trace_callback) - self.array_types = (pt.Array, Array) + self.array_types = (pt.Array, JAXArray) @property def _frozen_array_types(self) -> Tuple[Type, ...]: - from jax import Array - return (Array, ) + from jax import Array as JAXArray + return (JAXArray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, @@ -756,16 +756,16 @@ def freeze(self, array): import pytato as pt - from jax import Array + from jax import Array as JAXArray from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[Array, pt.Array]] = {} - key_to_frozen_subary: Dict[str, Array] = {} + array_as_dict: Dict[str, Union[JAXArray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, JAXArray] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[Array, pt.Array]) -> None: + ary: Union[JAXArray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -774,7 +774,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # {{{ remove any non pytato arrays from array_as_dict for key, subary in array_as_dict.items(): - if isinstance(subary, Array): + if isinstance(subary, JAXArray): key_to_frozen_subary[key] = subary.block_until_ready() elif isinstance(subary, pt.DataWrapper): # trivial freeze. @@ -801,7 +801,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> Array: + def _to_frozen(key: Tuple[Any, ...], ary) -> JAXArray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -825,7 +825,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): - if isinstance(ary, Array): + from jax import Array as JAXArray + if isinstance(ary, JAXArray): return ary else: return ary.tagged(_preprocess_array_tags(tags)) @@ -834,7 +835,8 @@ def _tag(ary): def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): def _tag_axis(ary): - if isinstance(ary, Array): + from jax import Array as JAXArray + if isinstance(ary, JAXArray): return ary else: return ary.with_tagged_axis(iaxis, tags) @@ -857,7 +859,8 @@ def einsum(self, spec, *args, arg_names=None, tagged=()): arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, Array): + from jax import Array as JAXArray + if isinstance(arg, JAXArray): ary = self.thaw(arg) elif isinstance(arg, pt.Array): ary = arg From 691f1d6ac8d58fa41c865fe8d45b454a948be60b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Dec 2022 23:36:25 +0100 Subject: [PATCH 4/5] Revert "limit JAX version due to CI errors" This reverts commit 9b89d14c3ee8eee1f5ded8b9fc8eee31f4c3fb85. --- .test-conda-env-py3.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 93b5da33..15b64965 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -13,4 +13,4 @@ dependencies: - pyopencl - islpy - pip -- jax <0.4 +- jax From 45482ab962934c5a6755412ea45b397c395fcdb8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Dec 2022 15:37:08 +0100 Subject: [PATCH 5/5] change to jax.numpy.ndarray --- arraycontext/impl/jax/__init__.py | 4 ++-- arraycontext/impl/pytato/__init__.py | 32 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 50832f82..4aa30c28 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -54,8 +54,8 @@ class and performs all array operations eagerly. See def __init__(self) -> None: super().__init__() - from jax import Array as JAXArray - self.array_types = (JAXArray, ) + import jax.numpy as jnp + self.array_types = (jnp.ndarray, ) def _get_fake_numpy_namespace(self): from .fake_numpy import EagerJAXFakeNumpyNamespace diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8cefa08e..c3e4462c 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -686,14 +686,14 @@ def __init__(self, unstable. """ import pytato as pt - from jax import Array as JAXArray + import jax.numpy as jnp super().__init__(compile_trace_callback=compile_trace_callback) - self.array_types = (pt.Array, JAXArray) + self.array_types = (pt.Array, jnp.ndarray) @property def _frozen_array_types(self) -> Tuple[Type, ...]: - from jax import Array as JAXArray - return (JAXArray, ) + import jax.numpy as jnp + return (jnp.ndarray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, @@ -756,16 +756,16 @@ def freeze(self, array): import pytato as pt - from jax import Array as JAXArray + import jax.numpy as jnp from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[JAXArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, JAXArray] = {} + array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, jnp.ndarray] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[JAXArray, pt.Array]) -> None: + ary: Union[jnp.ndarray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -774,7 +774,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # {{{ remove any non pytato arrays from array_as_dict for key, subary in array_as_dict.items(): - if isinstance(subary, JAXArray): + if isinstance(subary, jnp.ndarray): key_to_frozen_subary[key] = subary.block_until_ready() elif isinstance(subary, pt.DataWrapper): # trivial freeze. @@ -801,7 +801,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> JAXArray: + def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: key_str = "_ary" + _ary_container_key_stringifier(key) return key_to_frozen_subary[key_str] @@ -825,8 +825,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): - from jax import Array as JAXArray - if isinstance(ary, JAXArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.tagged(_preprocess_array_tags(tags)) @@ -835,8 +835,8 @@ def _tag(ary): def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): def _tag_axis(ary): - from jax import Array as JAXArray - if isinstance(ary, JAXArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.with_tagged_axis(iaxis, tags) @@ -859,8 +859,8 @@ def einsum(self, spec, *args, arg_names=None, tagged=()): arg_names = (None,) * len(args) def preprocess_arg(name, arg): - from jax import Array as JAXArray - if isinstance(arg, JAXArray): + import jax.numpy as jnp + if isinstance(arg, jnp.ndarray): ary = self.thaw(arg) elif isinstance(arg, pt.Array): ary = arg