Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Include example code] mixed_precision="fp16" will break torch.save function. #866

Closed
2 of 4 tasks
BurguerJohn opened this issue Nov 17, 2022 · 27 comments · Fixed by #872
Closed
2 of 4 tasks

[Include example code] mixed_precision="fp16" will break torch.save function. #866

BurguerJohn opened this issue Nov 17, 2022 · 27 comments · Fixed by #872
Assignees
Labels
bug Something isn't working

Comments

@BurguerJohn
Copy link

System Info

accelerate-0.14.0
Python 3.7.15
Pytorch 1.12.1+cu113

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

from accelerate import Accelerator
import torch
import torch.nn as nn

class ExampleModule(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 3, kernel_size=1)

model = ExampleModule()

#mixed_precision="fp16" will give error on torch.save
#mixed_precision="no" will work with torch.save
accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="fp16",
        log_with="tensorboard",
        logging_dir=".",
    )

#Always work
torch.save(model,  "/model_original.model")

#Will break torch.save if the model if mixed_precision="fp16" 
model = accelerator.prepare(model)

#Error with mixed_precision="fp16" 
torch.save(model,  "/model_acc.model")
#Error as well with mixed_precision="fp16" 
torch.save(accelerator.unwrap_model(model),  "/model_unwrap.sd")

It will return this error if mixed_precision="fp16"

---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
[<ipython-input-1-5ce45839c137>](https://localhost:8080/#) in <module>
     27 
     28 #Error
---> 29 torch.save(model,  "/model_acc.model")
     30 #Error as well
     31 torch.save(accelerator.unwrap_model(model),  "/model_unwrap.sd")

1 frames
[/usr/local/lib/python3.7/dist-packages/torch/serialization.py](https://localhost:8080/#) in _save(obj, zip_file, pickle_module, pickle_protocol)
    587     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    588     pickler.persistent_id = persistent_id
--> 589     pickler.dump(obj)
    590     data_value = data_buf.getvalue()
    591     zip_file.write_record('data.pkl', data_value, len(data_value))

PicklingError: Can't pickle <function _forward_unimplemented at 0x7fb39d0b0320>: it's not the same object as torch.nn.modules.module._forward_unimplemented

Expected behavior

torch.save should work even if Accelerator is set to fp16
@sgugger
Copy link
Collaborator

sgugger commented Nov 17, 2022

cc @muellerzr This might already be fixed by your recent work. Or it's what has broken it ;-)

@muellerzr muellerzr self-assigned this Nov 17, 2022
@BurguerJohn
Copy link
Author

Wow, that was fast, thanks for the quick response lol
Will try to test it with this fix.

@muellerzr
Copy link
Collaborator

muellerzr commented Nov 17, 2022

@BurguerJohn the solution is you should only save via this:

(and install accelerate from github, as this fix was put in this morning!)

torch.save(accelerator.unwrap_model(model),  "model_unwrap.sd")

@sgugger we could probably make accelerator.save wrap around this no?

@BurguerJohn
Copy link
Author

Is the solution on the main branch? Just installed it on colab accelerate-0.15.0.dev0 and still giving the same error.

@muellerzr
Copy link
Collaborator

muellerzr commented Nov 17, 2022

@BurguerJohn yes it is. I ran the following to test it:

from accelerate import Accelerator
import torch
import torch.nn as nn

class ExampleModule(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 3, kernel_size=1)

model = ExampleModule()

accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="fp16",
        log_with="tensorboard",
        logging_dir=".",
    )

model = accelerator.prepare(model)
torch.save(accelerator.unwrap_model(model),  "model_unwrap.sd")

@sgugger
Copy link
Collaborator

sgugger commented Nov 17, 2022

@muellerzr No we can't have accelerator.save add a magic unwrap since it will receive other things than models :-)

@muellerzr
Copy link
Collaborator

Whelp, as Sylvain says, too much magic :) What I presented above would be the official "correct" answer to what you're wanting to do.

@BurguerJohn
Copy link
Author

Could it be some limitation on colab or python 3.7? Still need to test it on my computer, but colab seen to still have trouble with it:
https://colab.research.google.com/drive/1Y07ElQf1qlD3b5SxCLGilshc_EYFILFg?usp=sharing

@muellerzr
Copy link
Collaborator

Thanks @BurguerJohn, will look into this as it seems to be py 3.7 specific!

@BurguerJohn
Copy link
Author

Cool, thanks! Will try it out with other version later. Again, thanks for all the support.

@muellerzr muellerzr added the bug Something isn't working label Nov 17, 2022
@muellerzr
Copy link
Collaborator

@BurguerJohn sadly I don't have a solution immediatly for you. This whole issue stems from jupyter specifically. Even calling it through the CLI on Jupyter will have this bug. No clue why but I'll be working on a different solution to this soon.

@BurguerJohn
Copy link
Author

@muellerzr Alright, no problem. Just letting you know that it seen that this bug also happen without jupyter, on windows: python 3.9
There seen to have also two imports on keymap.py that don't exist on windows.
termios and tty (call termios?)
Removing those two import made the code run on windows, but still the same error on torch.save

@muellerzr
Copy link
Collaborator

Can you open up a seperate issue for that please? Because termios and tty are part of the stdlib for python

@BurguerJohn
Copy link
Author

I'm giving more tests, it may be some name conflict with my project. Just tested to call termio on a clean project and it work. Will do more tests before opening another issue.

@BurguerJohn
Copy link
Author

Can confirm that the error still happen on Window:Python 3.9 even without jupyter.

Python: 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)]
Accelerate: 0.15.0.dev0
Pythorch: 1.12.1+cu116
Traceback (most recent call last):
  File "C:\termios.py", line 27, in <module>
    torch.save(accelerator.unwrap_model(model),  "model_unwrap.sd")
  File "C:\Users\T\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "C:\Users\T\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\serialization.py", line 589, in _save
    pickler.dump(obj)
_pickle.PicklingError: Can't pickle <function _forward_unimplemented at 0x000001CF5A4EE160>: it's not the same object as torch.nn.modules.module._forward_unimplemented

@BurguerJohn
Copy link
Author

I'm pretty sure is this line causing the problem:
model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward)
There is probably a better way to handle autocast, a wrapper model or a lambda function? Dunno, don't have more time today to check this. Will probably remove the prepare line from my code and use a autocast, it should be enough for my project.

@muellerzr
Copy link
Collaborator

muellerzr commented Nov 18, 2022

@BurguerJohn it's actually stemming from ConvertFp32 IIRC when I was looking (see the line at the bottom of that section you were looking at)

@BurguerJohn
Copy link
Author

Yeah, but this line is enough to break the torch.save
https://colab.research.google.com/drive/11fvrk1Jslw2VIRTF6h0pdGgJx5AkAIMv?usp=sharing

Unless unwrap should do something to revert this, didn't had the time to read all the code.

@muellerzr
Copy link
Collaborator

It should be with that pr Sylvain mentioned earlier (and it shows it works on Ubuntu based systems that aren't running Jupyter). I'll be looking deeper into this tommorow

@BurguerJohn
Copy link
Author

No problem, thanks for all the help. I already managed to make my code work without the prepare line, so there is no need to rush.

@muellerzr
Copy link
Collaborator

BTW, here's a very good s/o explaining what's happening: https://stackoverflow.com/questions/27641859/pickling-decorated-callable-class-wrapper

@muellerzr
Copy link
Collaborator

@BurguerJohn found a fix, essentially it's possible for us to follow the trail of __wrapped__ and get back the original .forward() to remove all the wrapping 🤯

@BurguerJohn
Copy link
Author

Wow, that is pretty cool. It's also something new for me. Glad you managed to find a good solution.

@phananh03x
Copy link

in Colab Notebook, I tried :
model = ExampleModule()

accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision="fp16",
log_with="tensorboard",
logging_dir=".",
)

model = accelerator.prepare(model)
torch.save(accelerator.unwrap_model(model), "model_unwrap.sd")

but it give error :

/usr/local/lib/python3.10/dist-packages/torch/serialization.py:441 in save │
│ │
│ 438 │ │
│ 439 │ if _use_new_zipfile_serialization: │
│ 440 │ │ with _open_zipfile_writer(f) as opened_zipfile: │
│ ❱ 441 │ │ │ _save(obj, opened_zipfile, pickle_module, pickle_protocol) │
│ 442 │ │ │ return │
│ 443 │ else: │
│ 444 │ │ with _open_file_like(f, 'wb') as opened_file: │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/serialization.py:653 in _save │
│ │
│ 650 │ data_buf = io.BytesIO() │
│ 651 │ pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) │
│ 652 │ pickler.persistent_id = persistent_id │
│ ❱ 653 │ pickler.dump(obj) │
│ 654 │ data_value = data_buf.getvalue() │
│ 655 │ zip_file.write_record('data.pkl', data_value, len(data_value)) │
│ 656 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: Can't pickle local object 'convert_outputs_to_fp32..forward'

@muellerzr
Copy link
Collaborator

@phananh03x can you provide us a full repr (what is ExampleModel?) and the version of accelerate you are using

@phananh03x
Copy link

from diffusers import UNet2DModel
from accelerate import Accelerator

model = UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128))

@phananh03x
Copy link

accelerate-0.19.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants