From 889d5bfe3ef340b6a2ccb0438b6d7cb12c72ff8f Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:05:52 -0700 Subject: [PATCH] Remove reliance on `__jax_array__` to unwrap variables. Several months ago JAX deprecated passing __jax_array__-implementing objects directly, to, e.g., a jit-ted function. JAX has emitted a warning since that time. In a future release of JAX this will become a hard error. --- keras/src/backend/jax/core.py | 7 +++- keras/src/backend/jax/nn.py | 2 + keras/src/backend/jax/numpy.py | 25 +++++++----- keras/src/backend/jax/optimizer.py | 5 ++- keras/src/layers/attention/attention.py | 2 +- keras/src/layers/layer_test.py | 8 ++-- .../normalization/layer_normalization_test.py | 4 +- .../normalization/rms_normalization_test.py | 10 +++-- keras/src/layers/rnn/gru.py | 2 +- keras/src/layers/rnn/lstm.py | 4 +- keras/src/layers/rnn/simple_rnn.py | 2 +- keras/src/ops/core_test.py | 4 +- keras/src/ops/nn.py | 2 +- keras/src/optimizers/adafactor.py | 39 +++++++++++++----- keras/src/optimizers/base_optimizer.py | 11 +++-- keras/src/optimizers/loss_scale_optimizer.py | 16 ++++---- .../optimizers/loss_scale_optimizer_test.py | 40 +++++++++++++------ 17 files changed, 121 insertions(+), 62 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7dc5a98fb8d5..528ccfdb3b88 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -56,7 +56,12 @@ def _convert_to_tensor(self, value, dtype=None): # Overload native accessor. def __jax_array__(self): - return self.value + # TODO UNDO + import traceback + + print("### __jax_array__") + traceback.print_stack() + raise ValueError("__jax_array__") Variable = JaxVariable diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index ea83e758a22f..23fc1249e3ed 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -396,6 +396,8 @@ def depthwise_conv( feature_group_count = ( inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] ) + kernel = convert_to_tensor(kernel) + inputs = convert_to_tensor(inputs) kernel = jnp.reshape( kernel, kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 0899a1dc11ac..e9def4b255c9 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -543,15 +543,18 @@ def clip(x, x_min, x_max): def concatenate(xs, axis=0): bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs) - if bcoo_count: - if bcoo_count == len(xs): - axis = canonicalize_axis(axis, len(xs[0].shape)) - return jax_sparse.bcoo_concatenate(xs, dimension=axis) - else: - xs = [ - x.todense() if isinstance(x, jax_sparse.JAXSparse) else x - for x in xs - ] + if bcoo_count == len(xs): + axis = canonicalize_axis(axis, len(xs[0].shape)) + return jax_sparse.bcoo_concatenate(xs, dimension=axis) + elif bcoo_count: + xs = [ + x.todense() + if isinstance(x, jax_sparse.JAXSparse) + else convert_to_tensor(x) + for x in xs + ] + else: + xs = [convert_to_tensor(x) for x in xs] return jnp.concatenate(xs, axis=axis) @@ -1087,6 +1090,7 @@ def reshape(x, newshape): if None not in output_shape: newshape = output_shape return jax_sparse.bcoo_reshape(x, new_sizes=newshape) + x = convert_to_tensor(x) return jnp.reshape(x, newshape) @@ -1149,10 +1153,12 @@ def sort(x, axis=-1): def split(x, indices_or_sections, axis=0): + x = convert_to_tensor(x) return jnp.split(x, indices_or_sections, axis=axis) def stack(x, axis=0): + x = [convert_to_tensor(t) for t in x] return jnp.stack(x, axis=axis) @@ -1338,6 +1344,7 @@ def squeeze(x, axis=None): axis = tuple(i for i, d in enumerate(x.shape) if d == 1) axis = to_tuple_or_list(axis) return jax_sparse.bcoo_squeeze(x, dimensions=axis) + x = convert_to_tensor(x) return jnp.squeeze(x, axis=axis) diff --git a/keras/src/backend/jax/optimizer.py b/keras/src/backend/jax/optimizer.py index ef366d2cf502..5cd6a40f65fb 100644 --- a/keras/src/backend/jax/optimizer.py +++ b/keras/src/backend/jax/optimizer.py @@ -36,13 +36,14 @@ def _backend_apply_gradients(self, grads, trainable_variables): new_g_accs = jax.lax.cond( is_update_step, lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads], - lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)], + lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)], ) grads = jax.lax.cond( is_update_step, lambda: [ - (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads) + (g + acc_g.value) / steps + for g, acc_g in zip(grads, acc_grads) ], lambda: list(grads), ) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 84bf035257f7..04e3f399c5e5 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -121,7 +121,7 @@ def _calculate_scores(self, query, key): if self.score_mode == "dot": scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1])) if self.scale is not None: - scores *= self.scale + scores = ops.multiply(scores, self.scale) elif self.score_mode == "concat": # Reshape tensors to enable broadcasting. # Reshape into [batch_size, Tq, 1, dim]. diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index aa27eb9aac71..53531b679cc5 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -678,7 +678,7 @@ def __init__(self): def call(self, x): # Should not autocast. assertDType(self.v, "float32") - return ops.cast(x, "float32") + self.v + return ops.add(ops.cast(x, "float32"), self.v) # A layer that is explicitly full precision. class InnerLayerTwo(layers.Layer): @@ -694,7 +694,7 @@ def __init__(self): def call(self, x): # Should not autocast. assertDType(self.v, "float32") - return x + self.v + return ops.add(x, self.v) # A layer that is explicitly mixed precision but with autocast=False # weight. @@ -732,7 +732,7 @@ def call(self, x): # Should autocast. assertDType(self.v, "float16") return self.inner_three( - self.inner_two(self.inner_one(x + self.v)) + self.inner_two(self.inner_one(ops.add(x, self.v))) ) layer = MixedPrecisionLayer() @@ -935,7 +935,7 @@ def call(self, x): x = x + backend.random.normal( shape=(), seed=self._seed_generator ) - return x + self.tw + self.ntw + return ops.add(x, ops.add(self.tw, self.ntw)) data = np.random.random((3, 4)) layer = TestLayer() diff --git a/keras/src/layers/normalization/layer_normalization_test.py b/keras/src/layers/normalization/layer_normalization_test.py index 384d2053b2e2..ad2c72006204 100644 --- a/keras/src/layers/normalization/layer_normalization_test.py +++ b/keras/src/layers/normalization/layer_normalization_test.py @@ -99,8 +99,8 @@ def test_correctness(self): ).astype("float32") out = layer(inputs) - out -= layer.beta - out /= layer.gamma + out = ops.subtract(out, layer.beta) + out = ops.divide(out, layer.gamma) self.assertAllClose(ops.mean(out), 0.0, atol=1e-1) self.assertAllClose(ops.std(out), 1.0, atol=1e-1) diff --git a/keras/src/layers/normalization/rms_normalization_test.py b/keras/src/layers/normalization/rms_normalization_test.py index c15390b920f7..5e56fa94634b 100644 --- a/keras/src/layers/normalization/rms_normalization_test.py +++ b/keras/src/layers/normalization/rms_normalization_test.py @@ -38,10 +38,12 @@ def test_correctness(self): inputs = ops.convert_to_tensor(inputs) out = layer(inputs) - expected = ( - inputs - * ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)) - * layer.scale + expected = ops.multiply( + ops.multiply( + inputs, + ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)), + ), + layer.scale, ) self.assertAllClose(out, expected, atol=1e-1) diff --git a/keras/src/layers/rnn/gru.py b/keras/src/layers/rnn/gru.py index 8b76d6baca2d..3a6abd2d1cbb 100644 --- a/keras/src/layers/rnn/gru.py +++ b/keras/src/layers/rnn/gru.py @@ -261,7 +261,7 @@ def call(self, inputs, states, training=False): matrix_x = ops.matmul(inputs, self.kernel) if self.use_bias: # biases: bias_z_i, bias_r_i, bias_h_i - matrix_x += input_bias + matrix_x = ops.add(matrix_x, input_bias) x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1) diff --git a/keras/src/layers/rnn/lstm.py b/keras/src/layers/rnn/lstm.py index 9ede7261153d..32a426a8ee50 100644 --- a/keras/src/layers/rnn/lstm.py +++ b/keras/src/layers/rnn/lstm.py @@ -276,9 +276,9 @@ def call(self, inputs, states, training=False): z = ops.matmul(inputs, self.kernel) - z += ops.matmul(h_tm1, self.recurrent_kernel) + z = ops.add(z, ops.matmul(h_tm1, self.recurrent_kernel)) if self.use_bias: - z += self.bias + z = ops.add(z, self.bias) z = ops.split(z, 4, axis=1) c, o = self._compute_carry_and_output_fused(z, c_tm1) diff --git a/keras/src/layers/rnn/simple_rnn.py b/keras/src/layers/rnn/simple_rnn.py index b68ecd64792a..b811baf88234 100644 --- a/keras/src/layers/rnn/simple_rnn.py +++ b/keras/src/layers/rnn/simple_rnn.py @@ -160,7 +160,7 @@ def call(self, sequence, states, training=False): sequence = sequence * dp_mask h = ops.matmul(sequence, self.kernel) if self.bias is not None: - h += self.bias + h = ops.add(h, self.bias) if training and rec_dp_mask is not None: prev_output = prev_output * rec_dp_mask diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index c0059d965b22..ff49a4d34e05 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -1185,7 +1185,9 @@ def __init__(self): self.b = self.add_weight(shape=(1,), initializer="zeros") def call(self, x, training=False): - return x * ops.stop_gradient(self.w) + self.b + return ops.add( + ops.multiply(x, ops.stop_gradient(self.w)), self.b + ) model = models.Sequential([ExampleLayer()]) model.compile( diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 6acaedfce999..7f171b83ab53 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -1443,7 +1443,7 @@ def depthwise_conv( """ data_format = standardize_data_format(data_format) padding = padding.lower() - if any_symbolic_tensors((inputs,)): + if any_symbolic_tensors((inputs, kernel)): return DepthwiseConv( strides, padding, data_format, dilation_rate ).symbolic_call(inputs, kernel) diff --git a/keras/src/optimizers/adafactor.py b/keras/src/optimizers/adafactor.py index 54fd74d1e783..6c406043353e 100644 --- a/keras/src/optimizers/adafactor.py +++ b/keras/src/optimizers/adafactor.py @@ -158,33 +158,52 @@ def update_step(self, gradient, variable, learning_rate): rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step)) alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1) - beta_2_t = 1 - ops.power(local_step, self.beta_2_decay) + beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay)) if len(variable.shape) >= 2: # `r` deletes the last dimension of gradient, so it is of shape # `gradient.shape[:-1]`. self.assign( r, - beta_2_t * r - + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1), + ops.add( + ops.multiply(beta_2_t, r), + ops.multiply( + ops.subtract(1, beta_2_t), + ops.mean(regulated_grad_square, axis=-1), + ), + ), ) # `c` deletes the second last dimension of gradient, so it is of # shape `gradient.shape[:-2] + gradient.shape[-1]`. self.assign( c, - beta_2_t * c - + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2), + ops.add( + ops.multiply(beta_2_t, c), + ops.multiply( + ops.subtract(1, beta_2_t), + ops.mean(regulated_grad_square, axis=-2), + ), + ), ) self.assign( v, - ops.expand_dims( - r / ops.mean(r, axis=-1, keepdims=True), axis=-1 - ) - * ops.expand_dims(c, -2), + ops.multiply( + ops.expand_dims( + ops.divide(r, ops.mean(r, axis=-1, keepdims=True)), + axis=-1, + ), + ops.expand_dims(c, -2), + ), ) else: self.assign( - v, beta_2_t * v + (1 - beta_2_t) * regulated_grad_square + v, + ops.add( + ops.multiply(beta_2_t, v), + ops.multiply( + ops.subtract(1, beta_2_t), regulated_grad_square + ), + ), ) u_t = ops.divide(gradient, ops.sqrt(v)) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index c3ecdd2baab9..020e92a3fce0 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -969,10 +969,15 @@ def _update_model_variables_moving_average(self, trainable_variables): ): if average is not None: not_first_step = ops.not_equal(self.iterations, 0) - momentum = ( - ops.cast(not_first_step, var.dtype) * self.ema_momentum + momentum = ops.multiply( + ops.cast(not_first_step, var.dtype), self.ema_momentum + ) + average.assign( + ops.add( + ops.multiply(momentum, average), + ops.multiply(ops.subtract(1, momentum), var), + ) ) - average.assign(momentum * average + (1 - momentum) * var) def _overwrite_model_variables_with_average_value( self, trainable_variables diff --git a/keras/src/optimizers/loss_scale_optimizer.py b/keras/src/optimizers/loss_scale_optimizer.py index 1b9945c4157b..566886c1fa3f 100644 --- a/keras/src/optimizers/loss_scale_optimizer.py +++ b/keras/src/optimizers/loss_scale_optimizer.py @@ -112,7 +112,7 @@ def upscale(): mapping = list(zip(self.variables, optimizer_variables)) with backend.StatelessScope(state_mapping=mapping) as scope: self.step_counter.assign(0) - self.dynamic_scale.assign(self.dynamic_scale * 2.0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0)) return [scope.get_current_value(v) for v in self._variables] def increment(): @@ -136,7 +136,9 @@ def increment(): g if g is None or self._overwrite_variable_with_gradient(v) else ops.divide(g, scale) - for g, v in zip(grads, trainable_variables) + for g, v in zip( + grads, self.inner_optimizer._trainable_variables + ) ] ( new_trainable_variables, @@ -156,7 +158,7 @@ def _stateless_handle_non_finite_grads( mapping = list(zip(self.variables, optimizer_variables)) with backend.StatelessScope(state_mapping=mapping) as scope: self.step_counter.assign(0) - self.dynamic_scale.assign(self.dynamic_scale / 2.0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5)) new_optimizer_variables = [] for v in self.variables: new_optimizer_variables.append(scope.get_current_value(v)) @@ -177,7 +179,7 @@ def apply(self, grads, trainable_variables=None): def _stateful_handle_finite_grads(self, grads, trainable_variables): scale = self.dynamic_scale # Unscale gradients. - tvs = trainable_variables or self._trainable_variables + tvs = trainable_variables or self.inner_optimizer._trainable_variables unscaled_grads = [ g if g is None or self._overwrite_variable_with_gradient(v) @@ -190,7 +192,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables): def upscale(): self.step_counter.assign(0) - self.dynamic_scale.assign(self.dynamic_scale * 2.0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0)) def increment(): self.step_counter.assign_add(1) @@ -205,7 +207,7 @@ def increment(): def _stateful_handle_non_finite_grads(self): # If any inf or nan in grads, downscale loss and reset counter. self.step_counter.assign(0) - self.dynamic_scale.assign(self.dynamic_scale / 2.0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5)) def _common_apply(self, grads, trainable_variables=None): finite = self.check_finite(grads) @@ -278,7 +280,7 @@ def iterations(self): def scale_loss(self, loss): scale = self.dynamic_scale if self.built else self.initial_scale - return loss * scale + return ops.multiply(loss, scale) def finalize_variable_values(self, var_list): self.inner_optimizer.finalize_variable_values(var_list) diff --git a/keras/src/optimizers/loss_scale_optimizer_test.py b/keras/src/optimizers/loss_scale_optimizer_test.py index c053d96787f6..e9408d4e8242 100644 --- a/keras/src/optimizers/loss_scale_optimizer_test.py +++ b/keras/src/optimizers/loss_scale_optimizer_test.py @@ -40,7 +40,9 @@ def test_finite_step(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -60,7 +62,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -79,7 +83,9 @@ def test_infinite_step(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -98,7 +104,9 @@ def test_finite_step_with_overwrite(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -112,12 +120,14 @@ def test_downscaling(self, stateless): optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([np.inf, np.inf, np.inf, np.inf])] for _ in range(4): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars) @@ -135,12 +145,14 @@ def test_upscaling(self, stateless): ) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([1.0, 6.0, 7.0, 2.0])] for _ in range(8): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars) @@ -154,15 +166,17 @@ def test_iterations_update(self, stateless): optimizer = LossScaleOptimizer(inner_optimizer) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([1.0, 6.0, 7.0, 2.0])] self.assertEqual(optimizer.iterations.value, 0) for i in range(3): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars)