Skip to content

Commit

Permalink
multiscale discriminators should also be controlled by discr_start_af…
Browse files Browse the repository at this point in the history
…ter_step
  • Loading branch information
lucidrains committed Nov 22, 2023
1 parent a7e8377 commit 041ccb6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,9 +1623,11 @@ def forward(
return_recon_loss_only = False,
apply_gradient_penalty = True,
video_contains_first_frame = True,
adversarial_loss_weight = None
adversarial_loss_weight = None,
multiscale_adversarial_loss_weight = None
):
adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight)
multiscale_adversarial_loss_weight = default(multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight)

assert (return_loss + return_codes + return_discr_loss) <= 1
assert video_or_images.ndim in {4, 5}
Expand Down Expand Up @@ -1830,7 +1832,7 @@ def forward(

weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))

total_loss = total_loss + weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight

# loss breakdown

Expand Down
4 changes: 3 additions & 1 deletion magvit2_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def train_step(self, dl_iter):
train_adversarially = self.model.use_gan and (step + 1) > self.discr_start_after_step

adversarial_loss_weight = 0. if not train_adversarially else None
multiscale_adversarial_loss_weight = 0. if not train_adversarially else None

# main model

Expand All @@ -318,7 +319,8 @@ def train_step(self, dl_iter):
loss, loss_breakdown = self.model(
data,
return_loss = True,
adversarial_loss_weight = adversarial_loss_weight
adversarial_loss_weight = adversarial_loss_weight,
multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
)

self.accelerator.backward(loss / self.grad_accum_every)
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.38'
__version__ = '0.1.39'

0 comments on commit 041ccb6

Please sign in to comment.