Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Having trouble moving a module from one GPU to another. #1191

Closed
Biotot opened this issue Dec 19, 2023 Discussed in #1190 · 22 comments
Closed

Having trouble moving a module from one GPU to another. #1191

Biotot opened this issue Dec 19, 2023 Discussed in #1190 · 22 comments

Comments

@Biotot
Copy link

Biotot commented Dec 19, 2023

Discussed in #1190

Originally posted by Biotot December 19, 2023
I've been banging my head against this for a couple days and I'm still coming up empty.

I have multiple modules and multiple GPUs, however this sequence continues to fail. I've narrowed it down to being a problem with the model. I can load it fresh from a file each loop and the error no longer exists.

(Pseudocode)
ModuleA.to(cuda:0)
TrainLoop
ModuleA.to(cpu)

ModuleA.to(cuda:1)
TrainLoop
Exception: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1

The error is consistently in the loss output.backward() call if that helps.

This error doesn't happen if I load the module from a file each loop. Input data is not the issue, the model isn't correctly switching devices. I've tried many different combinations of code and have tried directly moving from cuda:0 to cuda:1 without luck.
I'm not sure what is going wrong, I've been porting over my code from pytorch and I've been trying to get over this hurdle. Any help would be appreciated.

Running on TorchSharp-cuda-windows 0.101.4

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Dec 19, 2023

Is there a more detailed traceback?

Given that it's in backward() that it happens, maybe the gradients aren't being moved? What if you first zero the gradients using zero_grad(), and then move the model?

Do you know if this works in PyTorch?

Also, and this is just out of curiosity -- why move back and forth between GPUs?

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

Sorry for the massive block.


System.Runtime.InteropServices.ExternalException
  HResult=0x80004005
  Message=Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
Exception raised from compute_types at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\TensorIterator.cpp:484 (most recent call first):
00007FFB67B6AF4200007FFB67B6AEE0 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFB67B6AACA00007FFB67B6AA70 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FFB4C48704A00007FFB4C486740 torch_cpu.DLL!at::TensorIteratorBase::compute_types [<unknown file> @ <unknown line number>]
00007FFB4C483C4C00007FFB4C483BD0 torch_cpu.DLL!at::TensorIteratorBase::build [<unknown file> @ <unknown line number>]
00007FFB4C48429C00007FFB4C484190 torch_cpu.DLL!at::TensorIteratorBase::build_borrowing_binary_op [<unknown file> @ <unknown line number>]
00007FFB4C6A864800007FFB4C6A8610 torch_cpu.DLL!at::meta::structured_add_Tensor::meta [<unknown file> @ <unknown line number>]
00007FFAD1E35BFA <unknown symbol address> torch_cuda.dll!<unknown symbol> [<unknown file> @ <unknown line number>]
00007FFB4D0231EC00007FFB4D023170 torch_cpu.DLL!at::_ops::add__Tensor::redispatch [<unknown file> @ <unknown line number>]
00007FFB4ED285CE00007FFB4E8DE6E0 torch_cpu.DLL!torch::jit::tracer::TracingState::leaveFrame [<unknown file> @ <unknown line number>]
00007FFB4D0231EC00007FFB4D023170 torch_cpu.DLL!at::_ops::add__Tensor::redispatch [<unknown file> @ <unknown line number>]
00007FFB4E6872E800007FFB4E3469A0 torch_cpu.DLL!torch::autograd::CopySlices::~CopySlices [<unknown file> @ <unknown line number>]
00007FFB4CF8EF2800007FFB4CF8EE20 torch_cpu.DLL!at::_ops::add__Tensor::call [<unknown file> @ <unknown line number>]
00007FFB4C3CEC0100007FFB4C3CEBD0 torch_cpu.DLL!at::Tensor::operator+= [<unknown file> @ <unknown line number>]
00007FFB4EF9A20C00007FFB4EF98880 torch_cpu.DLL!torch::autograd::Node::name [<unknown file> @ <unknown line number>]
00007FFB4EF9B40200007FFB4EF9B300 torch_cpu.DLL!torch::autograd::AccumulateGrad::apply [<unknown file> @ <unknown line number>]
00007FFB4E27B83D00007FFB4E27B550 torch_cpu.DLL!torch::autograd::Node::operator() [<unknown file> @ <unknown line number>]
00007FFB4EF8FEEE00007FFB4EF8F830 torch_cpu.DLL!torch::autograd::Engine::add_thread_pool_task [<unknown file> @ <unknown line number>]
00007FFB4EF90FBF00007FFB4EF909D0 torch_cpu.DLL!torch::autograd::Engine::evaluate_function [<unknown file> @ <unknown line number>]
00007FFB4EF9686800007FFB4EF963D0 torch_cpu.DLL!torch::autograd::Engine::thread_main [<unknown file> @ <unknown line number>]
00007FFB4EF9633B00007FFB4EF96260 torch_cpu.DLL!torch::autograd::Engine::thread_init [<unknown file> @ <unknown line number>]
00007FFB4EF8B88500007FFB4EF8ADF0 torch_cpu.DLL!torch::autograd::Engine::get_base_engine [<unknown file> @ <unknown line number>]
00007FFC34491BB200007FFC34491B20 ucrtbase.dll!configthreadlocale [<unknown file> @ <unknown line number>]
00007FFC3458734400007FFC34587330 KERNEL32.DLL!BaseThreadInitThunk [<unknown file> @ <unknown line number>]
00007FFC365626B100007FFC36562690 ntdll.dll!RtlUserThreadStart [<unknown file> @ <unknown line number>]

  Source=TorchSharp
  StackTrace:
   at TorchSharp.torch.CheckForErrors()
   at TorchSharp.torch.Tensor.backward()
   at MultiBrain.Processing.Torcher.fuckaround(List`1& tModels) in C:\Users\Megabox\source\repos\Biotot\MultiBrain\MultiBrain\Processing\Torcher.cs:line 79
   at MultiBrain.Program.TrainDickaround() in C:\Users\Megabox\source\repos\Biotot\MultiBrain\MultiBrain\Program.cs:line 47
   at MultiBrain.Program.Main(String[] args) in C:\Users\Megabox\source\repos\Biotot\MultiBrain\MultiBrain\Program.cs:line 18

That's with trying to zero_grad the model.
I know switching GPUs is an edge case, but I'm iterating over a list of models and it's not a guarantee that the same gpu will get the same model during the next cycle.

I had things running in pytorch for for some external reasons I wanted to switch over to C#. I had most of my code ported over before actually testing on the mutli-gpu pc. Runs great on 1 gpu.

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Dec 19, 2023

Okay... The AccumulatedGrad appearance is suggestive (but just suggestive, no more than that :-)).

Stopping in the debugger just before calling backward, can you iterate through the parameters of the module and then check the 'device' attribute of all the parameters and their gradients? That should tell us if the theory is correct.

Unfortunately, I'm not fortunate enough to have two GPUs that are usable for training, so I can't help debug, myself.

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

I've tried just about everything I could think of.
For my pytorch version I would go through an update each parameter of the optimizer to get it to switch over, but no luck. So I tried looping over the parameters and first checking that the devices matched, and even just manually setting them all to the right device.
Right now my basic code block to test it is the super basic loop from the examples. Just the first loop on one device, and the second on another.

aModel is the model that's training and aOutput is the loss output about to be .backwards()'d. I didn't expand them all for the screenshot, but each _module and _internal_params were tagged to the right cuda:1 as expected
I noticed that the _deviceType doesn't seem to change, but it doesn't complain about that being on CPU since the error is Cuda:0 and Cuda:1

image

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Dec 19, 2023

Right, but could you do a foreach (var (name,p) in m.named_parameters()) and print out the device_type attribute on all the params and their gradients that aren't null? Unless it's input data that mismatches, there has to be some tensor in the module that's not in the right place.

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

Oh hey!! Thank you!
You were right. p.grad() was on the old device.
zero_grad doesn't clear it. Doesn't seem like I cant** move gradient to the new device.

@NiklasGustafsson
Copy link
Contributor

@shaltielshmid -- could this have anything to do with the recent no_grad() changes to Module.to()?

@Biotot -- does this work better on a version < 0.101.3?

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

Same issue in 0.101.0

Do you have an idea for a work-around that I could put together for now?

@NiklasGustafsson
Copy link
Contributor

Not at the moment -- @shaltielshmid recently worked on fixing some stuff in Module.to() and he is more familiar with the new logic, so it will be good to wait for him to reply. He's in a different time zone, though, so it may be tomorrow before he has anything to say.

@NiklasGustafsson
Copy link
Contributor

BTW, tomorrow is my last day working this calendar year, so there's not likely to be another release before 2024.

@shaltielshmid
Copy link
Contributor

If you call module.zero_grad() before moving devices, does that solve the problem?

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

No luck. Even calling it everywhere. It's zeroed (I'm assuming) but still on the original device.

@shaltielshmid
Copy link
Contributor

shaltielshmid commented Dec 19, 2023

Can you provide sample minimal code to replicate the issue?

I thought I replicated it on my end, but the zero_grad() fix worked by me.

@Biotot
Copy link
Author

Biotot commented Dec 19, 2023

I added a few extra zero_grads for fun.
I'm running on 0.101.4 off of nuget.

Excuse some lazy naming. copypasting from tutorials and my own code.

        public void playaround2()
        {
            var aModule = Sequential(
                Linear(1000, 1), Softsign());
            {
                Device aDevice = torch.device("cuda:0");
                var dataBatch = rand(32, 1000).to(aDevice);
                var resultBatch = rand(32, 1).to(aDevice);
                aModule.to(aDevice);
                foreach (var (name, p) in aModule.named_parameters())
                {
                    Console.WriteLine($"{name} {p.device} {p.grad()}");
                }
                var aMseLoss = nn.MSELoss();
                var optimizer = torch.optim.AdamW(aModule.parameters());
                // Compute the loss
                using var output = aMseLoss.forward(aModule.forward(dataBatch), resultBatch);

                // Clear the gradients before doing the back-propagation
                aModule.zero_grad();

                // Do back-progatation, which computes all the gradients.
                output.backward();

                optimizer.step();
                aModule.zero_grad();
                aModule.to(torch.CPU);
            }
            {
                aModule.zero_grad();
                Device aDevice = torch.device("cuda:1");
                var dataBatch = rand(32, 1000).to(aDevice);
                var resultBatch = rand(32, 1).to(aDevice);
                aModule.to(aDevice);
                aModule.zero_grad();
                foreach (var (name, p) in aModule.named_parameters())
                {
                    Console.WriteLine($"{name} {p.device} {p.grad()}");
                }
                var aMseLoss = nn.MSELoss();
                var optimizer = torch.optim.AdamW(aModule.parameters());
                // Compute the loss
                using var output = aMseLoss.forward(aModule.forward(dataBatch), resultBatch);

                // Clear the gradients before doing the back-propagation
                aModule.zero_grad();

                // Do back-progatation, which computes all the gradients.
                output.backward();

                optimizer.step();
            }
        }

System.Runtime.InteropServices.ExternalException: 'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Error triggers on the 2nd output.backward()

output from added prints.
First print doesn't have gradients

0.weight cuda:0
0.bias cuda:0
0.weight cuda:1 [1x1000], type = Float32, device = cuda:0
0.bias cuda:1 [1], type = Float32, device = cuda:0

@shaltielshmid
Copy link
Contributor

Okay, I see two things wrong here.

1] The gradients of a parameter should be copied during a move, as in PyTorch. I should have a fix ready for that in a few minutes.

2] When you call module.zero_grad() on the sequential, it doesn't seem to actually zero the gradients. Working on that now.

@NiklasGustafsson
Copy link
Contributor

Awesome to have you involved with TorchSharp, @shaltielshmid!

@shaltielshmid
Copy link
Contributor

@NiklasGustafsson awesome to be involved!

Question for you:
In PyTorch, the default behavior for zero_grad is actually to set the gradient to none as opposed to making them zeros. Should we add this behavior in as well?
https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2438

Right now with the fix of moving the gradients the bug no longer occurs, but this raises that questions.

@shaltielshmid
Copy link
Contributor

Update: It seems like libtorch's default behavior is to set to none as well: https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/src/nn/module.cpp#L253

Updating TorchSharp's behavior accordingly

@Biotot
Copy link
Author

Biotot commented Dec 20, 2023

You guys are fantastic. Thanks for the quick turn around.

@NiklasGustafsson
Copy link
Contributor

If everything goes well, there should be a release today with this fix in it.

@shaltielshmid
Copy link
Contributor

@Biotot Is the problem solved in version 0.101.5?

@Biotot
Copy link
Author

Biotot commented Dec 21, 2023

It's working fantastic now. Thanks again.

@Biotot Biotot closed this as completed Dec 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants