Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with PyTorch 0.4 #32

Closed
seyiqi opened this issue Mar 27, 2018 · 11 comments
Closed

Compatibility with PyTorch 0.4 #32

seyiqi opened this issue Mar 27, 2018 · 11 comments

Comments

@seyiqi
Copy link

seyiqi commented Mar 27, 2018

Thanks very much for sharing this implementation. I forked the code. It works great on PyTorch 0.3.1. But when I ran it with 0.4.0 (master version), I got following error (I made some minor change so the line number wouldn't match):
File "../networks/densenet_efficient.py", line 330, in forward

bn_input_var = Variable(type(inputs[0])(storage).resize_(size), volatile=True)

TypeError: Variable data has to be a tensor, but got torch.cuda.FloatStorage

It turned out that for this line:
bn_input_var = Variable(type(inputs[0])(storage).resize_(size), volatile=True)

The inputs in version 0.3.1 is FloatTensor but in 0.4.0 it's Variable.

I am wondering what's the best way to update the code for 0.4.0?

Many thanks!

@seyiqi
Copy link
Author

seyiqi commented Mar 27, 2018

I forgot to mention that this issue is related to this file:

https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet_efficient.py

@gpleiss
Copy link
Owner

gpleiss commented Mar 30, 2018

Right... I think that the way it is written is currently incompatible with PyTorch version 0.4. I've mostly been testing against 0.3. I'll try to make it compatible with both versions.

@gpleiss
Copy link
Owner

gpleiss commented Apr 16, 2018

Sorry for the slow reply. @taineleau pointed out that PyTorch 0.4 has a checkpoint feature. This will essentially do most of the work that's being done in the efficient implementation right now. I'll look into this, hopefully next week.

@wandering007
Copy link
Contributor

@gpleiss @seyiqi I just release the code runnable with PyTorch 0.4 and pass the single-gpu case. However, it does not work in the multi-gpu case, which is weird. Needing help~

@taineleau-zz
Copy link
Collaborator

@wandering007 I suggest you take a look at the checkpoint feature. This helper function should handle multi-gpus nicely.

@gpleiss
Copy link
Owner

gpleiss commented Apr 26, 2018

@taineleau @wandering007 I'm thinking we should re-write this code to just work with PyTorch 0.4, using the checkpointing feature. We can make a branch/tag of the current code for people who are still using PyTorch 0.3.

@wandering007
Copy link
Contributor

@taineleau Actually, what we do is basically like checkpoint feature and seems more efficient...

@gpleiss
Copy link
Owner

gpleiss commented Apr 26, 2018

Sorry @wandering007 not sure what you mean. Are you saying that you currently have an implementation that's using checkpointing?

@wandering007
Copy link
Contributor

wandering007 commented Apr 26, 2018

@gpleiss uh...no, I didn't use the checkpoint feature. But I've looked into the source code of checkpointing feature, the implementations are very similar. The inefficiency of checkpoint may include these points:

  1. Though concat-bn-relu can be checkpointed, the relu-ouput cannot use the same shared memory unless we hack it like what you've done before.
  2. If concat-bn-relu-conv is checkpointed, the conv operations need to be recomputed when backward, which is time comsuming.

Besides, we still need to restore the BN statistics (running_mean and running_var) as checkpoint feature simply ignores the change.

@gpleiss
Copy link
Owner

gpleiss commented Apr 26, 2018

@wandering007 checkpointing takes care of it all. I profiled it, and the models are far more memory efficient than they ever were. The PyTorch team has optimized the memory usage of autograd like crazy.

Besides that, the checkpointing feature seems to be really smart. When memory isn't an issue (e.g. for big GPUs/smaller models), the checkpointing code does less memory-efficient optimizations so the code runs faster. When memory IS an issue (e.g. for bigger models), the checkpointing code squeezes out a TON of memory savings.

@gpleiss
Copy link
Owner

gpleiss commented Apr 27, 2018

Closed by #35

@gpleiss gpleiss closed this as completed Apr 27, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants