Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA, 3D noise #24

Closed
pmh9960 opened this issue Apr 19, 2022 · 2 comments
Closed

EMA, 3D noise #24

pmh9960 opened this issue Apr 19, 2022 · 2 comments

Comments

@pmh9960
Copy link

pmh9960 commented Apr 19, 2022

Hi. @edgarschnfld @SushkoVadim

I am a student studying semantic image synthesis. Thank you for the great work. I have two questions about the difference between paper and code.

  1. EMA
    As you cite [Yaz et al., 2018], exponential moving average is a good technique for training GAN. However, in your code

    OASIS/utils/utils.py

    Lines 125 to 132 in 6e728ec

    def update_EMA(model, cur_iter, dataloader, opt, force_run_stats=False):
    # update weights based on new generator weights
    with torch.no_grad():
    for key in model.module.netEMA.state_dict():
    model.module.netEMA.state_dict()[key].data.copy_(
    model.module.netEMA.state_dict()[key].data * opt.EMA_decay +
    model.module.netG.state_dict()[key].data * (1 - opt.EMA_decay)
    )

    I think below code might be added
model.module.netG.state_dict()[key].data.copy_(
    model.module.netEMA.state_dict()[key].data
)

If not, netG is not trained using EMA.

Yaz, Yasin, et al. "The unusual effectiveness of averaging in GAN training." International Conference on Learning Representations. 2018.

  1. 3D noise

If I do not misunderstand your paper, the paper says that the noise of OASIS has been sampled from a 3D normal distribution. And this is one of the main differences with SPADE.
However, in your code at,

if not self.opt.no_3dnoise:
dev = seg.get_device() if self.opt.gpu_ids != "-1" else "cpu"
z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=dev)
z = z.view(z.size(0), self.opt.z_dim, 1, 1)
z = z.expand(z.size(0), self.opt.z_dim, seg.size(2), seg.size(3))
seg = torch.cat((z, seg), dim = 1)

Noise is not sampled from the 3D normal distribution. It was also sampled from a 1D normal distribution. Then expand it to 3D, which replicates the same vector spatial way.
In my opinion, this code should be replaced by

z = torch.randn(seg.shape, ...)

I think both two parts are pretty crucial for your paper. If there is any reason for these choices or my fault, please let me know.

Thank you.

@SushkoVadim
Copy link
Contributor

Hi,

  1. The netEMA checkpoint is meant simply to track the running average of weights of the generator network netG. When used at inference instead of netG, it has indeed the potential to improve performance. It is usually not meant to be used during training. In your suggested example, netG is not allowed to be different from netEMA, which imposes a strong constraint on the generator. This will likely impair the training by making it much harder for the generator to fool the discriminator netD.
  2. We indeed do not assume sampling from a 3D normal distribution, and use a simpler "replication" strategy. Please refer to Appendix A.7 in the paper for the related discussion, and the paragraph in Sec. 3.3 in the main paper:

Note that for simplicity during training we sample the 3D noise tensor globally, i.e. per-channel, replicating each channel value spatially along the height and width of the tensor. We analyse alternative ways of sampling 3D noise during training in App. A.7.

@pmh9960
Copy link
Author

pmh9960 commented Apr 19, 2022

Thank you for fast reply and detailed explanation!
I will consider it in my work.

Thank you.

@pmh9960 pmh9960 closed this as completed Apr 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants