Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed VGG encoder weights query #7

Closed
SidShenoy opened this issue Jun 3, 2019 · 2 comments
Closed

Fixed VGG encoder weights query #7

SidShenoy opened this issue Jun 3, 2019 · 2 comments

Comments

@SidShenoy
Copy link

In your paper you mention that 'We use the encoder-decoder architecture with fixed VGG
encoder weights
' and you only train the decoder for COCO dataset. However, since you have changed the max pooling/unpooling layers from the original VGG-19 architecture to the wavelet pooling/unpooling layers, you need to train the entire encoder-decoder VGG-19 architecture right? Since these layers are changed, the VGG-19 weights will also get affected as a result of which we cannot use the original ImageNet weights. Can you please provide some clarification on this.

@jaejun-yoo
Copy link
Collaborator

jaejun-yoo commented Jun 3, 2019

@SidShenoy Thank you for your comment. That is a very sharp observation!

Yes, indeed, changing the pooling layers affects the following feature maps. Since the output of the max-pooling is now changed to the LL filter (similar to average pooling), the feature map after the pooling layer is slightly different. (Note that the other feature maps from the other three wavelet filters do not propagate to the next layer of the encoder. They are skipped to the decoder so that the only change the encoder has to care about is due to the LL filter change from the max-pooling)

However, as we wrote in the paper, we decided not to touch the encoder but just let the decoder adapt to those changes. You can fine-tune the encoder weights by partially or entirely freeing encoder weight parameters and we actually tried some variants, such as freeing only the following convolution parameters (after pooling layer) so that the change will be dealt in the encoder as well. There was not much difference at the final outcomes so we chose to stick on the simpler training strategy.

This can be explained in two-folds; 1) It is already a well-known phenomenon and a lot of observations were consistently reported that style transfer can be done with changing the max-pooling to average pooling (even though the VGG network was trained using the max-pooling) and the effect is sometimes even better. Similarly, our encoder with LL filter, which is an average pooling with some scaling factor, shares this characteristic. 2) Since the decoder is newly trained, it has enough capacity to deal with such shiftings of the feature maps in the encoder to output a good reconstruction.

Still, your comment is very valuable and we will include our description of the training procedure more in detail to clarify the point. Thx a lot for your attention :)

@jaejun-yoo
Copy link
Collaborator

jaejun-yoo commented Jun 3, 2019

Maybe this partial code snippet would help your understanding on what we did:

for param in self.encoder.parameters():
    param.requires_grad = False
self.dec_optim = torch.optim.Adam(
           filter(lambda p: p.requires_grad, self.decoder.parameters()),
           lr = self.lr,
           betas=(self.beta1, self.beta2)
        )
feature, skips = self.encoder(real_image)
recon_image = self.decoder(feature, skips)
feature_recon, _ = self.encoder(recon_image)
recon_loss = self.MSE_loss(recon_image, real_image)
feature_loss = torch.zeros(1).to(self.device)
feature_loss += self.MSE_loss(feature_recon, feature.detach())
loss = recon_loss * self.recon_weight + feature_loss * self.feature_weight
self.reset_grad()
loss.backward()
self.dec_optim.step()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants