Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Gradients #19302

Merged
merged 18 commits into from
Mar 22, 2024
Merged

Custom Gradients #19302

merged 18 commits into from
Mar 22, 2024

Conversation

abhaskumarsinha
Copy link
Contributor

With little bit of tweaking of the syntax, now its possible to have custom gradients for all three - JAX, PyTorch and TensorFlow.

Sorry if I don't know, tho I added the unit tests, I don't know how to run to test it.

Example of syntax:

Here's the rough notebook testing it:

I believe custom addition of function and their custom gradient definitions would enable the users to implement complex layers and operations even if that isn't available in other frameworks. It could also help them work with custom data structures like - complex numbers, non metric spaces or probabilistic spaces.

Example Syntax:

@keras.ops.custom_gradient
def fun(x):
    z = x * x
    def grad(*args, upstream=None):
        if upstream is None:
            # tf.custom_gradient convention
            upstream, = args
        return upstream * x * 10
    return z, grad

x = torch.tensor([2.0], requires_grad = True)
z = fun(x)
z.sum().backward()
x.grad

Only difference is that grad function accepts two arguments (one keyword) in the case of PyTorch and one argument in the case of TensorFlow or Jax. So additional

def grad(*args, upstream = None):
     if upstream == None:
          upstream, = args

doesn't hurt more.

Thank You.

Added a support for @custom_gradient decorator for PyTorch users.
Added documentation for ops.custom_gradient and edited the syntax of log1pexp(x) example to demonstrate new syntax.
Syntax Error in the example
Updated core_test.py with PyTorch Test
Corrected the syntax in test case.
Correcting docs
Copy link

google-cla bot commented Mar 13, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Very nice work, I would not have anticipated it to be feasible like this!

raise NotImplementedError(
"`custom_gradient` is not supported with torch backend"
)
class custom_gradient:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the API consistent across backends, we may want to make this a functions def custom_gradient(fun) that instantiates CustomGradientFunction and calls it later when the function is called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet Yes. The @custom_gradient in torch/core.py acts as a decorator or custom_gradient(fun) as you said. It takes the function fun and returns a torch.autograd.Function instance (that is, a function whose gradient formula is now known to PyTorch framework).

Additionally, CustomGradientFunction class has forward() and backward() functions with our required forward and backward definitions, as usually custom definitions are defined in PyTorch: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

The effort was to unify the syntax and the method of creating new functions with custom gradients, across all backends. Since, PyTorch doesn't suppors decorator method of defining new gradients, I had to write a custom one for that.

For testing purposes, the second colab notebook uses my fork of Keras (because that has Custom Gradients implemented) and all frameworks one by one to test it. Just ensure to restart the Colab instance once to test another framework as backend after testing one.


class CustomGradientFunction(torch.autograd.Function):
"""
Autograd function for custom gradients.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One-line summary should be on the first line (after """)

@fchollet
Copy link
Member

Sorry if I don't know, tho I added the unit tests, I don't know how to run to test it.

Run pytest keras/ to run all unit tests. You can also specify the path to a single test file or directory.

@codecov-commenter
Copy link

codecov-commenter commented Mar 13, 2024

Codecov Report

Attention: Patch coverage is 76.92308% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 75.85%. Comparing base (c8700f4) to head (2534ca1).
Report is 125 commits behind head on master.

Files Patch % Lines
keras/backend/torch/core.py 76.92% 4 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19302      +/-   ##
==========================================
- Coverage   80.14%   75.85%   -4.29%     
==========================================
  Files         341      367      +26     
  Lines       36163    40433    +4270     
  Branches     7116     7864     +748     
==========================================
+ Hits        28982    30671    +1689     
- Misses       5578     8065    +2487     
- Partials     1603     1697      +94     
Flag Coverage Δ
keras 75.71% <76.92%> (-4.29%) ⬇️
keras-jax 60.12% <30.76%> (-2.93%) ⬇️
keras-numpy 54.39% <30.76%> (-2.69%) ⬇️
keras-tensorflow 61.29% <30.76%> (-3.37%) ⬇️
keras-torch 60.41% <76.92%> (-3.46%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Added one-line summary of the class in line `454`
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 14, 2024
@fchollet fchollet added the keras-team-review-pending Pending review by a Keras team member. label Mar 18, 2024
@fchollet
Copy link
Member

@james77777778 what do you think about this functionality and the new upstream argument?

@james77777778
Copy link
Contributor

james77777778 commented Mar 21, 2024

@james77777778 what do you think about this functionality and the new upstream argument?

This functionality is crucial for the following QLoRA-like technique. The new upstream argument is a bit tricky but I think it is worth the effort for the backend-agnostic feature.

I have tested the code and it works well. The training results in torch are consistent with the other backends.
(using my fork: https://github.com/james77777778/keras/blob/gradient/benchmark.py)

@@ -661,10 +662,19 @@ def custom_gradient(f):
def log1pexp(x):
e = ops.exp(x)

def grad(upstream):
def grad(*args, upstream = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upstream=None

return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))

return ops.log(1 + e), grad
```

Note that the grad function that returns gradient computations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer, please provide two separate code examples, one for JAX/TF and one for torch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer, please provide two separate code examples, one for JAX/TF and one for torch.

Added one JAX/TF-specific backend example, one PyTorch and one for all three together (backend-invariant).
Can you check if my edited PRs are visible to you or not?

@@ -501,15 +501,20 @@ def test_is_tensor(self):
self.assertFalse(ops.is_tensor([1, 2, 3]))

@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax"),
backend.backend() not in ("tensorflow", "jax", "pytorch"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's "torch"

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass computation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this line to after """

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I've made edits. But this part is not of much interest to the user. It just works as a wrapper to pull syntax to unify backends.

@staticmethod
def forward(ctx, forward_fn, *args, **kwargs):
"""
Forward pass computation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this line to after """


class CustomGradientFunction(torch.autograd.Function):
"""
CustomGradientFunction is a PyTorch autograd function enabling custom forward and backward passes for gradient computation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorten line (break it up into a few lines)

elif backend.backend() == "torch":
import torch

x = torch.tensor(100.0, requires_grad = True) # x = ops.convert_to_tensor(100.0) is NOT supported Yet!
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, @james77777778 @fchollet I just found out why PyT unit tests for @custom_gradient were failing all along.

The problem is that if one wants to calculate or work with gradients of a variable in torch, then regular definition of that variable by tensor torch.tensor(100.0) doesn't works. It needs additional torch.tensor(100.0, requires_grad = True) argument to be able to have .grad instance variable later on.

Keras' ops.convert_to_tensor() too requires this argument to define tensor in PyTorch using requires_grad argument to be able to calculate the gradient.

This is why I had to manually use PyT here to define torch.tensor rather than to have ops.convert_to_tensor to make it work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can instead do:

x = keras.Variable(100.0).value

This will work with all backends.

@fchollet
Copy link
Member

I'll merge the PR and fix remaining issues in post. Thank you for the contribution!

@fchollet fchollet merged commit ddba5d8 into keras-team:master Mar 22, 2024
5 of 6 checks passed
PR Queue automation moved this from Assigned Reviewer to Merged Mar 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
keras-team-review-pending Pending review by a Keras team member. size:M
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

5 participants