Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow str to be a valid JAX type #3045

Closed
sumanthratna opened this issue May 11, 2020 · 4 comments
Closed

Allow str to be a valid JAX type #3045

sumanthratna opened this issue May 11, 2020 · 4 comments
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request

Comments

@sumanthratna
Copy link

sumanthratna commented May 11, 2020

import jax
import numpy as np


@jax.jit
def print_yay(arr, msg):
    print(msg)
    print(arr)

print_yay(np.array([5, 3, 2, 2020]), "this is my message!")

results in:

Traceback (most recent call last):
  File "jaxt.py", line 10, in <module>
    print_yay(np.array([5,3, 2, 2020]), "this is my message!")
  File "/private/tmp/venv/lib/python3.8/site-packages/jax/api.py", line 151, in f_jitted
    _check_args(args_flat)
  File "/private/tmp/venv/lib/python3.8/site-packages/jax/api.py", line 1590, in _check_args
    raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'this is my message!' of type <class 'str'> is not a valid JAX type

It'd be great if JIT-compiled functions could accept strings. For me, this is useful because my function uses flow control to calculate different values based on the value of msg.

I believe the only way to approach this is as of now is to create different functions for each flow control "pathway" and separately JIT compile each of this. To me, this is inconvenient for developers and reduces code-readability.

Note: This is a feature request; I'm aware that the only supported JAX types are currently numpy arrays. I've also read https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow

@hawkinsp hawkinsp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request labels May 12, 2020
@hawkinsp
Copy link
Collaborator

I acknowledge that supporting strings might be handy, but I note that it's likely difficult to do in JAX because the underlying compiler (XLA) doesn't support strings.

I'm going to close this issue since I don't think there's any reasonable action we can take here, but feel more than welcome to keep brainstorming on this issue.

I will note a couple of things about your example. Note that JAX expects its functions under a jit to be pure, so adding print won't do what you expect; namely, it will print during trace time, not during runtime. As a secondary comment, if you do want to pass a string object into a jit-decorated function, you can, but you must do so either by lexical capture, e.g.,

msg = "hello"
def f(x):
  print(msg)
  return x + 2

or using the static_argnums feature of jit, which has the effect of recompiling the computation for each string.

Hope that helps!

@sumanthratna
Copy link
Author

Thanks! I had assumed that if a string couldn't be passed as an argument, lexical closure wouldn't work either. Can you elaborate on the recompilation behavior of static_argnums? From the docs:

An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the jitted function with different values for these constants will trigger recompilation.

Let's say I call print_yay (from my original post) 3 times:

print_yay(np.array([5, 3, 2, 2020]), "this is my first message!")
print_yay(np.array([5, 3, 2, 2020]), "this is my second message!")
print_yay(np.array([5, 3, 2, 2020]), "this is my first message!")

Will JAX recompile print_yay on the third call? Relevant code: https://github.com/google/jax/blob/ef4debcaad5a5ac5182899e385f45ca64f5ce600/jax/api.py#L145-L149

@llCurious
Copy link

llCurious commented Jan 14, 2022

hey, concerning supporting string in JAX, i notice that jax.lax.conv_general_dilated actually takes dimension_numbers which is a tuple of string.

I am trying to custom backward of conv and try to wrap jax.lax.conv_general_dilated as follows:

@custom_vjp
def custom_conv_general_dilated(lhs: Array, rhs: Array, window_strides: Sequence[int],
                                padding: Union[str, Sequence[Tuple[int, int]]],
                                lhs_dilation: Optional[Sequence[int]] = None,
                                rhs_dilation: Optional[Sequence[int]] = None,
                                dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
                                feature_group_count: int = 1, batch_group_count: int = 1,
                                precision: PrecisionLike = None,
                                preferred_element_type: Optional[DType] = None):
  return jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
                                    dimension_numbers=dimension_numbers)


def custom_conv_general_fwd(inputs, W, window_strides, padding, lhs_dilation: Optional[Sequence[int]] = None,
                            rhs_dilation: Optional[Sequence[int]] = None,
                            dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
                            feature_group_count: int = 1, batch_group_count: int = 1,
                            precision: PrecisionLike = None,
                            preferred_element_type: Optional[DType] = None):
  print("Custom Conv general forward", "="*20)
  # custom modification to inputs and W
  outputs = custom_conv_general_dilated(inputs, W_lo, window_strides, padding, lhs_dilation, rhs_dilation,
                                    dimension_numbers=dimension_numbers)
  return outputs, (inputs, W, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers)

The error msg is

--> 211     return custom_conv_general_dilated(inputs, W, strides, padding, one, one,
    212                                     dimension_numbers=dimension_numbers) + b
    213   return init_fun, apply_fun

    [... skipping hidden 4 frame]

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py in concrete_aval(x)
    960   if hasattr(x, '__jax_array__'):
    961     return concrete_aval(x.__jax_array__())
--> 962   raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
    963                    "type")
    964 

TypeError: Value 'VALID' with type <class 'str'> is not a valid JAX type

Below is the jax.lax.conv_general_dilated

def conv_general_dilated(
  lhs: Array, rhs: Array, window_strides: Sequence[int],
  padding: Union[str, Sequence[Tuple[int, int]]],
  lhs_dilation: Optional[Sequence[int]] = None,
  rhs_dilation: Optional[Sequence[int]] = None,
  dimension_numbers: ConvGeneralDilatedDimensionNumbers  = None,
  feature_group_count: int = 1, batch_group_count: int = 1,
  precision: PrecisionLike = None,
  preferred_element_type: Optional[DType] = None) -> Array:
  """General n-dimensional convolution operator, with optional dilation.

  Wraps XLA's `Conv
  <https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
  operator.

  Args:
    lhs: a rank `n+2` dimensional input array.
    rhs: a rank `n+2` dimensional array of kernel weights.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or
      a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a
      string of length `n+2`.
    feature_group_count: integer, default 1. See XLA HLO docs.
    batch_group_count: integer, default 1. See XLA HLO docs.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
      'fastest', see the ``jax.default_matmul_precision`` context manager), or a
      tuple of two ``lax.Precision`` enums or strings indicating precision of
      ``lhs`` and ``rhs``.
    preferred_element_type: Optional. Either ``None``, which means the default
      accumulation type for the input types, or a datatype, indicating to
      accumulate results to and return a result with that datatype.

  Returns:
    An array containing the convolution result.

  In the string case of ``dimension_numbers``, each character identifies by
  position:

  - the batch dimensions in ``lhs``, ``rhs``, and the output with the character
    'N',
  - the feature dimensions in `lhs` and the output with the character 'C',
  - the input and output feature dimensions in rhs with the characters 'I'
    and 'O' respectively, and
  - spatial dimension correspondences between lhs, rhs, and the output using
    any distinct characters.

  For example, to indicate dimension numbers consistent with the ``conv``
  function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
  'NCHW')``. As another example, to indicate dimension numbers consistent with
  the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
  When using the latter form of convolution dimension specification, window
  strides are associated with spatial dimension character labels according to
  the order in which the labels appear in the ``rhs_spec`` string, so that
  ``window_strides[0]`` is matched with the dimension corresponding to the first
  character appearing in rhs_spec that is not ``'I'`` or ``'O'``.

  If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
  'NCHW')`` (for a 2D convolution).
  """
  dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
  if lhs_dilation is None:
    lhs_dilation = (1,) * (lhs.ndim - 2)
  elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
    raise ValueError(
        "String padding is not implemented for transposed convolution "
        "using this op. Please either exactly specify the required padding or "
        "use conv_transpose.")
  if rhs_dilation is None:
    rhs_dilation = (1,) * (rhs.ndim - 2)
  if isinstance(padding, str):
    lhs_perm, rhs_perm, _ = dnums
    rhs_shape = np.take(rhs.shape, rhs_perm)[2:]  # type: ignore[index]
    effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
    padding = padtype_to_pads(
        np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape,  # type: ignore[index]
        window_strides, padding)
  preferred_element_type = (None if preferred_element_type is None else
                            np.dtype(preferred_element_type))
  return conv_general_dilated_p.bind(
      lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
      lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
      dimension_numbers=dnums,
      feature_group_count=feature_group_count,
      batch_group_count=batch_group_count,
      lhs_shape=lhs.shape, rhs_shape=rhs.shape,
      precision=canonicalize_precision(precision),
      preferred_element_type=preferred_element_type)

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 14, 2022

You should use nondiff_argnums in custom_vfp or custom_jvp to specify which of your function's arguments are non-differentiable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants