Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.api.ops import linalg
from keras.api.ops import nn
from keras.api.ops import numpy
from keras.src.ops.core import associative_scan
from keras.src.ops.core import cast
from keras.src.ops.core import cond
from keras.src.ops.core import convert_to_numpy
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.api.ops import linalg
from keras.api.ops import nn
from keras.api.ops import numpy
from keras.src.ops.core import associative_scan
from keras.src.ops.core import cast
from keras.src.ops.core import cond
from keras.src.ops.core import convert_to_numpy
Expand Down
12 changes: 12 additions & 0 deletions keras/src/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,15 @@ def wrapped(*args, **kwargs):
return ops.expand_dims(result, axis=dims_to_expand)

return wrapped


def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
"""Slice a Tensor along the given axis."""
# Ref: same util function defined in tfp.math.scan_associative
if axis >= 0:
slices = [slice(None)] * axis + [slice(start, stop, step)]
else:
slices = [Ellipsis, slice(start, stop, step)] + [slice(None)] * (
-1 - axis
)
return x[tuple(slices)]
4 changes: 4 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
)


def associative_scan(f, elems, reverse=False, axis=0):
return jax.lax.associative_scan(f, elems, reverse, axis)


def scatter(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
key = tuple(jnp.moveaxis(indices, -1, 0))
Expand Down
113 changes: 113 additions & 0 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import builtins
import functools
import warnings

import numpy as np
import optree

from keras.src import tree
from keras.src.backend.common import KerasVariable
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.backend_utils import slice_along_axis
from keras.src.backend.common.dtypes import result_type
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
Expand Down Expand Up @@ -202,6 +205,116 @@ def pack_output(x):
return carry, stacked_y


def associative_scan(f, elems, reverse=False, axis=0):
# Ref: jax.lax.associative_scan
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
elems_flat, tree = optree.tree_flatten(elems)
elems_flat = [convert_to_tensor(elem) for elem in elems_flat]
if reverse:
elems_flat = [np.flip(elem, (axis,)) for elem in elems_flat]

def _combine(a_flat, b_flat):
a = optree.tree_unflatten(tree, a_flat)
b = optree.tree_unflatten(tree, b_flat)
c = f(a, b)
c_flat, _ = optree.tree_flatten(c)
return c_flat

num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError(
"Array inputs to associative_scan must have the same "
"first dimension. (saw: {})".format(
[elem.shape for elem in elems_flat]
)
)

def _interleave(a, b, axis):
"""Given two Tensors of static shape, interleave them along axis."""
assert (
a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
)

# we want to get a: [a1, a2], b: [b1, b2]
# to a: [a1, 0, a2, 0], b: [0, b1, 0, b2]
a_shape = list(a.shape)
a_shape[axis] = a.shape[axis] * 2 - 1

b_shape = list(b.shape)
b_shape[axis] = b.shape[axis] * 2 - 1

a_dil = np.zeros(a_shape)
np.copyto(slice_along_axis(a_dil, 0, None, 2, axis), a)
b_dil = np.zeros(b_shape)
np.copyto(slice_along_axis(b_dil, 0, None, 2, axis), b)

a_pad = [[0, 0] for _ in range(a.ndim)]
a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0

b_pad = [[0, 0] for _ in range(b.ndim)]
b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1]

op = np.bitwise_or if a.dtype == np.bool_ else np.add
return op(
np.pad(a_dil, a_pad),
np.pad(b_dil, b_pad),
)

def _scan(elems):
num_elems = elems[0].shape[axis]
if num_elems < 2:
return elems

reduced_elems = _combine(
[
slice_along_axis(elem, 0, -1, step=2, axis=axis)
for elem in elems
],
[
slice_along_axis(elem, 1, None, step=2, axis=axis)
for elem in elems
],
)

odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
even_elems = _combine(
[slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems],
[
slice_along_axis(e, 2, None, step=2, axis=axis)
for e in elems
],
)
else:
even_elems = _combine(
odd_elems,
[
slice_along_axis(e, 2, None, step=2, axis=axis)
for e in elems
],
)

even_elems = [
np.concatenate(
[slice_along_axis(elem, 0, 1, axis=axis), result],
axis=axis,
)
for (elem, result) in zip(elems, even_elems)
]
return list(
builtins.map(
functools.partial(_interleave, axis=axis), even_elems, odd_elems
)
)

scans = _scan(elems_flat)
if reverse:
scans = [np.flip(scanned, (axis,)) for scanned in scans]

return optree.tree_unflatten(tree, scans)


def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
Expand Down
183 changes: 183 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import builtins

import numpy as np
import optree
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from keras.src import tree
from keras.src.backend.common import KerasVariable
from keras.src.backend.common import global_state
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.backend_utils import slice_along_axis
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.name_scope import name_scope as base_name_scope
from keras.src.backend.common.stateless_scope import StatelessScope
Expand Down Expand Up @@ -350,6 +354,185 @@ def loop_body(i, carry_array, ys_array):
return pack_output(carry_flat), pack_output(ys_flat)


def associative_scan(f, elems, reverse=False, axis=0):
# Implementation is the same as tfp.math.scan_associative
# with additional checks to ensure similar behavior with jax
if not callable(f):
raise TypeError(f"`f` should be a callable. Received: f={f}")
elems_flat, treespec = optree.tree_flatten(elems)
elems_flat = [tf.convert_to_tensor(elem) for elem in elems_flat]
if reverse:
elems_flat = [tf.reverse(elem, [axis]) for elem in elems_flat]

def _combine(a_flat, b_flat):
a = optree.tree_unflatten(treespec, a_flat)
b = optree.tree_unflatten(treespec, b_flat)
c = f(a, b)
c_flat, _ = optree.tree_flatten(c)
return c_flat

def _get_dim(x):
return shape(x)[axis]

# TODO add constant dim check
num_elems = _get_dim(elems_flat[0])
if not all(_get_dim(elem) == num_elems for elem in elems_flat[1:]):
raise ValueError(
"Array inputs to associative_scan must have the same "
"first dimension. (saw: {})".format(
[tf.shape(elem) for elem in elems_flat]
)
)

def _interleave(a, b, axis):
# [a b c ...] [d e f ...] -> [a d b e c f ...]
num_elems_a = _get_dim(a)
num_elems_b = _get_dim(b)

# Note that interleaving implies rank(a)==rank(b).
axis = tf.where(axis >= 0, axis, tf.rank(a) + axis)
axis = (
int(axis) # Avoid ndarray values.
if tf.get_static_value(axis) is not None
else axis
)

def _interleave_with_b(a):
return tf.reshape(
# Work around lack of support for Tensor axes in
# `tf.stack` by using `concat` and `expand_dims` instead.
tf.concat(
[
tf.expand_dims(a, axis=axis + 1),
tf.expand_dims(b, axis=axis + 1),
],
axis=axis + 1,
),
tf.concat(
[
a.get_shape()[:axis],
[2 * num_elems_b],
a.get_shape()[axis + 1 :],
],
axis=0,
),
)

return tf.cond(
tf.equal(num_elems_a, num_elems_b + 1),
lambda: tf.concat(
[
_interleave_with_b(
slice_along_axis(a, None, -1, axis=axis)
),
slice_along_axis(a, -1, None, axis=axis),
],
axis=axis,
),
lambda: _interleave_with_b(a),
)

def _scan(elems):
elem_length = _get_dim(elems[0])
a = [slice_along_axis(elem, 0, -1, step=2, axis=axis) for elem in elems]
b = [
slice_along_axis(elem, 1, None, step=2, axis=axis) for elem in elems
]
reduced_elems = _combine(a, b)

def _handle_base_case_elem_length_two():
return [
tf.concat(
[slice_along_axis(elem, 0, 1, axis=axis), reduced_elem],
axis=axis,
)
for (reduced_elem, elem) in zip(reduced_elems, elems)
]

def _handle_base_case_elem_length_three():
reduced_reduced_elems = _combine(
reduced_elems,
[slice_along_axis(elem, 2, 3, axis=axis) for elem in elems],
)
return [
tf.concat(
[
slice_along_axis(elem, 0, 1, axis=axis),
reduced_elem,
reduced_reduced_elem,
],
axis=axis,
)
for (reduced_reduced_elem, reduced_elem, elem) in zip(
reduced_reduced_elems, reduced_elems, elems
)
]

at_base_case = tf.logical_or(
tf.equal(elem_length, 2), tf.equal(elem_length, 3)
)

def _base_case():
return tf.cond(
tf.equal(elem_length, 2),
_handle_base_case_elem_length_two,
_handle_base_case_elem_length_three,
)

def _recursive_case():

odd_elems = _scan(reduced_elems)

def _even_length_case():
return _combine(
[
slice_along_axis(odd_elem, 0, -1, axis=axis)
for odd_elem in odd_elems
],
[
slice_along_axis(elem, 2, None, 2, axis=axis)
for elem in elems
],
)

def _odd_length_case():
return _combine(
[odd_elem for odd_elem in odd_elems],
[
slice_along_axis(elem, 2, None, 2, axis=axis)
for elem in elems
],
)

results = tf.cond(
tf.equal(elem_length % 2, 0),
_even_length_case,
_odd_length_case,
)

even_elems = [
tf.concat(
[slice_along_axis(elem, 0, 1, axis=axis), result], axis=axis
)
for (elem, result) in zip(elems, results)
]
return list(
builtins.map(
lambda a, b: _interleave(a, b, axis=axis),
even_elems,
odd_elems,
)
)

return tf.cond(at_base_case, _base_case, _recursive_case)

scans = _scan(elems_flat)
if reverse:
scans = [tf.reverse(scanned, [axis]) for scanned in scans]

return optree.tree_unflatten(treespec, scans)


def scatter(indices, values, shape):
return tf.scatter_nd(indices, values, shape)

Expand Down
Loading