Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch 0.3 compatibility #28

Merged
merged 10 commits into from Mar 13, 2018
Merged

PyTorch 0.3 compatibility #28

merged 10 commits into from Mar 13, 2018

Conversation

gpleiss
Copy link
Owner

@gpleiss gpleiss commented Mar 5, 2018

The efficient model now works on PyTorch 0.3

Some other changes:

  • Multi-GPU support is now baked directly into DenseNetEfficient, so I removed the multi-GPU specific model
  • Changed the name of the cifar option to small_inputs (more generic).

I will merge it in tomorrow after I confirm that the demo gets the same error on CIFAR-10.

@gpleiss
Copy link
Owner Author

gpleiss commented Mar 5, 2018

This PR fixes #11 and #24 and #26.

@wandering007
Copy link
Contributor

wandering007 commented Mar 6, 2018

@gpleiss Thanks for the hard work. I have some confusions though. Line 116 in densenet_efficient.py:

relu_output = fn(self.norm_weight, self.norm_bias, *inputs)
conv_output = F.conv2d(relu_output, self.conv_weight, bias=None, stride=1,
                       padding=0, dilation=1, groups=1)

Based on my understanding, the F.conv2d backward needs the relu_output to compute the gradient of conv_weight. As relu_output uses shared memory, the data could be corrupted. Am I wrong?

Besides, I also test the previous multi-gpu version, it was weird that _efficient_conv2d (aka cudnn_conv) can pass the test independently but grad got wrong when the module is used on the whole model. Is that the reason you put the conv ops outside of the function?

@gpleiss
Copy link
Owner Author

gpleiss commented Mar 6, 2018

@wandering007 I put the conv outside because that made everythin much more memory efficient. When the conv is inside the rest of the bottleneck function, everything uses a lot more memory.

But you're right, it does seem that the gradient is incorrect. In practice, it's not that incorrect, since often what's written in the memory buffer at any given time is similar to what the actual conv input should be.

To fix this, we'll probably need to do something fancy with PyTorch hooks. I'll give this a shot, and hold off on merging for now.

@wandering007
Copy link
Contributor

@gpleiss I've post a question on PyTorch forum. Based on the answers, I think what you've implemented is workable.

@gpleiss
Copy link
Owner Author

gpleiss commented Mar 12, 2018

@wandering007 I just pushed a fix, which should correctly compute the gradient.
What I had before sort of worked, for a very subtle reason. The gradients for the convolution were incorrect, but not too incorrect! The input variable to the convolution had a batch-norm'd version of the features, but an incorrect batch-norm'd version (i.e. the batch norm of the very final layer). What happens is that the network then performs gradient descent with respect to only a single set of batch norm parameters, rather than layer-specific batch norm parameters.

This "works" - in the sense that the network still learns something. However, it does not have the same capacity as a normal DenseNet, since the network effectively has one batch norm layer. And consequently, the network is not as accurate as the non-efficient network.

TLDR The fixes I just pushed correctly re-populate the shared storage before the convolution backward pass. It seems to be getting the same accuracy now as the non-efficient network. I'm going to run one more test today, and push tomorrow.

@wandering007
Copy link
Contributor

wandering007 commented Mar 13, 2018

@gpleiss That is a nice workaround!
I think it can be further improved in two ways (#29):

  1. for forward pass, since the autograd graph is not necessary to be built, the created Variables in the forward function can be volatile, aka in purely inference mode.
  2. the running_mean and running_var restoring can be simpler to achieve.

One question about the code is Line 389:

 self.bn_output_var.backward(gradient=relu_grad_input)

Will the backward do the whole rest backward process of the model since inputs are non-leaf Variables? Maybe use torch.autograd.grad function instead?

�The above comments are not tested, just based on my experience.

@gpleiss
Copy link
Owner Author

gpleiss commented Mar 13, 2018

The efficient densenet matches the error of the normal densenet, so I'm merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants