-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow str to be a valid JAX type #3045
Comments
I acknowledge that supporting strings might be handy, but I note that it's likely difficult to do in JAX because the underlying compiler (XLA) doesn't support strings. I'm going to close this issue since I don't think there's any reasonable action we can take here, but feel more than welcome to keep brainstorming on this issue. I will note a couple of things about your example. Note that JAX expects its functions under a
or using the Hope that helps! |
Thanks! I had assumed that if a string couldn't be passed as an argument, lexical closure wouldn't work either. Can you elaborate on the recompilation behavior of
Let's say I call print_yay(np.array([5, 3, 2, 2020]), "this is my first message!")
print_yay(np.array([5, 3, 2, 2020]), "this is my second message!")
print_yay(np.array([5, 3, 2, 2020]), "this is my first message!") Will JAX recompile |
hey, concerning supporting string in JAX, i notice that I am trying to custom backward of conv and try to wrap @custom_vjp
def custom_conv_general_dilated(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None):
return jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers=dimension_numbers)
def custom_conv_general_fwd(inputs, W, window_strides, padding, lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None):
print("Custom Conv general forward", "="*20)
# custom modification to inputs and W
outputs = custom_conv_general_dilated(inputs, W_lo, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers=dimension_numbers)
return outputs, (inputs, W, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers) The error msg is --> 211 return custom_conv_general_dilated(inputs, W, strides, padding, one, one,
212 dimension_numbers=dimension_numbers) + b
213 return init_fun, apply_fun
[... skipping hidden 4 frame]
~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py in concrete_aval(x)
960 if hasattr(x, '__jax_array__'):
961 return concrete_aval(x.__jax_array__())
--> 962 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
963 "type")
964
TypeError: Value 'VALID' with type <class 'str'> is not a valid JAX type Below is the def conv_general_dilated(
lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
operator.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or
a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a
string of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
'fastest', see the ``jax.default_matmul_precision`` context manager), or a
tuple of two ``lax.Precision`` enums or strings indicating precision of
``lhs`` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the convolution result.
In the string case of ``dimension_numbers``, each character identifies by
position:
- the batch dimensions in ``lhs``, ``rhs``, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between lhs, rhs, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the ``conv``
function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
'NCHW')``. As another example, to indicate dimension numbers consistent with
the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
When using the latter form of convolution dimension specification, window
strides are associated with spatial dimension character labels according to
the order in which the labels appear in the ``rhs_spec`` string, so that
``window_strides[0]`` is matched with the dimension corresponding to the first
character appearing in rhs_spec that is not ``'I'`` or ``'O'``.
If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
'NCHW')`` (for a 2D convolution).
"""
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
if lhs_dilation is None:
lhs_dilation = (1,) * (lhs.ndim - 2)
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
raise ValueError(
"String padding is not implemented for transposed convolution "
"using this op. Please either exactly specify the required padding or "
"use conv_transpose.")
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
if isinstance(padding, str):
lhs_perm, rhs_perm, _ = dnums
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
padding = padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
preferred_element_type = (None if preferred_element_type is None else
np.dtype(preferred_element_type))
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
dimension_numbers=dnums,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
lhs_shape=lhs.shape, rhs_shape=rhs.shape,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type) |
You should use |
results in:
It'd be great if JIT-compiled functions could accept strings. For me, this is useful because my function uses flow control to calculate different values based on the value of
msg
.I believe the only way to approach this is as of now is to create different functions for each flow control "pathway" and separately JIT compile each of this. To me, this is inconvenient for developers and reduces code-readability.
Note: This is a feature request; I'm aware that the only supported JAX types are currently numpy arrays. I've also read https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow
The text was updated successfully, but these errors were encountered: