Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorials: issue with trainer and logger? #1143

Closed
adamjstewart opened this issue Feb 24, 2023 · 7 comments · Fixed by #1145
Closed

Tutorials: issue with trainer and logger? #1143

adamjstewart opened this issue Feb 24, 2023 · 7 comments · Fixed by #1145
Labels
documentation Improvements or additions to documentation trainers PyTorch Lightning trainers
Milestone

Comments

@adamjstewart
Copy link
Collaborator

adamjstewart commented Feb 24, 2023

Description

Seeing the following error when I run our Pretrained Weights tutorial on Colab:

AttributeError                            Traceback (most recent call last)
<ipython-input-10-009018e0ed8d> in <module>
----> 1 trainer.fit(model=task, datamodule=datamodule)

14 frames
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    606         model = self._maybe_unwrap_optimized(model)
    607         self.strategy._lightning_module = model
--> 608         call._call_and_handle_interrupt(
    609             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    610         )

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37         else:
---> 38             return trainer_fn(*args, **kwargs)
     39 
     40     except _TunerExitException:

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    648             model_connected=self.lightning_module is not None,
    649         )
--> 650         self._run(model, ckpt_path=self.ckpt_path)
    651 
    652         assert self.state.stopped

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1110         self._checkpoint_connector.resume_end()
   1111 
-> 1112         results = self._run_stage()
   1113 
   1114         log.detail(f"{self.__class__.__name__}: trainer tearing down")

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
   1189         if self.predicting:
   1190             return self._run_predict()
-> 1191         self._run_train()
   1192 
   1193     def _pre_training_routine(self) -> None:

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1202 
   1203         with isolate_rng():
-> 1204             self._run_sanity_check()
   1205 
   1206         # enable train mode

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self)
   1274             # run eval step
   1275             with torch.no_grad():
-> 1276                 val_loop.run()
   1277 
   1278             self._call_callback_hooks("on_sanity_check_end")

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    150         if self.num_dataloaders > 1:
    151             kwargs["dataloader_idx"] = dataloader_idx
--> 152         dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    153 
    154         # store batch level output per dataloader

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dl_max_batches, kwargs)
    135 
    136         # lightning module methods
--> 137         output = self._evaluation_step(**kwargs)
    138         output = self._evaluation_step_end(output)
    139 

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, **kwargs)
    232         """
    233         hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 234         output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
    235 
    236         return output

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1492 
   1493         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1494             output = fn(*args, **kwargs)
   1495 
   1496         # restore current_fx when nested context

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/strategies/strategy.py in validation_step(self, *args, **kwargs)
    388         with self.precision_plugin.val_step_context():
    389             assert isinstance(self.model, ValidationStep)
--> 390             return self.model.validation_step(*args, **kwargs)
    391 
    392     def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

/usr/local/lib/python3.8/dist-packages/torchgeo/trainers/classification.py in validation_step(self, *args, **kwargs)
    202                 fig = datamodule.plot(sample)
    203                 summary_writer = self.logger.experiment
--> 204                 summary_writer.add_figure(
    205                     f"image/{batch_idx}", fig, global_step=self.global_step
    206                 )

AttributeError: 'ExperimentWriter' object has no attribute 'add_figure'

Could it be that our trainers only work when using TensorBoardLogger and CSVLogger is crashing? If so we should fix the trainers. Also curious why our tests didn't catch this.

Steps to reproduce

  1. Launch Colab
  2. Edit the first cell to install main because 0.4.0 has a broken download link:
    %pip install git+https://github.com/microsoft/torchgeo.git
    
  3. Run all
  4. Restart and run all (yay jupyter)
  5. Last cell should fail

Version

0.5.0.dev0 (e81af42)

@adamjstewart adamjstewart added documentation Improvements or additions to documentation trainers PyTorch Lightning trainers labels Feb 24, 2023
@adamjstewart adamjstewart added this to the 0.4.1 milestone Feb 24, 2023
@adamjstewart adamjstewart changed the title Pretrained Weights tutorial: issue with trainer and logger? Tutorials: issue with trainer and logger? Feb 24, 2023
@adamjstewart
Copy link
Collaborator Author

Same issue with the trainers tutorial.

@isaaccorley
Copy link
Collaborator

Probably need to add a line using hasattr to check if the logger has an add_figure method. That would allow users to train w/o the plotting overhead as well.

@adamjstewart
Copy link
Collaborator Author

I'm surprised that mypy doesn't catch this, it normally complains on possibly undefined attributes like this. But yeah, I think hasattr is the easiest fix. We could also expand the try-except to catch anything, but that's bad practice.

@ashnair1
Copy link
Collaborator

ashnair1 commented Feb 24, 2023

I believe this is because of the logger used. In the tutorial we're using CSVLogger, but add_figure is a function from the TensorBoardLogger

@adamjstewart
Copy link
Collaborator Author

Yeah, but we run the exact same tutorial in our release tests and it passes, so idk why it only happens on Colab.

I'll try adding a test with CSVLogger to see if I can reproduce it locally.

@ashnair1
Copy link
Collaborator

Just switched to TensorBoardLogger in Colab and it seems to be working now.

@adamjstewart
Copy link
Collaborator Author

Nope, can't reproduce locally with CSVLogger. I can fix the bug but it's frustrating that I can't test it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants