Skip to content

GAROM solver validation_step and test_step methods need to convert loss to tensor #715

@green3diamond

Description

@green3diamond

Describe the bug
Training the GAROM solver results in a bug AttributeError: 'LabelTensor' object has no attribute '_labels'. This is caused by the losses in validation_step and test_step having no labels. I believe the framework expects those to either have a label set, or be tensors.

After debugging the classes GAROM and SolverInterface I have the following findings:

  • both classes have the method validation_step, which uses self.store_log to log metrics
  • SolverInterface successfully passes a Pytorch Tensor as an input
  • GAROM, a LabelTensor without labels is passed, which causes the error to be thrown
  • bug is not reproducible in PINA 1.2.0. Maybe an update of a part of the code, excluding GAROM after 1.2.0 caused the change.

To Reproduce
Create a simple GAROM solver with FNNs for the generator and the discriminator.

Expected behavior
No error when training. Fixed by using the following tensor conversion after the loss calculation.

# lines 277 and 300 in pina/solver/garom.py
loss = self.weighting.aggregate(condition_loss).tensor

Output

summarize ---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[34], line 17
     15 start = time.time()
     16 # train
---> 17 trainer.train()
     18 end = time.time()
     19 print(f"Training time: {end - start} seconds")

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/trainer.py:230, in Trainer.train(self, **kwargs)
    222 def train(self, **kwargs):
    223     """
    224     Manage the training process of the solver.
    225 
   (...)
    228         for details.
    229     """
--> 230     return super().fit(self.solver, datamodule=self.data_module, **kwargs)

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:560, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    558 self.training = True
    559 self.should_stop = False
--> 560 call._call_and_handle_interrupt(
    561     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    562 )

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:49, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     47     if trainer.strategy.launcher is not None:
     48         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 49     return trainer_fn(*args, **kwargs)
     51 except _TunerExitException:
     52     _call_teardown_hook(trainer)

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:598, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    591     download_model_from_registry(ckpt_path, self)
    592 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    593     self.state.fn,
    594     ckpt_path,
    595     model_provided=True,
    596     model_connected=self.lightning_module is not None,
    597 )
--> 598 self._run(model, ckpt_path=ckpt_path)
    600 assert self.state.stopped
    601 self.training = False

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1011, in Trainer._run(self, model, ckpt_path)
   1006 self._signal_connector.register_signal_handlers()
   1008 # ----------------------------
   1009 # RUN THE TRAINER
   1010 # ----------------------------
-> 1011 results = self._run_stage()
   1013 # ----------------------------
   1014 # POST-Training CLEAN UP
   1015 # ----------------------------
   1016 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1053, in Trainer._run_stage(self)
   1051 if self.training:
   1052     with isolate_rng():
-> 1053         self._run_sanity_check()
   1054     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1055         self.fit_loop.run()

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1082, in Trainer._run_sanity_check(self)
   1079 call._call_callback_hooks(self, "on_sanity_check_start")
   1081 # run eval step
-> 1082 val_loop.run()
   1084 call._call_callback_hooks(self, "on_sanity_check_end")
   1086 # reset logger connector

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177     context_manager = torch.no_grad
    178 with context_manager():
--> 179     return loop_run(self, *args, **kwargs)

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py:145, in _EvaluationLoop.run(self)
    143     self.batch_progress.is_last_batch = data_fetcher.done
    144     # run step hooks
--> 145     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    146 except StopIteration:
    147     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    148     break

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py:437, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    431 hook_name = "test_step" if trainer.testing else "validation_step"
    432 step_args = (
    433     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    434     if not using_dataloader_iter
    435     else (dataloader_iter,)
    436 )
--> 437 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    439 self.batch_progress.increment_processed()
    441 if using_dataloader_iter:
    442     # update the hook kwargs now that the step method might have consumed the iterator

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:329, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    326     return None
    328 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 329     output = fn(*args, **kwargs)
    331 # restore current_fx when nested context
    332 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
    410 if self.model != self.lightning_module:
    411     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/solver/garom.py:278, in GAROM.validation_step(self, batch)
    274     condition_loss[condition_name] = self._loss_fn(
    275         snapshots, snapshots_gen
    276     )
    277 loss = self.weighting.aggregate(condition_loss)
--> 278 self.store_log("val_loss", loss, self.get_batch_size(batch))
    279 return loss

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/solver/solver.py:159, in SolverInterface.store_log(self, name, value, batch_size)
    150 def store_log(self, name, value, batch_size):
    151     """
    152     Store the log of the solver.
    153 
   (...)
    156     :param int batch_size: The size of the batch.
    157     """
--> 159     self.log(
    160         name=name,
    161         value=value,
    162         batch_size=batch_size,
    163         **self.trainer.logging_kwargs,
    164     )

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/core/module.py:483, in LightningModule.log(self, name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, sync_dist, sync_dist_group, add_dataloader_idx, batch_size, metric_attribute, rank_zero_only)
    477 if add_dataloader_idx and "/dataloader_idx_" in name:
    478     raise MisconfigurationException(
    479         f"You called `self.log` with the key `{name}`"
    480         " but it should not contain information about `dataloader_idx` when `add_dataloader_idx=True`"
    481     )
--> 483 value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
    485 if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
    486     # if we started a new epoch (running its first batch) the hook name has changed
    487     # reset any tensors for the new hook name
    488     results.reset(metrics=False, fx=self._current_fx_name)

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py:66, in apply_to_collection(data, dtype, function, wrong_dtype, include_none, allow_frozen, *args, **kwargs)
     64 # fast path for the most common cases:
     65 if isinstance(data, dtype):  # single element
---> 66     return function(data, *args, **kwargs)
     67 if data.__class__ is list and all(isinstance(x, dtype) for x in data):  # 1d homogeneous list
     68     return [function(x, *args, **kwargs) for x in data]

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/core/module.py:658, in LightningModule.__to_tensor(self, value, name)
    656 def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
    657     value = (
--> 658         value.clone().detach()
    659         if isinstance(value, Tensor)
    660         else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
    661     )
    662     if not torch.numel(value) == 1:
    663         raise ValueError(
    664             f"`self.log({name}, {value})` was called, but the tensor must have a single element."
    665             f" You can try doing `self.log({name}, {value}.mean())`"
    666         )

File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/label_tensor.py:445, in LabelTensor.clone(self, *args, **kwargs)
    434 def clone(self, *args, **kwargs):
    435     """
    436     Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
    437     :meth:`torch.Tensor.clone`.
   (...)
    441     :rtype: LabelTensor
    442     """
    444     out = LabelTensor(
--> 445         super().clone(*args, **kwargs), deepcopy(self._labels)
    446     )
    447     return out

AttributeError: 'LabelTensor' object has no attribute '_labels'

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions