You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This operation is needed for the correct use of BatchNorm at test time. When updating the weights of EMA, our code iterates over the state_dict of the generator:
with torch.no_grad():
for key in model.module.netEMA.state_dict():
model.module.netEMA.state_dict()[key].data.copy_(
model.module.netEMA.state_dict()[key].data * opt.EMA_decay +
model.module.netG.state_dict()[key].data * (1 - opt.EMA_decay)
)
which contains all the network's weights. However, the estimates of mean and variance of batches are not copied, since these statistics are updated only during a forward pass through the network (see PyTorch documentation).
Therefore, for correct testing, before switching to evaluation mode netEMA.eval(), we need to collect some running stats for the batch norm in the netEMA.train() regime, using some label maps from the train set. Without this stats accumulation, the images produced by netEMA usually look like colorful noise, so this step is important.
Hi @SushkoVadim @edgarschnfld,
Thanks for the excellent work. I noted that when updating EMA, you collect the running stats for BatchNorm before FID computation, image or network saving (see https://github.com/boschresearch/OASIS/blob/master/utils/utils.py#L133).
May I ask about the purpose and intuition of this operation? How significant would it affect the model performance? Thank you in advance.
The text was updated successfully, but these errors were encountered: