Skip to content

v0.4.0

Choose a tag to compare

@fferflo fferflo released this 04 Mar 17:34
· 50 commits to master since this release

einx v0.4.0: Fully embrace vectorization!

Summary

Vectorization. This release fully embraces vectorization by analogy with loop notation as the core abstraction of einx: Any einx expression

# einx notation
z = einx.{OP}("a [i j], b -> a b [j]", x, y)

will yield the same output as invoking the underlying elementary operation in an analogous loop expression:

# Loop notation
for a in range(...):
    for b in range(...):
        z[a, b, :] = {OP}(x[a, :, :], y[b])
         "a  b [j]"        "a [i  j]"  "b"

See the new documentation for more information. This definition was already adhered to
almost entirely, but is now strictly enforced through smaller changes in the interface such as renaming einx.rearrange to einx.id and
removing some specialized behavior in the notation (see details below).

Backends. This release introduces major updates to how tensor operations are implemented in einx. This allows adapting arbitrary functions to einx notation

# Define some custom operation
def op(x, y):
    return torch.sum(x, dim=0) * torch.flip(y)

# Adapt to einx notation
einop = einx.torch.adapt_with_vmap(op)

# Invoke using einx notation
result = einop("a [b c], a [c] -> a [c]", x, y)

and choosing different backend implementations for operations (e.g., Numpy-like notation, vmap-based notation, or einsum notation).

Clarity. The release improves clarity through better error reporting among others for syntax and shape errors, a new documentation, and by removing special behavior and edge cases from the einx notation (see details below).

Added

  • Allow adapting arbitrary functions to einx notation. einx provides different adapters based on the signature of the wrapped function in the namespace einx.{framework}.adapt_*. The simplest is einx.{framework}.adapt_with_vmap which uses a framework's vmap transformation internally, but is only supported for frameworks that provide vmap (e.g., Jax, PyTorch, MLX, but not Numpy). Other adapters are provided for functions that follow Numpy-like signatures (e.g. reduction operation with axis parameter). See the documentation for more information.

    The functions einx.{reduce|elementwise|vmap|vmap_with_axis} that partially provided this functionality in previous versions have been removed in favor of the new adapters.

  • Add different backend implementations for operations. Each einx operation can now be invoked using different backend implementations by specifying the backend argument. For example, passing backend="torch.numpylike" uses only Numpy-like operations from PyTorch, while backend="torch.vmap" uses torch.vmap, and backend="torch.einsum" uses torch.einsum internally (if the operation is expressible using torch.einsum). The default backend backend="torch" uses a combination of the above. See the documentation for more information and examples of the compiled code with different backends.

    Indexing functions (einx.{get_at|set_at|...}) were previously implemented only using vmap which lead to some problems with frameworks that have limited support for vmap (e.g., PyTorch) or no support for vmap (e.g., Numpy). The default backend for all frameworks now uses a purely Numpy-like implementation of indexing functions which avoids these issues.

  • Add support for new operations: einx.{argmin|argmax|sort|argsort|logaddexp}.

  • Support multiple vectorized axes with the same name in input expressions. In this case, the diagonal of the input tensor is extracted along the specified axes
    before applying the operation. This adheres to the loop notation analogy. For example:

    einx.id("a b b c -> a b c", x) # Extracts diagonal along the 'b' axes
    einx.sum("[a] b b c", x) # Extracts diagonal along the 'b' axes, and computes sum along 1st axis
    einx.sum("a [b b] c", x) # 'b' is not vectorized, so the behavior does not apply here. Still computes sum along 2nd and 3rd axis.
  • Add support for Array API backend. As a result, einx now supports all tensor frameworks that implement the Array API standard. This requires the array-api-compat package to be installed.

  • Add einx.solve_axes and einx.solve_shapes.

Changed

  • Improve error reporting to improve clarity. Most errors should be a lot easier to fix now. For example:

    x = np.zeros((10, 5))
    einx.id("(a b) c -> a b c", x)

    raises

    einx.errors.AxisSizeError: Failed to uniquely determine the size of the axes a, b. Please provide more constraints.
    Expression: "(a b) c -> a b c"
                  ^ ^       ^ ^
    The operation was called with the following arguments:
      - Positional argument #1: Tensor with shape (10, 5)
    
  • Simplify einx notation by removing special behavior and edge cases:

    • Deprecate keepdims argument in reduction functions:
      einx.sum("a [b]", x, keepdims=True) # version < 0.4.0
      The behavior can be equally achieved using a flattened axis:
      einx.sum("a ([b])", x) # version >= 0.4.0
    • Remove cse argument from einx functions which previously allowed disabling common subexpression elimination.
    • Remove special shorthand notation in dot-product and elementwise operations where two tensors are passed, but the second input expression is determined implicitly:
      einx.dot("b [c_in] -> b [c_out]", x, weight) # version < 0.4.0
      einx.add("b [c]", x, bias) # version < 0.4.0
      The behavior can be equally achieved by explicitly specifying the second input:
      einx.dot("b [c_in], [c_in] c_out -> b c_out", x, weight) # version >= 0.4.0
      einx.add("b c, c", x, bias) # version >= 0.4.0
    • Remove einx.arange:
      einx.arange("a b [2]", a=5, b=10) # version < 0.4.0
      The behavior can be equally achieved using einx.id with np.arange:
      einx.id("a, b -> a b (1 + 1)", np.arange(5), np.arange(10)) # version >= 0.4.0
    • Deprecate einx.check:
      einx.check("a b", x) # version < 0.4.0
      The behavior can be equally achieved using einx.id:
      einx.id("a b", x) # version >= 0.4.0
    • Change named axes ("a") and unnamed axes ("1") to have identical behavior now. Among others, this now allows squeezing named axes:
      einx.id("a b c -> a b", x, c=1) # version >= 0.4.0
    • Remove automatic reordering of arguments in einx.id:
      einx.id("a, b -> (b + a)", x, y) # version < 0.4.0
      The behavior can be equally achieved by switching the order of the arguments:
      einx.id("b, a -> (b + a)", y, x) # version >= 0.4.0
  • Rename einx.rearrange to einx.id to reflect that it computes a vectorized identity map. This follows the general naming convention of einx where function names reflect the elementary operation that is computed.

  • Clean up public API by moving implementation into einx._src namespace.

  • Remove einx.experimental.shard.

  • Remove einx.nn. This namespace contained implementations of neural net layers for different frameworks in einx notation. Supporting many different neural net libraries created an overhead that is not warranted by the benefit. Rather than provide special einx layers, einx may be used internally by layer implementations.

  • Remove support for passing lists or tuples as tensor arguments:

    einx.add("a b, a", x, [1.0, 2.0, 4.0]) # version < 0.4.0

    The behavior can be equally achieved by using a Numpy array instead:

    einx.add("a b, a", x, np.asarray([1.0, 2.0, 4.0])) # version >= 0.4.0
  • Bump required Python version to 3.10 since 3.8 and 3.9 have reached end-of-life.

  • Remove all usages of tensorflow.experimental.numpy in the Tensorflow backend, and instead rely only on standard Tensorflow operations.

  • Remove dedicated support for the Dask framework. Dask is now instead supported using the Array API backend.

  • Disallow changing order of non-vectorized axes in some einx functions:

    einx.softmax("a [b c] -> a [c b]", x) # version < 0.4.0

    This avoids confusion of vectorized axes (where axis ordering indicates permutation) and non-vectorized axes (where axis ordering only indicates the signature of the elementary operation).

  • Disallow using | as an alternative to -> in einx notation which was previously supported.

  • einx.dot now only supports dot-product operations, and no longer supports other operation signatures also supported by einsum.

Fixed

  • When initializing a backend, delay raising an exception until the backend is used in an operation. This avoids problems where the import of a framework failed, even though it is not actually used with einx.
  • Use torch.{amin|amax} instead of torch.{min|max} since in some configurations the latter returns a tuple rather than only the reduced tensor (see #24 and #26).