Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arraycontext/impl/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class and performs all array operations eagerly. See
def __init__(self) -> None:
super().__init__()

from jax.numpy import DeviceArray
self.array_types = (DeviceArray, )
import jax.numpy as jnp
self.array_types = (jnp.ndarray, )

def _get_fake_numpy_namespace(self):
from .fake_numpy import EagerJAXFakeNumpyNamespace
Expand Down
34 changes: 16 additions & 18 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,14 @@ def __init__(self,
unstable.
"""
import pytato as pt
from jax.numpy import DeviceArray
import jax.numpy as jnp
super().__init__(compile_trace_callback=compile_trace_callback)
self.array_types = (pt.Array, DeviceArray)
self.array_types = (pt.Array, jnp.ndarray)

@property
def _frozen_array_types(self) -> Tuple[Type, ...]:
from jax.numpy import DeviceArray
return (DeviceArray, )
import jax.numpy as jnp
return (jnp.ndarray, )

def _rec_map_container(
self, func: Callable[[Array], Array], array: ArrayOrContainer,
Expand Down Expand Up @@ -756,16 +756,16 @@ def freeze(self, array):

import pytato as pt

from jax.numpy import DeviceArray
import jax.numpy as jnp
from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier

array_as_dict: Dict[str, Union[DeviceArray, pt.Array]] = {}
key_to_frozen_subary: Dict[str, DeviceArray] = {}
array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {}
key_to_frozen_subary: Dict[str, jnp.ndarray] = {}
key_to_pt_arrays: Dict[str, pt.Array] = {}

def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
ary: Union[DeviceArray, pt.Array]) -> None:
ary: Union[jnp.ndarray, pt.Array]) -> None:
key_str = "_ary" + _ary_container_key_stringifier(key)
array_as_dict[key_str] = ary

Expand All @@ -774,7 +774,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
# {{{ remove any non pytato arrays from array_as_dict

for key, subary in array_as_dict.items():
if isinstance(subary, DeviceArray):
if isinstance(subary, jnp.ndarray):
key_to_frozen_subary[key] = subary.block_until_ready()
elif isinstance(subary, pt.DataWrapper):
# trivial freeze.
Expand All @@ -801,7 +801,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
for k, v in out_dict.items()}
}

def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray:
def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
key_str = "_ary" + _ary_container_key_stringifier(key)
return key_to_frozen_subary[key_str]

Expand All @@ -824,21 +824,19 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return LazilyJAXCompilingFunctionCaller(self, f)

def tag(self, tags: ToTagSetConvertible, array):
from jax.numpy import DeviceArray

def _tag(ary):
if isinstance(ary, DeviceArray):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.tagged(_preprocess_array_tags(tags))

return self._rec_map_container(_tag, array)

def tag_axis(self, iaxis, tags: ToTagSetConvertible, array):
from jax.numpy import DeviceArray

def _tag_axis(ary):
if isinstance(ary, DeviceArray):
import jax.numpy as jnp
if isinstance(ary, jnp.ndarray):
return ary
else:
return ary.with_tagged_axis(iaxis, tags)
Expand All @@ -857,12 +855,12 @@ def call_loopy(self, program, **kwargs):

def einsum(self, spec, *args, arg_names=None, tagged=()):
import pytato as pt
from jax.numpy import DeviceArray
if arg_names is None:
arg_names = (None,) * len(args)

def preprocess_arg(name, arg):
if isinstance(arg, DeviceArray):
import jax.numpy as jnp
if isinstance(arg, jnp.ndarray):
ary = self.thaw(arg)
elif isinstance(arg, pt.Array):
ary = arg
Expand Down