Skip to content

Commit

Permalink
align the ema model device back after sampling from the cascading ddp…
Browse files Browse the repository at this point in the history
…m in the decoder
  • Loading branch information
lucidrains committed May 12, 2022
1 parent 6021945 commit 924455d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions dalle2_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __init__(
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))

def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)

def update(self):
self.step += 1

Expand Down Expand Up @@ -305,6 +309,11 @@ def sample(self, *args, **kwargs):

if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets

# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()

return output

def forward(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.11',
version = '0.2.12',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 924455d

Please sign in to comment.