Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def _convert_to_tensor(self, value, dtype=None):

# Overload native accessor.
def __jax_array__(self):
return self.value
# TODO UNDO
import traceback

print("### __jax_array__")
traceback.print_stack()
raise ValueError("__jax_array__")
Comment on lines +59 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This debugging code, including the TODO UNDO comment, print statements, and the ValueError exception, should be removed before merging the pull request. It appears to have been added to trace usages of __jax_array__ but is not intended for the final version.

Suggested change
# TODO UNDO
import traceback
print("### __jax_array__")
traceback.print_stack()
raise ValueError("__jax_array__")
return self.value



Variable = JaxVariable
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ def depthwise_conv(
feature_group_count = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel = convert_to_tensor(kernel)
inputs = convert_to_tensor(inputs)
kernel = jnp.reshape(
kernel,
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
Expand Down
25 changes: 16 additions & 9 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,18 @@ def clip(x, x_min, x_max):

def concatenate(xs, axis=0):
bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
if bcoo_count:
if bcoo_count == len(xs):
axis = canonicalize_axis(axis, len(xs[0].shape))
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
else:
xs = [
x.todense() if isinstance(x, jax_sparse.JAXSparse) else x
for x in xs
]
if bcoo_count == len(xs):
axis = canonicalize_axis(axis, len(xs[0].shape))
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
elif bcoo_count:
xs = [
x.todense()
if isinstance(x, jax_sparse.JAXSparse)
else convert_to_tensor(x)
for x in xs
]
else:
xs = [convert_to_tensor(x) for x in xs]
return jnp.concatenate(xs, axis=axis)


Expand Down Expand Up @@ -1087,6 +1090,7 @@ def reshape(x, newshape):
if None not in output_shape:
newshape = output_shape
return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
x = convert_to_tensor(x)
return jnp.reshape(x, newshape)


Expand Down Expand Up @@ -1149,10 +1153,12 @@ def sort(x, axis=-1):


def split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
return jnp.split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
x = [convert_to_tensor(t) for t in x]
return jnp.stack(x, axis=axis)


Expand Down Expand Up @@ -1338,6 +1344,7 @@ def squeeze(x, axis=None):
axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
axis = to_tuple_or_list(axis)
return jax_sparse.bcoo_squeeze(x, dimensions=axis)
x = convert_to_tensor(x)
return jnp.squeeze(x, axis=axis)


Expand Down
5 changes: 3 additions & 2 deletions keras/src/backend/jax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def _backend_apply_gradients(self, grads, trainable_variables):
new_g_accs = jax.lax.cond(
is_update_step,
lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
)

grads = jax.lax.cond(
is_update_step,
lambda: [
(g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
(g + acc_g.value) / steps
for g, acc_g in zip(grads, acc_grads)
],
lambda: list(grads),
)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _calculate_scores(self, query, key):
if self.score_mode == "dot":
scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))
if self.scale is not None:
scores *= self.scale
scores = ops.multiply(scores, self.scale)
elif self.score_mode == "concat":
# Reshape tensors to enable broadcasting.
# Reshape into [batch_size, Tq, 1, dim].
Expand Down
8 changes: 4 additions & 4 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def __init__(self):
def call(self, x):
# Should not autocast.
assertDType(self.v, "float32")
return ops.cast(x, "float32") + self.v
return ops.add(ops.cast(x, "float32"), self.v)

# A layer that is explicitly full precision.
class InnerLayerTwo(layers.Layer):
Expand All @@ -694,7 +694,7 @@ def __init__(self):
def call(self, x):
# Should not autocast.
assertDType(self.v, "float32")
return x + self.v
return ops.add(x, self.v)

# A layer that is explicitly mixed precision but with autocast=False
# weight.
Expand Down Expand Up @@ -732,7 +732,7 @@ def call(self, x):
# Should autocast.
assertDType(self.v, "float16")
return self.inner_three(
self.inner_two(self.inner_one(x + self.v))
self.inner_two(self.inner_one(ops.add(x, self.v)))
)

layer = MixedPrecisionLayer()
Expand Down Expand Up @@ -935,7 +935,7 @@ def call(self, x):
x = x + backend.random.normal(
shape=(), seed=self._seed_generator
)
return x + self.tw + self.ntw
return ops.add(x, ops.add(self.tw, self.ntw))

data = np.random.random((3, 4))
layer = TestLayer()
Expand Down
4 changes: 2 additions & 2 deletions keras/src/layers/normalization/layer_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_correctness(self):
).astype("float32")

out = layer(inputs)
out -= layer.beta
out /= layer.gamma
out = ops.subtract(out, layer.beta)
out = ops.divide(out, layer.gamma)

self.assertAllClose(ops.mean(out), 0.0, atol=1e-1)
self.assertAllClose(ops.std(out), 1.0, atol=1e-1)
Expand Down
10 changes: 6 additions & 4 deletions keras/src/layers/normalization/rms_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def test_correctness(self):
inputs = ops.convert_to_tensor(inputs)

out = layer(inputs)
expected = (
inputs
* ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True))
* layer.scale
expected = ops.multiply(
ops.multiply(
inputs,
ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)),
),
layer.scale,
)

self.assertAllClose(out, expected, atol=1e-1)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/rnn/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def call(self, inputs, states, training=False):
matrix_x = ops.matmul(inputs, self.kernel)
if self.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x += input_bias
matrix_x = ops.add(matrix_x, input_bias)

x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1)

Expand Down
4 changes: 2 additions & 2 deletions keras/src/layers/rnn/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def call(self, inputs, states, training=False):

z = ops.matmul(inputs, self.kernel)

z += ops.matmul(h_tm1, self.recurrent_kernel)
z = ops.add(z, ops.matmul(h_tm1, self.recurrent_kernel))
if self.use_bias:
z += self.bias
z = ops.add(z, self.bias)

z = ops.split(z, 4, axis=1)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
Expand Down
2 changes: 1 addition & 1 deletion keras/src/layers/rnn/simple_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def call(self, sequence, states, training=False):
sequence = sequence * dp_mask
h = ops.matmul(sequence, self.kernel)
if self.bias is not None:
h += self.bias
h = ops.add(h, self.bias)

if training and rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
Expand Down
4 changes: 3 additions & 1 deletion keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,9 @@ def __init__(self):
self.b = self.add_weight(shape=(1,), initializer="zeros")

def call(self, x, training=False):
return x * ops.stop_gradient(self.w) + self.b
return ops.add(
ops.multiply(x, ops.stop_gradient(self.w)), self.b
)

model = models.Sequential([ExampleLayer()])
model.compile(
Expand Down
2 changes: 1 addition & 1 deletion keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ def depthwise_conv(
"""
data_format = standardize_data_format(data_format)
padding = padding.lower()
if any_symbolic_tensors((inputs,)):
if any_symbolic_tensors((inputs, kernel)):
return DepthwiseConv(
strides, padding, data_format, dilation_rate
).symbolic_call(inputs, kernel)
Expand Down
39 changes: 29 additions & 10 deletions keras/src/optimizers/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,33 +158,52 @@ def update_step(self, gradient, variable, learning_rate):
rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)
beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)
beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay))

if len(variable.shape) >= 2:
# `r` deletes the last dimension of gradient, so it is of shape
# `gradient.shape[:-1]`.
self.assign(
r,
beta_2_t * r
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1),
ops.add(
ops.multiply(beta_2_t, r),
ops.multiply(
ops.subtract(1, beta_2_t),
ops.mean(regulated_grad_square, axis=-1),
),
),
)
# `c` deletes the second last dimension of gradient, so it is of
# shape `gradient.shape[:-2] + gradient.shape[-1]`.
self.assign(
c,
beta_2_t * c
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2),
ops.add(
ops.multiply(beta_2_t, c),
ops.multiply(
ops.subtract(1, beta_2_t),
ops.mean(regulated_grad_square, axis=-2),
),
),
)
self.assign(
v,
ops.expand_dims(
r / ops.mean(r, axis=-1, keepdims=True), axis=-1
)
* ops.expand_dims(c, -2),
ops.multiply(
ops.expand_dims(
ops.divide(r, ops.mean(r, axis=-1, keepdims=True)),
axis=-1,
),
ops.expand_dims(c, -2),
),
)
else:
self.assign(
v, beta_2_t * v + (1 - beta_2_t) * regulated_grad_square
v,
ops.add(
ops.multiply(beta_2_t, v),
ops.multiply(
ops.subtract(1, beta_2_t), regulated_grad_square
),
),
)

u_t = ops.divide(gradient, ops.sqrt(v))
Expand Down
11 changes: 8 additions & 3 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,10 +969,15 @@ def _update_model_variables_moving_average(self, trainable_variables):
):
if average is not None:
not_first_step = ops.not_equal(self.iterations, 0)
momentum = (
ops.cast(not_first_step, var.dtype) * self.ema_momentum
momentum = ops.multiply(
ops.cast(not_first_step, var.dtype), self.ema_momentum
)
average.assign(
ops.add(
ops.multiply(momentum, average),
ops.multiply(ops.subtract(1, momentum), var),
)
)
average.assign(momentum * average + (1 - momentum) * var)

def _overwrite_model_variables_with_average_value(
self, trainable_variables
Expand Down
16 changes: 9 additions & 7 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def upscale():
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))
return [scope.get_current_value(v) for v in self._variables]

def increment():
Expand All @@ -136,7 +136,9 @@ def increment():
g
if g is None or self._overwrite_variable_with_gradient(v)
else ops.divide(g, scale)
for g, v in zip(grads, trainable_variables)
for g, v in zip(
grads, self.inner_optimizer._trainable_variables
)
]
(
new_trainable_variables,
Expand All @@ -156,7 +158,7 @@ def _stateless_handle_non_finite_grads(
mapping = list(zip(self.variables, optimizer_variables))
with backend.StatelessScope(state_mapping=mapping) as scope:
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))
new_optimizer_variables = []
for v in self.variables:
new_optimizer_variables.append(scope.get_current_value(v))
Expand All @@ -177,7 +179,7 @@ def apply(self, grads, trainable_variables=None):
def _stateful_handle_finite_grads(self, grads, trainable_variables):
scale = self.dynamic_scale
# Unscale gradients.
tvs = trainable_variables or self._trainable_variables
tvs = trainable_variables or self.inner_optimizer._trainable_variables
unscaled_grads = [
g
if g is None or self._overwrite_variable_with_gradient(v)
Expand All @@ -190,7 +192,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables):

def upscale():
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale * 2.0)
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))

def increment():
self.step_counter.assign_add(1)
Expand All @@ -205,7 +207,7 @@ def increment():
def _stateful_handle_non_finite_grads(self):
# If any inf or nan in grads, downscale loss and reset counter.
self.step_counter.assign(0)
self.dynamic_scale.assign(self.dynamic_scale / 2.0)
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))

def _common_apply(self, grads, trainable_variables=None):
finite = self.check_finite(grads)
Expand Down Expand Up @@ -278,7 +280,7 @@ def iterations(self):

def scale_loss(self, loss):
scale = self.dynamic_scale if self.built else self.initial_scale
return loss * scale
return ops.multiply(loss, scale)

def finalize_variable_values(self, var_list):
self.inner_optimizer.finalize_variable_values(var_list)
Expand Down
Loading