Skip to content

Commit

Permalink
Fix a pretty major bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 26, 2023
1 parent 0d70141 commit aa88067
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 26 deletions.
2 changes: 0 additions & 2 deletions keras_core/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from keras_core.utils.naming import auto_name

DYNAMIC_SHAPES_OK = True
# Disable autograph
tf.__internal__.autograph.control_status_ctx().status = 2


class Variable(KerasVariable, tf.__internal__.types.Tensor):
Expand Down
6 changes: 2 additions & 4 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def max(x, axis=None, keepdims=False):


def ones(shape, dtype="float32"):
with tf.init_scope():
return tf.ones(shape, dtype=dtype)
return tf.ones(shape, dtype=dtype)


def zeros(shape, dtype="float32"):
with tf.init_scope():
return tf.zeros(shape, dtype=dtype)
return tf.zeros(shape, dtype=dtype)
32 changes: 14 additions & 18 deletions keras_core/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

def tf_draw_seed(seed):
# TF ops only accept int32/64 seeds but our base seed is uint32.
with tf.init_scope():
return tf.cast(draw_seed(seed), dtype="int32")
return tf.cast(draw_seed(seed), dtype="int32")


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
Expand All @@ -35,10 +34,9 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
return tf.random.stateless_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
Expand All @@ -65,14 +63,13 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_uniform(
shape=shape,
minval=minval,
maxval=maxval,
dtype=dtype,
seed=seed,
)
return tf.random.stateless_uniform(
shape=shape,
minval=minval,
maxval=maxval,
dtype=dtype,
seed=seed,
)


def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
Expand All @@ -98,10 +95,9 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
"""
dtype = dtype or floatx()
seed = tf_draw_seed(seed)
with tf.init_scope():
return tf.random.stateless_truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)
return tf.random.stateless_truncated_normal(
shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
)


def dropout(inputs, rate, noise_shape=None, seed=None):
Expand Down
2 changes: 1 addition & 1 deletion keras_core/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def functional_like_constructor(cls):


def unpack_singleton(x):
if len(x) == 1:
if isinstance(x, (list, tuple)) and len(x) == 1:
return x[0]
return x

Expand Down
2 changes: 1 addition & 1 deletion keras_core/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_temp_dir(self):
self.addCleanup(lambda: shutil.rmtree(temp_dir))
return temp_dir

def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6):
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)

def assertAlmostEqual(self, x1, x2, decimal=3):
Expand Down

0 comments on commit aa88067

Please sign in to comment.