Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA Bug #87

Closed
CiaoHe opened this issue May 12, 2022 · 5 comments
Closed

EMA Bug #87

CiaoHe opened this issue May 12, 2022 · 5 comments

Comments

@CiaoHe
Copy link

CiaoHe commented May 12, 2022

Hi Phil,

This morning I tried to run the decoder training part. I decided to use DecoderTrainer but found one issue when ema update.

When after using decoder_trainer do sampling, the next train forward run will throw RunError:

Traceback (most recent call last):
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 321, in <module>    main()
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 318, in main
    train(decoder_trainer, train_dl, val_dl, train_config, device)
  File "/home/caohe/DPMs/dalle2/train_decoder.py", line 195, in train
    trainer.update(unet_number)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 288, in update
    self.ema_unets[index].update()
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 119, in update
    self.update_moving_average(self.ema_model, self.online_model)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 129, in update_moving_average
    ema_param.data = calculate_ema(self.beta, old_weight, up_weight)
  File "/home/caohe/DPMs/dalle2/dalle2_pytorch/train.py", line 125, in calculate_ema
    return old * beta + new * (1 - beta)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!

def update(self):
self.step += 1
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return
if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model)

And I checked the up_weight.device(online model) and old_weight.device(ema model), found online model is on cuda:0 but ema model is on cpu. It's really weird, I debugged for a long time and I think it might be caused by the DecoderTrainer.sample() process.
When swapping across ema and online model, there exists some problem related to the device.

@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output

The way I fixed it just add self.ema_model = self.ema_model.to(next(self.online_model.parameters()).device) before use self.update_moving_average(self.ema_model, self.online_model) (pretty naive haha)

Hope to hear your solution

Enjoy!

@lucidrains
Copy link
Owner

@CiaoHe ohh yes, you are correct, thank you! i think this should fix it 924455d

@CiaoHe
Copy link
Author

CiaoHe commented May 12, 2022

btw, how do you usually debug/test when adding some new functions or starting a new repo? I found my efficiency is quite low (Either run in command and wait for ERROR, or copy codes into jupyter-notebook and test again and again...)

@lucidrains
Copy link
Owner

@CiaoHe i've come full circle and just use a simple test.py in the root directory + print lol

@CiaoHe
Copy link
Author

CiaoHe commented May 12, 2022

@CiaoHe i've come full circle and just use a simple test.py in the root directory + print lol

@lucidrains lol. But when moving to cluster do train, things gonna be out of control sometimes (I hate bugs)

@lucidrains
Copy link
Owner

🪰 🪱 🐞

@CiaoHe CiaoHe closed this as completed May 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants