-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom Gradients #19302
Custom Gradients #19302
Conversation
Added a support for @custom_gradient decorator for PyTorch users.
Added documentation for ops.custom_gradient and edited the syntax of log1pexp(x) example to demonstrate new syntax.
Syntax Error in the example
Updated core_test.py with PyTorch Test
Corrected the syntax in test case.
Correcting docs
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Very nice work, I would not have anticipated it to be feasible like this!
raise NotImplementedError( | ||
"`custom_gradient` is not supported with torch backend" | ||
) | ||
class custom_gradient: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep the API consistent across backends, we may want to make this a functions def custom_gradient(fun)
that instantiates CustomGradientFunction
and calls it later when the function is called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet Yes. The @custom_gradient
in torch/core.py
acts as a decorator or custom_gradient(fun)
as you said. It takes the function fun
and returns a torch.autograd.Function
instance (that is, a function whose gradient formula is now known to PyTorch framework).
Additionally, CustomGradientFunction
class has forward()
and backward()
functions with our required forward and backward definitions, as usually custom definitions are defined in PyTorch: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
The effort was to unify the syntax and the method of creating new functions with custom gradients, across all backends. Since, PyTorch doesn't suppors decorator method of defining new gradients, I had to write a custom one for that.
For testing purposes, the second colab notebook uses my fork of Keras (because that has Custom Gradients implemented) and all frameworks one by one to test it. Just ensure to restart the Colab instance once to test another framework as backend after testing one.
keras/backend/torch/core.py
Outdated
|
||
class CustomGradientFunction(torch.autograd.Function): | ||
""" | ||
Autograd function for custom gradients. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One-line summary should be on the first line (after """
)
Run |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19302 +/- ##
==========================================
- Coverage 80.14% 75.85% -4.29%
==========================================
Files 341 367 +26
Lines 36163 40433 +4270
Branches 7116 7864 +748
==========================================
+ Hits 28982 30671 +1689
- Misses 5578 8065 +2487
- Partials 1603 1697 +94
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Added one-line summary of the class in line `454`
@james77777778 what do you think about this functionality and the new |
This functionality is crucial for the following QLoRA-like technique. The new I have tested the code and it works well. The training results in torch are consistent with the other backends. |
keras/ops/core.py
Outdated
@@ -661,10 +662,19 @@ def custom_gradient(f): | |||
def log1pexp(x): | |||
e = ops.exp(x) | |||
|
|||
def grad(upstream): | |||
def grad(*args, upstream = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
upstream=None
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) | ||
|
||
return ops.log(1 + e), grad | ||
``` | ||
|
||
Note that the grad function that returns gradient computations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this clearer, please provide two separate code examples, one for JAX/TF and one for torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this clearer, please provide two separate code examples, one for JAX/TF and one for torch.
Added one JAX/TF-specific backend example, one PyTorch and one for all three together (backend-invariant).
Can you check if my edited PRs are visible to you or not?
keras/ops/core_test.py
Outdated
@@ -501,15 +501,20 @@ def test_is_tensor(self): | |||
self.assertFalse(ops.is_tensor([1, 2, 3])) | |||
|
|||
@pytest.mark.skipif( | |||
backend.backend() not in ("tensorflow", "jax"), | |||
backend.backend() not in ("tensorflow", "jax", "pytorch"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's "torch"
keras/backend/torch/core.py
Outdated
@staticmethod | ||
def backward(ctx, grad_output): | ||
""" | ||
Backward pass computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this line to after """
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I've made edits. But this part is not of much interest to the user. It just works as a wrapper to pull syntax to unify backends.
keras/backend/torch/core.py
Outdated
@staticmethod | ||
def forward(ctx, forward_fn, *args, **kwargs): | ||
""" | ||
Forward pass computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this line to after """
keras/backend/torch/core.py
Outdated
|
||
class CustomGradientFunction(torch.autograd.Function): | ||
""" | ||
CustomGradientFunction is a PyTorch autograd function enabling custom forward and backward passes for gradient computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shorten line (break it up into a few lines)
elif backend.backend() == "torch": | ||
import torch | ||
|
||
x = torch.tensor(100.0, requires_grad = True) # x = ops.convert_to_tensor(100.0) is NOT supported Yet! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, @james77777778 @fchollet I just found out why PyT unit tests for @custom_gradient
were failing all along.
The problem is that if one wants to calculate or work with gradients of a variable in torch, then regular definition of that variable by tensor torch.tensor(100.0)
doesn't works. It needs additional torch.tensor(100.0, requires_grad = True)
argument to be able to have .grad
instance variable later on.
Keras' ops.convert_to_tensor()
too requires this argument to define tensor in PyTorch using requires_grad
argument to be able to calculate the gradient.
This is why I had to manually use PyT here to define torch.tensor
rather than to have ops.convert_to_tensor
to make it work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can instead do:
x = keras.Variable(100.0).value
This will work with all backends.
I'll merge the PR and fix remaining issues in post. Thank you for the contribution! |
With little bit of tweaking of the syntax, now its possible to have custom gradients for all three - JAX, PyTorch and TensorFlow.
Sorry if I don't know, tho I added the unit tests, I don't know how to run to test it.
Example of syntax:
Here's the rough notebook testing it:
I believe custom addition of function and their custom gradient definitions would enable the users to implement complex layers and operations even if that isn't available in other frameworks. It could also help them work with custom data structures like - complex numbers, non metric spaces or probabilistic spaces.
Example Syntax:
Only difference is that grad function accepts two arguments (one keyword) in the case of PyTorch and one argument in the case of TensorFlow or Jax. So additional
doesn't hurt more.
Thank You.