Skip to content


Merge branch 'keras-team:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sqali committed Sep 21, 2023
2 parents e829832 + 6383d8a commit 686f339
Show file tree
Hide file tree
Showing 31 changed files with 2,187 additions and 390 deletions.
26 changes: 22 additions & 4 deletions keras_core/backend/jax/
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@ def add(x1, x2):

def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
bincounts = [
jnp.bincount(arr, weights=weights, minlength=minlength)
for arr in list(x)
if weights is None:

def bincount_fn(arr):
return jnp.bincount(arr, minlength=minlength)

bincounts = list(map(bincount_fn, x))

def bincount_fn(arr_w):
return jnp.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength

bincounts = list(map(bincount_fn, zip(x, weights)))

return jnp.stack(bincounts)
return jnp.bincount(x, weights=weights, minlength=minlength)

Expand Down Expand Up @@ -102,6 +113,13 @@ def append(

def arange(start, stop=None, step=1, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
dtype = config.floatx()
return jnp.arange(start, stop, step=step, dtype=dtype)

Expand Down
35 changes: 34 additions & 1 deletion keras_core/backend/numpy/
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np

from keras_core.backend import config
from keras_core.backend import standardize_dtype

def add(x1, x2):
return np.add(x1, x2)
Expand Down Expand Up @@ -77,6 +80,13 @@ def append(

def arange(start, stop=None, step=None, dtype=None):
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
dtype = config.floatx()
return np.arange(start, stop, step=step, dtype=dtype)

Expand Down Expand Up @@ -124,6 +134,7 @@ def argsort(x, axis=-1):

def array(x, dtype=None):
dtype = dtype or config.floatx()
return np.array(x, dtype=dtype)

Expand All @@ -133,6 +144,23 @@ def average(x, axis=None, weights=None):

def bincount(x, weights=None, minlength=0):
if len(x.shape) == 2:
if weights is None:

def bincount_fn(arr):
return np.bincount(arr, minlength=minlength)

bincounts = list(map(bincount_fn, x))

def bincount_fn(arr_w):
return np.bincount(
arr_w[0], weights=arr_w[1], minlength=minlength

bincounts = list(map(bincount_fn, zip(x, weights)))

return np.stack(bincounts)
return np.bincount(x, weights, minlength)

Expand Down Expand Up @@ -254,6 +282,7 @@ def floor(x):

def full(shape, fill_value, dtype=None):
dtype = dtype or config.floatx()
return np.full(shape, fill_value, dtype=dtype)

Expand Down Expand Up @@ -575,7 +604,11 @@ def square(x):

def sqrt(x):
return np.sqrt(x)
dtype = None
if hasattr(x, "dtype"):
if standardize_dtype(x.dtype).startswith("int"):
dtype = config.floatx()
return np.sqrt(x, dtype=dtype)

def squeeze(x, axis=None):
Expand Down
179 changes: 169 additions & 10 deletions keras_core/backend/tensorflow/
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import builtins
import functools
import math
import warnings

import tensorflow as tf
from tensorflow.experimental import numpy as tfnp
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops

from keras_core.backend import config
from keras_core.backend.tensorflow.core import convert_to_tensor

def add(x1, x2):
if isinstance(x1, tf.SparseTensor) or isinstance(x2, tf.SparseTensor):
return tf.sparse.add(x1, x2)
return tfnp.add(x1, x2)

Expand Down Expand Up @@ -38,30 +45,129 @@ def einsum(subscripts, *operands, **kwargs):

def subtract(x1, x2):
if isinstance(x1, tf.SparseTensor) or isinstance(x2, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.add(x1, tf.sparse.map_values(tf.negative, x2))
return tf.sparse.add(x1, tf.negative(x2))
return tfnp.subtract(x1, x2)

def matmul(x1, x2):
if isinstance(x1, tf.SparseTensor):
def with_combined_batch_dimensions(a, b, fn_3d):
batch_shape = (
b.shape[:-2] if isinstance(b, tf.SparseTensor) else a.shape[:-2]
batch_size =
a_3d = reshape(a, [batch_size] + a.shape[-2:])
b_3d = reshape(b, [batch_size] + b.shape[-2:])
result = fn_3d(a_3d, b_3d)
return reshape(result, batch_shape + result.shape[1:])

def sparse_sparse_matmul(a, b):
dtype = a.values.dtype
# Convert SparseTensors to CSR SparseMatrix.
a_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
a.indices, a.values, a.dense_shape
b_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
b.indices, b.values, b.dense_shape
# Compute the CSR SparseMatrix matrix multiplication.
result_csr = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
a_csr, b_csr, dtype
# Convert the CSR SparseMatrix to a SparseTensor.
res = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
result_csr, dtype
return tf.SparseTensor(res.indices, res.values, res.dense_shape)

def embedding_lookup_sparse_dense_matmul(a, b):
# We need at least one id per rows for embedding_lookup_sparse,
# otherwise there will be missing rows in the output.
x1, _ = tf.sparse.fill_empty_rows(x1, 0)
a, _ = tf.sparse.fill_empty_rows(a, 0)
# We need to split x1 into separate ids and weights tensors. The ids
# should be the column indices of x1 and the values of the weights
# can continue to be the actual x1. The column arrangement of ids and
# weights does not matter as we sum over columns. See documentation for
# sparse_ops.sparse_tensor_dense_matmul for details.
# can continue to be the actual x1. The column arrangement of ids
# and weights does not matter as we sum over columns. See details in
# the documentation for sparse_ops.sparse_tensor_dense_matmul.
ids = tf.SparseTensor(
values=x1.indices[:, 1],
values=a.indices[:, 1],
weights = x1
return tf.nn.embedding_lookup_sparse(x2, ids, weights, combiner="sum")
return tf.nn.embedding_lookup_sparse(b, ids, a, combiner="sum")

# Either a or b is sparse
def sparse_dense_matmul_3d(a, b):
return tf.map_fn(
lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),
elems=(a, b),

x1_sparse = isinstance(x1, tf.SparseTensor)
x2_sparse = isinstance(x2, tf.SparseTensor)
if x1_sparse and x2_sparse:
if x1.shape.rank <= 3:
return sparse_sparse_matmul(x1, x2)
return with_combined_batch_dimensions(x1, x2, sparse_sparse_matmul)
elif x1_sparse or x2_sparse:
# Sparse * dense or dense * sparse
sparse_rank = x1.shape.rank if x1_sparse else x2.shape.rank

# Special case: embedding_lookup_sparse for sparse * dense and rank 2
if x1_sparse and sparse_rank == 2:
return embedding_lookup_sparse_dense_matmul(x1, x2)
elif sparse_rank == 2:
return tf.sparse.sparse_dense_matmul(x1, x2)
elif sparse_rank == 3:
return sparse_dense_matmul_3d(x1, x2)
return with_combined_batch_dimensions(
x1, x2, sparse_dense_matmul_3d

return tfnp.matmul(x1, x2)

def multiply(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
ones_like_int8 = functools.partial(tf.ones_like, dtype=tf.int8)
zeros_like_int8 = functools.partial(tf.zeros_like, dtype=tf.int8)

# compute the intersection of indices in the form of a sparse tensor
# containing ones as values
ones1 = tf.sparse.map_values(ones_like_int8, x1)
ones2 = tf.sparse.map_values(ones_like_int8, x2)
# tf.sets.intersection ignores the last dimension when comparing,
# so we need to add a dummy extra dimension and then remove it
intersection = tf.sparse.reshape(
tf.sparse.expand_dims(ones1, axis=-1),
tf.sparse.expand_dims(ones2, axis=-1),

# compute the masks to remove indices in x1 and x2 that are not part
# of the intersection, then trim x1 and x2
zeros1 = tf.sparse.map_values(zeros_like_int8, x1)
zeros2 = tf.sparse.map_values(zeros_like_int8, x2)
mask1 = tf.sparse.add(zeros1, intersection)
mask2 = tf.sparse.add(zeros2, intersection)
x1_trimmed = tf.sparse.retain(x1, tf.cast(mask1.values, tf.bool))
x2_trimmed = tf.sparse.retain(x2, tf.cast(mask2.values, tf.bool))

# now it is an element-wise multiplication on the values
return tf.sparse.map_values(tf.multiply, x1_trimmed, x2_trimmed)
return x1 * x2
elif isinstance(x2, tf.SparseTensor):
return x2 * x1
return tfnp.multiply(x1, x2)

Expand Down Expand Up @@ -133,6 +239,13 @@ def append(
def arange(start, stop=None, step=1, dtype=None):
# tfnp.arange has trouble with dynamic Tensors in compiled function.
# tf.range does not.
if dtype is None:
if hasattr(start, "dtype"):
dtype = start.dtype
elif isinstance(start, int):
dtype = "int32"
dtype = config.floatx()
return tf.range(start, stop, delta=step, dtype=dtype)

Expand Down Expand Up @@ -202,6 +315,15 @@ def clip(x, x_min, x_max):

def concatenate(xs, axis=0):
sparse_count = builtins.sum(isinstance(x, tf.SparseTensor) for x in xs)
if sparse_count:
if sparse_count == len(xs):
return tf.sparse.concat(axis=axis, sp_inputs=xs)
xs = [
tf.sparse.to_dense(x) if isinstance(x, tf.SparseTensor) else x
for x in xs
return tfnp.concatenate(xs, axis=axis)

Expand Down Expand Up @@ -294,6 +416,8 @@ def exp(x):

def expand_dims(x, axis):
if isinstance(x, tf.SparseTensor):
return tf.sparse.expand_dims(x, axis)
return tfnp.expand_dims(x, axis)

Expand Down Expand Up @@ -420,6 +544,13 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):

def maximum(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.maximum(x1, x2)
x1 = tf.sparse.to_dense(x1)
elif isinstance(x2, tf.SparseTensor):
x2 = tf.sparse.to_dense(x2)
return tfnp.maximum(x1, x2)

Expand Down Expand Up @@ -449,6 +580,13 @@ def min(x, axis=None, keepdims=False, initial=None):

def minimum(x1, x2):
if isinstance(x1, tf.SparseTensor):
if isinstance(x2, tf.SparseTensor):
return tf.sparse.minimum(x1, x2)
x1 = tf.sparse.to_dense(x1)
elif isinstance(x2, tf.SparseTensor):
x2 = tf.sparse.to_dense(x2)
return tfnp.minimum(x1, x2)

Expand Down Expand Up @@ -683,10 +821,31 @@ def square(x):

def sqrt(x):
x = convert_to_tensor(x)
if tf.as_dtype(x.dtype).is_integer:
x = tf.cast(x, dtype=config.floatx())
return tfnp.sqrt(x)

def squeeze(x, axis=None):
if isinstance(x, tf.SparseTensor):
new_shape = list(x.shape)
gather_indices = list(range(len(new_shape)))
if axis is None:
for i in range(len(new_shape) - 1, -1, -1):
if new_shape[i] == 1:
del new_shape[i]
del gather_indices[i]
if new_shape[axis] != 1:
raise ValueError(
f"Cannot squeeze axis {axis}, because the "
"dimension is not 1."
del new_shape[axis]
del gather_indices[axis]
new_indices = tf.gather(x.indices, gather_indices, axis=1)
return tf.SparseTensor(new_indices, x.values, tuple(new_shape))
return tfnp.squeeze(x, axis=axis)

Expand Down

0 comments on commit 686f339

Please sign in to comment.