-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
mse_loss
#174
Comments
Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update
Would it be ok if I take care of this? Appreciate your help and support! |
Absolutely! Anyone is welcome to submit a PR to address this issue. Yes, this PR would start by updating
|
Hi @mruberry! I am almost done with the implementation regarding
import torch
import thunder
reduction = "none"
def mse(input, target):
output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
return output
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)
cfn = thunder.jit(mse)
actual_loss = cfn(input, target)
grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)
expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)
print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item()) I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call
I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :) |
Happy to help!
You should be able to do the following, which might be simpler:
@IvanYashchuk can correct me if I'm mistaken about this. Take a look at the Your question about the decomposition is great. I'm not sure because it depends on the details the Note for @jjsjann123 @t-vi @IvanYashchuk, we should probably reintroduce a developer option for the jit to force executors like torch to execute only primitives. This would be straightforward to do. One way to test the decomposition locally is to not register a direct binding of
This is a great question! First, if you have a decomposition for Second, the current state of grad formulas in thunder can be confusing! @IvanYashchuk can help direct you here. It would be helpful if you submitting a draft PR, so we can look at the code in more detail before making a recommendation.
Let's take a look at a draft PR, maybe even one where grad support isn't considered to start, and then we can discusss!
The easiest way to check is to inspect the execution trace and verify that lightning-thunder/thunder/executors/torchex.py Line 1208 in 649c3d7
and here: lightning-thunder/thunder/executors/torchex.py Line 1548 in 649c3d7
The torch executor has some helper functions that make this straightforward. The
These are great questions, and it's great you're asking them. I hope these responses are helpful. Let us know if you have additional questions! |
Hi @mruberry! Thanks so much for the detailed explanation!馃槃 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general. Currently, I am facing issues with I do have another question regarding your comment:
Does this mean that I should not have something like: mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional) in actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn) I should be able to see the decomposition of
This sounds like, once I add I am learning a lot from your comments and having a lot of fun! Thank you so much :) |
Awesome! I'm glad that was helpful.
I look forward to reviewing the PR in more detail.
Interesting! Let's discuss more with @IvanYashchuk on the PR itself.
Yes, I think that's correct. Without the direct binding of
If you both register the operation with
You're very welcome; I'm glad you're having fun. |
馃殌 Feature
Implement
torch.nn.functional.mse_loss
Motivation
NeMo text-to-image model.
The text was updated successfully, but these errors were encountered: