Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def value(self):
return self._maybe_autocast(self._value)

def assign(self, value):
value = self._convert_to_tensor(value, dtype=self.dtype)
value = self._convert_to_tensor(value, dtype=self._dtype)
if not shape_equal(value.shape, self.shape):
raise ValueError(
"The shape of the target variable and "
Expand Down
32 changes: 30 additions & 2 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def test_trainable_setter(self):
v.trainable = False
self.assertFalse(v._value.requires_grad)

def test_autocasting(self):
"""Tests autocasting of float variables."""
def test_autocasting_float(self):
# Tests autocasting of float variables
v = backend.Variable(
initializer=initializers.RandomNormal(),
shape=(2, 2),
Expand All @@ -191,6 +191,33 @@ def test_autocasting(self):
)
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")

def test_autocasting_float_assign(self):
# Tests assigning value to variable within an autocast scope
v = backend.Variable(
initializer=initializers.RandomNormal(),
shape=(2, 2),
dtype="float32",
)
self.assertEqual(v.dtype, "float32")
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")

# Assign float16 value within float16 scope
with AutocastScope("float16"):
self.assertEqual(
backend.standardize_dtype(v.value.dtype), "float16"
)
v.assign(ops.ones((2, 2), "float16"))
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")

# Assign float32 value within float16 scope
with AutocastScope("float16"):
self.assertEqual(
backend.standardize_dtype(v.value.dtype), "float16"
)
v.assign(ops.zeros((2, 2), "float32"))
self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32")

def test_autocasting_int(self):
# Test non-float variables are not affected
v = backend.Variable(
initializer=initializers.Ones(),
Expand All @@ -204,6 +231,7 @@ def test_autocasting(self):
with AutocastScope("float16"):
self.assertEqual(backend.standardize_dtype(v.value.dtype), "int32")

def test_autocasting_float_with_autocast_off(self):
# Test autocast argument
v = backend.Variable(
initializer=initializers.RandomNormal(),
Expand Down
Loading