Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix error connected to retain_grad() #57

Merged
merged 8 commits into from
Apr 18, 2022

Conversation

JonasHell
Copy link
Contributor

@JonasHell JonasHell commented Apr 15, 2022

Connected to 82d44dd

Since pred.retain_grad() is now part of _forward_and_loss, it is possible that it is called while we are in _validate_impl within a torch.no_grad() context. This leads to the fact that pred.requires_grad can be False, which leads to the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/g/kreshuk/hellgoth/domain_adaptation/scripts/check_new_trainer_impl.py in <module>
     18 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
     20 trainer = DefaultTrainer(
     21     name='test',
     22     train_loader=loader,
   (...)
     29     log_image_interval=20
     30 )
---> 32 trainer.fit(50)

File ~/torch-em/torch_em/trainer/default_trainer.py:351, in DefaultTrainer.fit(self, iterations, load_from_checkpoint)
    349 for _ in range(train_epochs):
    350     t_per_iter = train_epoch(progress)
--> 351     current_metric = validate()
    353     if self.lr_scheduler is not None:
    354         self.lr_scheduler.step(current_metric)

File ~/torch-em/torch_em/trainer/default_trainer.py:439, in DefaultTrainer._validate_mixed(self)
    438 def _validate_mixed(self):
--> 439     return self._validate_impl(amp.autocast)

File ~/torch-em/torch_em/trainer/default_trainer.py:451, in DefaultTrainer._validate_impl(self, forward_context)
    449 x, y = x.to(self.device), y.to(self.device)
    450 with forward_context():
--> 451     pred, loss = self._forward_and_loss(x, y)
    452     metric = self.metric(pred, y)
    454 loss_val += loss.item()

File ~/torch-em/torch_em/trainer/default_trainer.py:402, in DefaultTrainer._forward_and_loss(self, x, y)
    399 pred = self.model(x)
    400 if self._iteration % self.log_image_interval == 0:
    401     #if pred.requires_grad:
--> 402     pred.retain_grad()
    404 loss = self.loss(pred, y)
    405 return pred, loss

RuntimeError: can't retain_grad on Tensor that has requires_grad=False

I added an if statement that fixes the issue. I hope this does not interfere with any desired functionality.

Code to reproduce the error:

import torch
import torch_em
from torch_em.trainer import DefaultTrainer
from torch_em.model import UNet2d


ds = torch.utils.data.TensorDataset(torch.rand(10, 1, 256, 256), torch.rand(10, 1, 256, 256))
loader = torch.utils.data.DataLoader(ds)
loader.shuffle = True

model = UNet2d(in_channels=1, out_channels=1, final_activation="Sigmoid")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

trainer = DefaultTrainer(
    name='test',
    train_loader=loader,
    val_loader=loader,
    model=model,
    loss=torch_em.loss.DiceLoss(),
    optimizer=optimizer,
    metric=torch_em.loss.DiceLoss(),
    device=torch.device('cuda'),
    log_image_interval=20
)

trainer.fit(50)

@constantinpape
Copy link
Owner

Thanks! The change looks good, and indeed it only makes sense to call retain_grad if we also require gradients.

@constantinpape constantinpape merged commit b07182f into constantinpape:main Apr 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants