Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions keras/src/backend/common/remat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,31 @@ def test_remat_basic_call(self):
batch_size=batch_size,
verbose=0,
)

def test_remat_with_kwargs(self):
if backend.backend() in ("openvino", "numpy"):
self.skipTest(
"remat is not supported in openvino and numpy backends."
)

# Define a function that uses keyword arguments
def fn_with_kwargs(x, scale=1.0, offset=0.0):
return x * scale + offset

x = np.array([1.0, 2.0, 3.0], dtype=np.float32)

# Test with keyword arguments
remat_fn = backend.core.remat(fn_with_kwargs)
result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0)
expected = fn_with_kwargs(x, scale=2.0, offset=1.0)
self.assertAllClose(result_with_kwargs, expected)

# Test with default keyword arguments
result_with_defaults = remat_fn(x)
expected_defaults = fn_with_kwargs(x)
self.assertAllClose(result_with_defaults, expected_defaults)

# Test with partial keyword arguments
result_partial = remat_fn(x, scale=3.0)
expected_partial = fn_with_kwargs(x, scale=3.0)
self.assertAllClose(result_partial, expected_partial)
9 changes: 8 additions & 1 deletion keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ml_dtypes
import numpy as np
import torch
from torch.utils.checkpoint import checkpoint

from keras.src import tree
from keras.src.backend.common import KerasVariable
Expand Down Expand Up @@ -673,7 +674,13 @@ def remat(f):
"""

def wrapped(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
if not kwargs:
return checkpoint(f, *args, use_reentrant=False)

def positional_wrapper(*pos_args):
return f(*pos_args, **kwargs)

return checkpoint(positional_wrapper, *args, use_reentrant=False)
Comment on lines -676 to +683
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the documentation, it looks like you can just do:

return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False, **kwargs)

Is that not the case?

My concern with your approach is that I think it statically binds the kwargs so they cannot be tensors.

Copy link
Contributor Author

@Abhinavexists Abhinavexists Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh
PyTorch's checkpoint() doesn't support passing kwargs to the checkpointed function as i tried to use return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False, **kwargs) because it was my initial approach but it looses there kwargs and only accept positional args.

Regarding the static binding concern: you're right that this could be problematic. Let me add a test case with tensor kwargs to verify gradient tracking works correctly.

if that works here ?


return wrapped

Expand Down
Loading