Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

mse_loss #174

Closed
Tracked by #179
tfogal opened this issue Apr 12, 2024 · 6 comments 路 Fixed by #218
Closed
Tracked by #179

mse_loss #174

tfogal opened this issue Apr 12, 2024 · 6 comments 路 Fixed by #218
Labels
enhancement New feature or request MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) nemo Issues needed to support NVIDIA NeMo models.

Comments

@tfogal
Copy link
Collaborator

tfogal commented Apr 12, 2024

馃殌 Feature

Implement torch.nn.functional.mse_loss

Motivation

NeMo text-to-image model.

@tfogal tfogal added enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models. MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) labels Apr 12, 2024
@k223kim
Copy link
Contributor

k223kim commented Apr 15, 2024

Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update torch/__init__.py with something like:

@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False)
def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:

Would it be ok if I take care of this? Appreciate your help and support!
Best,
Kaeun

@mruberry
Copy link
Collaborator

Hello Team! I am Kaeun, one of the new contributors at Thunder who is having a lot of fun with these tasks. I am wondering if it would be ok for me to handle this issue. It seems like I have to update torch/__init__.py with something like:

@torchsymbol(torch.nn.functional.mse_loss, id="torch.nn.functional.mse_loss", is_method=False)
def mse_loss(a: TensorLike, b: TensorLike, /) -> TensorLike:

Would it be ok if I take care of this? Appreciate your help and support! Best, Kaeun

Absolutely! Anyone is welcome to submit a PR to address this issue. Yes, this PR would start by updating torch/__init__.py. A few additional notes:

  • Make sure to capture the additional arguments to the function, too
  • The decomposition of this function may be a little tricky
  • The torch executor can be updated to run the operation if it's called without having to execute the decomposition

@k223kim
Copy link
Contributor

k223kim commented Apr 18, 2024

Hi @mruberry! I am almost done with the implementation regarding mse_loss. However, I have some questions and would appreciate your help! (this will help me a lot :) )

  1. I am using the following script to quickly confirm the forward and backward pass of mse_loss.
import torch
import thunder

reduction = "none"

def mse(input, target):
    output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
    return output

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)

cfn = thunder.jit(mse)
actual_loss = cfn(input, target)

grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)

expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())

I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call mse_loss with decomposition (the one in torch/__init__.py not in torchex.py)?

  1. Regarding the backward pass that will be added in torch/__init__.py, I am trying to mimic what has been done with _cross_entropy_grad. However, I am having a hard time understanding the main purpose of _cross_entorpy_grad and cross_entropy_backward. cross_entropy_backward simply returns a TensorProxy of the gradient that has been computed through get_grad(fwd). Why would we want to separate the two?
    When taking a look at the implementation of log_softmax_backward, it explicitly calculates the derivative of log_softmax and does not do any get_grad or put_grad. Why is there a difference between the cross_entropy's backward pass and log_softmax's backward pass?
    For the mse_loss, I am assuming I should be doing something like the cross_entropy's implementation. Would that be a proper approach? (having mse_loss_backward and _mse_loss_grad)

  2. Referring to the third bullet point in your comment, I am also currently working on implementing mse_loss in torchex.py. How would I call the mse_loss without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposed mse_loss in torchex.py properly.

I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)

@mruberry
Copy link
Collaborator

Hi @mruberry! I am almost done with the implementation regarding mse_loss. However, I have some questions and would appreciate your help! (this will help me a lot :) )

Happy to help!

  1. I am using the following script to quickly confirm the forward and backward pass of mse_loss.
import torch
import thunder

reduction = "none"

def mse(input, target):
    output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
    return output

input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5)

cfn = thunder.jit(mse)
actual_loss = cfn(input, target)

grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)

expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())

I am not sure if this is the proper way to calculate the gradient. Would you be able to check this for me? Also, just to check, does the above script call mse_loss with decomposition (the one in torch/__init__.py not in torchex.py)?

You should be able to do the following, which might be simpler:

actual_loss = cfn(inp, target)
actual_loss.sum().backward()
thunder_grad = inp.grad
inp.grad = None

expected_loss = fn(inp, target)
expected_loss.sum().backward()
pytorch_grad = inp.grad

assert_close(thunder_grad, pytorch_grad)

@IvanYashchuk can correct me if I'm mistaken about this. Take a look at the assert_close utility for comparing tensors.

Your question about the decomposition is great. I'm not sure because it depends on the details the mse_loss implementation and any updates to the torchexecutor. You can see how the program is actually being executed by printing the execution trace of the program, which will show whether torch.nn.functional.mse_loss is being called directly, or a decomposition is being called instead.

Note for @jjsjann123 @t-vi @IvanYashchuk, we should probably reintroduce a developer option for the jit to force executors like torch to execute only primitives. This would be straightforward to do.

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected. Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

  1. Regarding the backward pass that will be added in torch/__init__.py, I am trying to mimic what has been done with _cross_entropy_grad. However, I am having a hard time understanding the main purpose of _cross_entorpy_grad and cross_entropy_backward. cross_entropy_backward simply returns a TensorProxy of the gradient that has been computed through get_grad(fwd). Why would we want to separate the two?
    When taking a look at the implementation of log_softmax_backward, it explicitly calculates the derivative of log_softmax and does not do any get_grad or put_grad. Why is there a difference between the cross_entropy's backward pass and log_softmax's backward pass?

This is a great question! First, if you have a decomposition for mse_loss, and a grad formula is defined for every element of that decomposition, then you will also define an (implicit) grad formula for mse_loss, and shouldn't have to add a custom grad formula at all.

Second, the current state of grad formulas in thunder can be confusing! @IvanYashchuk can help direct you here. It would be helpful if you submitting a draft PR, so we can look at the code in more detail before making a recommendation.

For the mse_loss, I am assuming I should be doing something like the cross_entropy's implementation. Would that be a proper approach? (having mse_loss_backward and _mse_loss_grad)

Let's take a look at a draft PR, maybe even one where grad support isn't considered to start, and then we can discusss!

  1. Referring to the third bullet point in your comment, I am also currently working on implementing mse_loss in torchex.py. How would I call the mse_loss without having to execute the decomposition (as you have mentioned)? I am wondering if there is a way to check if I have implemented the non-decomposed mse_loss in torchex.py properly.

The easiest way to check is to inspect the execution trace and verify that torch.nn.functional.mse_loss is called directly. An example of a "direct" lowering to torch is the dropout operation. See here:

dropout = _register_torch_operation("dropout", module=torch.nn.functional)

and here:

_register_implementation(ltorch.dropout, dropout, checker=_always_executable)

The torch executor has some helper functions that make this straightforward. The _register_torch_operation function tells thunder how to call torch operations (like dropout), and the _register_implementation function tells thunder that it can call PyTorch's dropout to execute thunder.torch.dropout.

I hope you understand that I am in the process of learning how things work in thunder and would appreciate your help :)

These are great questions, and it's great you're asking them. I hope these responses are helpful. Let us know if you have additional questions!

@k223kim k223kim mentioned this issue Apr 18, 2024
4 tasks
@k223kim
Copy link
Contributor

k223kim commented Apr 18, 2024

Hi @mruberry!

Thanks so much for the detailed explanation!馃槃 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general.
I have submitted a draft PR that at least passes tests/test_ops.py. Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed traces = thunder.last_traces(cfn) and confirmed that it is calling the decomposition of mse_loss) It'll be awesome if you can take a look so we can further discuss about the implementation.

Currently, I am facing issues with tests/test_grad.py. Specifically, there is some discrepancy when running test_vjp_correctness_mse_loss_torch_cpu_float64 which I suspect is due to some implementation in torchex.py (strange how the current implementation's grad output is different to torch.ops.aten.mse_loss_backward's output).

I do have another question regarding your comment:

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.

Does this mean that I should not have something like:

mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)

in torchex.py? And just run when doing

actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn)

I should be able to see the decomposition of mse_loss?

Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

This sounds like, once I add _register_torch_operation for mse_loss, I should be able to see torch.nn.functional.mse_loss in traces = thunder.last_traces(cfn). However, I am only seeing the decomposed version of mse_loss. Did I misunderstand something? Please let me know!

I am learning a lot from your comments and having a lot of fun! Thank you so much :)

@mruberry
Copy link
Collaborator

Hi @mruberry!

Thanks so much for the detailed explanation!馃槃 It definitely helped me to further understand how trace execution can help my testing during the implementation and how gradients are calculated in general. I have submitted a draft PR that at least passes tests/test_ops.py.

Awesome! I'm glad that was helpful.

Also, using the script that you have provided above, it seems to calculate the gradient and forward pass correctly. (+ I have performed traces = thunder.last_traces(cfn) and confirmed that it is calling the decomposition of mse_loss) It'll be awesome if you can take a look so we can further discuss about the implementation.

I look forward to reviewing the PR in more detail.

Currently, I am facing issues with tests/test_grad.py. Specifically, there is some discrepancy when running test_vjp_correctness_mse_loss_torch_cpu_float64 which I suspect is due to some implementation in torchex.py (strange how the current implementation's grad output is different to torch.ops.aten.mse_loss_backward's output).

Interesting! Let's discuss more with @IvanYashchuk on the PR itself.

I do have another question regarding your comment:

One way to test the decomposition locally is to not register a direct binding of mse_loss with the torch executor and check that the execution trace decomposes as expected.

Does this mean that I should not have something like:

mse_loss = _register_torch_operation("mse_loss", module=torch.nn.functional)

in torchex.py? And just run when doing

actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.grad
traces = thunder.last_traces(cfn)

I should be able to see the decomposition of mse_loss?

Yes, I think that's correct. Without the direct binding of mse_loss registered in the torch executor it should have to execute it by decomposing it into other operations.

Then you can add the direct binding for the loss and verify it's working as expected, too. In the future this should be easier (once the above developer option is available).

This sounds like, once I add _register_torch_operation for mse_loss, I should be able to see torch.nn.functional.mse_loss in traces = thunder.last_traces(cfn). However, I am only seeing the decomposed version of mse_loss. Did I misunderstand something? Please let me know!

If you both register the operation with _register_torch_operation and then bind it with _register_implementation you should see the execution trace call torch.nn.functional.mse_loss directly. The decomposition will still appear below it, but it will be commented out.

I am learning a lot from your comments and having a lot of fun! Thank you so much :)

You're very welcome; I'm glad you're having fun.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request MegatronImagen Needed to support NeMo's MegatronImagen model (text to image generation) nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants