-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Training the GAROM solver results in a bug AttributeError: 'LabelTensor' object has no attribute '_labels'. This is caused by the losses in validation_step and test_step having no labels. I believe the framework expects those to either have a label set, or be tensors.
After debugging the classes GAROM and SolverInterface I have the following findings:
- both classes have the method
validation_step, which usesself.store_logto log metrics SolverInterfacesuccessfully passes aPytorch Tensoras an input- GAROM, a
LabelTensorwithout labels is passed, which causes the error to be thrown - bug is not reproducible in PINA 1.2.0. Maybe an update of a part of the code, excluding GAROM after 1.2.0 caused the change.
To Reproduce
Create a simple GAROM solver with FNNs for the generator and the discriminator.
Expected behavior
No error when training. Fixed by using the following tensor conversion after the loss calculation.
# lines 277 and 300 in pina/solver/garom.py
loss = self.weighting.aggregate(condition_loss).tensor
Output
summarize ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[34], line 17
15 start = time.time()
16 # train
---> 17 trainer.train()
18 end = time.time()
19 print(f"Training time: {end - start} seconds")
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/trainer.py:230, in Trainer.train(self, **kwargs)
222 def train(self, **kwargs):
223 """
224 Manage the training process of the solver.
225
(...)
228 for details.
229 """
--> 230 return super().fit(self.solver, datamodule=self.data_module, **kwargs)
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:560, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
558 self.training = True
559 self.should_stop = False
--> 560 call._call_and_handle_interrupt(
561 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
562 )
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:49, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
47 if trainer.strategy.launcher is not None:
48 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 49 return trainer_fn(*args, **kwargs)
51 except _TunerExitException:
52 _call_teardown_hook(trainer)
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:598, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
591 download_model_from_registry(ckpt_path, self)
592 ckpt_path = self._checkpoint_connector._select_ckpt_path(
593 self.state.fn,
594 ckpt_path,
595 model_provided=True,
596 model_connected=self.lightning_module is not None,
597 )
--> 598 self._run(model, ckpt_path=ckpt_path)
600 assert self.state.stopped
601 self.training = False
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1011, in Trainer._run(self, model, ckpt_path)
1006 self._signal_connector.register_signal_handlers()
1008 # ----------------------------
1009 # RUN THE TRAINER
1010 # ----------------------------
-> 1011 results = self._run_stage()
1013 # ----------------------------
1014 # POST-Training CLEAN UP
1015 # ----------------------------
1016 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1053, in Trainer._run_stage(self)
1051 if self.training:
1052 with isolate_rng():
-> 1053 self._run_sanity_check()
1054 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1055 self.fit_loop.run()
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py:1082, in Trainer._run_sanity_check(self)
1079 call._call_callback_hooks(self, "on_sanity_check_start")
1081 # run eval step
-> 1082 val_loop.run()
1084 call._call_callback_hooks(self, "on_sanity_check_end")
1086 # reset logger connector
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
177 context_manager = torch.no_grad
178 with context_manager():
--> 179 return loop_run(self, *args, **kwargs)
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py:145, in _EvaluationLoop.run(self)
143 self.batch_progress.is_last_batch = data_fetcher.done
144 # run step hooks
--> 145 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
146 except StopIteration:
147 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
148 break
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py:437, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
431 hook_name = "test_step" if trainer.testing else "validation_step"
432 step_args = (
433 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
434 if not using_dataloader_iter
435 else (dataloader_iter,)
436 )
--> 437 output = call._call_strategy_hook(trainer, hook_name, *step_args)
439 self.batch_progress.increment_processed()
441 if using_dataloader_iter:
442 # update the hook kwargs now that the step method might have consumed the iterator
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:329, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
326 return None
328 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 329 output = fn(*args, **kwargs)
331 # restore current_fx when nested context
332 pl_module._current_fx_name = prev_fx_name
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
410 if self.model != self.lightning_module:
411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/solver/garom.py:278, in GAROM.validation_step(self, batch)
274 condition_loss[condition_name] = self._loss_fn(
275 snapshots, snapshots_gen
276 )
277 loss = self.weighting.aggregate(condition_loss)
--> 278 self.store_log("val_loss", loss, self.get_batch_size(batch))
279 return loss
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/solver/solver.py:159, in SolverInterface.store_log(self, name, value, batch_size)
150 def store_log(self, name, value, batch_size):
151 """
152 Store the log of the solver.
153
(...)
156 :param int batch_size: The size of the batch.
157 """
--> 159 self.log(
160 name=name,
161 value=value,
162 batch_size=batch_size,
163 **self.trainer.logging_kwargs,
164 )
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/core/module.py:483, in LightningModule.log(self, name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph, sync_dist, sync_dist_group, add_dataloader_idx, batch_size, metric_attribute, rank_zero_only)
477 if add_dataloader_idx and "/dataloader_idx_" in name:
478 raise MisconfigurationException(
479 f"You called `self.log` with the key `{name}`"
480 " but it should not contain information about `dataloader_idx` when `add_dataloader_idx=True`"
481 )
--> 483 value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
485 if trainer._logger_connector.should_reset_tensors(self._current_fx_name):
486 # if we started a new epoch (running its first batch) the hook name has changed
487 # reset any tensors for the new hook name
488 results.reset(metrics=False, fx=self._current_fx_name)
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning_utilities/core/apply_func.py:66, in apply_to_collection(data, dtype, function, wrong_dtype, include_none, allow_frozen, *args, **kwargs)
64 # fast path for the most common cases:
65 if isinstance(data, dtype): # single element
---> 66 return function(data, *args, **kwargs)
67 if data.__class__ is list and all(isinstance(x, dtype) for x in data): # 1d homogeneous list
68 return [function(x, *args, **kwargs) for x in data]
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/lightning/pytorch/core/module.py:658, in LightningModule.__to_tensor(self, value, name)
656 def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
657 value = (
--> 658 value.clone().detach()
659 if isinstance(value, Tensor)
660 else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
661 )
662 if not torch.numel(value) == 1:
663 raise ValueError(
664 f"`self.log({name}, {value})` was called, but the tensor must have a single element."
665 f" You can try doing `self.log({name}, {value}.mean())`"
666 )
File ~/anaconda3/envs/ml/lib/python3.12/site-packages/pina/label_tensor.py:445, in LabelTensor.clone(self, *args, **kwargs)
434 def clone(self, *args, **kwargs):
435 """
436 Clone the :class:`~pina.label_tensor.LabelTensor`. For more details, see
437 :meth:`torch.Tensor.clone`.
(...)
441 :rtype: LabelTensor
442 """
444 out = LabelTensor(
--> 445 super().clone(*args, **kwargs), deepcopy(self._labels)
446 )
447 return out
AttributeError: 'LabelTensor' object has no attribute '_labels'
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working