Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountered 'NoneType' object has no attribute 'canvas' error while training a semantic segmentation model #1551

Closed
adamjstewart opened this issue Sep 7, 2023 Discussed in #1487 · 1 comment · Fixed by #1585
Labels
trainers PyTorch Lightning trainers
Milestone

Comments

@adamjstewart
Copy link
Collaborator

Discussed in #1487

Originally posted by gtgrp-user July 19, 2023
I encountered the following error while training a semantic segmentation model using a GeoDataModule.

Constants used for the training:

accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join('./', "full_experiments_geodataset_v1")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=default_root_dir, save_top_k=1, save_last=True
)
early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)
logger = TensorBoardLogger(save_dir=default_root_dir, name="tutorial_logs")

num_workers = 4
max_epochs = 100
fast_dev_run = False

torch.set_float32_matmul_precision('high')

Creating the semantic segmentation task and appropriate trainer

task = SemanticSegmentationTask(
    loss="jaccard",
    model="deeplabv3+",
    backbone="resnet50",
    weights='imagenet',
    in_channels=10,
    num_classes=7,
    learning_rate=0.001,
    learning_rate_schedule_patience=5,
    ignore_index=None
)

trainer = Trainer(
    accelerator=accelerator,
    callbacks=[checkpoint_callback, early_stopping_callback],
    enable_progress_bar = True,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    logger=logger,
    min_epochs=20,
    max_epochs=max_epochs,
)

Error encountered in this method

trainer.fit(task, datamodule = datamodule)

Stack trace:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[21], line 1
----> 1 trainer.fit(task, datamodule = datamodule)

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:608, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    606 model = self._maybe_unwrap_optimized(model)
    607 self.strategy._lightning_module = model
--> 608 call._call_and_handle_interrupt(
    609     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    610 )

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py:38, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37     else:
---> 38         return trainer_fn(*args, **kwargs)
     40 except _TunerExitException:
     41     trainer._call_teardown_hook()

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:650, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    643 ckpt_path = ckpt_path or self.resume_from_checkpoint
    644 self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
    645     self.state.fn,
    646     ckpt_path,  # type: ignore[arg-type]
    647     model_provided=True,
    648     model_connected=self.lightning_module is not None,
    649 )
--> 650 self._run(model, ckpt_path=self.ckpt_path)
    652 assert self.state.stopped
    653 self.training = False

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:1112, in Trainer._run(self, model, ckpt_path)
   1108 self._checkpoint_connector.restore_training_state()
   1110 self._checkpoint_connector.resume_end()
-> 1112 results = self._run_stage()
   1114 log.detail(f"{self.__class__.__name__}: trainer tearing down")
   1115 self._teardown()

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:1191, in Trainer._run_stage(self)
   1189 if self.predicting:
   1190     return self._run_predict()
-> 1191 self._run_train()

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:1204, in Trainer._run_train(self)
   1201 self._pre_training_routine()
   1203 with isolate_rng():
-> 1204     self._run_sanity_check()
   1206 # enable train mode
   1207 assert self.model is not None

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:1276, in Trainer._run_sanity_check(self)
   1274 # run eval step
   1275 with torch.no_grad():
-> 1276     val_loop.run()
   1278 self._call_callback_hooks("on_sanity_check_end")
   1280 # reset logger connector

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
    197 try:
    198     self.on_advance_start(*args, **kwargs)
--> 199     self.advance(*args, **kwargs)
    200     self.on_advance_end()
    201     self._restarting = False

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/dataloader/evaluation_loop.py:152, in EvaluationLoop.advance(self, *args, **kwargs)
    150 if self.num_dataloaders > 1:
    151     kwargs["dataloader_idx"] = dataloader_idx
--> 152 dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
    154 # store batch level output per dataloader
    155 self._outputs.append(dl_outputs)

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/loop.py:199, in Loop.run(self, *args, **kwargs)
    197 try:
    198     self.on_advance_start(*args, **kwargs)
--> 199     self.advance(*args, **kwargs)
    200     self.on_advance_end()
    201     self._restarting = False

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py:137, in EvaluationEpochLoop.advance(self, data_fetcher, dl_max_batches, kwargs)
    134 self.batch_progress.increment_started()
    136 # lightning module methods
--> 137 output = self._evaluation_step(**kwargs)
    138 output = self._evaluation_step_end(output)
    140 self.batch_progress.increment_processed()

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py:234, in EvaluationEpochLoop._evaluation_step(self, **kwargs)
    223 """The evaluation step (validation_step or test_step depending on the trainer's state).
    224 
    225 Args:
   (...)
    231     the outputs of the step
    232 """
    233 hook_name = "test_step" if self.trainer.testing else "validation_step"
--> 234 output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
    236 return output

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py:1494, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
   1491     return
   1493 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1494     output = fn(*args, **kwargs)
   1496 # restore current_fx when nested context
   1497 pl_module._current_fx_name = prev_fx_name

File /usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py:390, in Strategy.validation_step(self, *args, **kwargs)
    388 with self.precision_plugin.val_step_context():
    389     assert isinstance(self.model, ValidationStep)
--> 390     return self.model.validation_step(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torchgeo/trainers/segmentation.py:211, in SemanticSegmentationTask.validation_step(self, *args, **kwargs)
    209     fig = datamodule.plot(sample)
    210     summary_writer = self.logger.experiment
--> 211     summary_writer.add_figure(
    212         f"image/{batch_idx}", fig, global_step=self.global_step
    213     )
    214     plt.close()
    215 except ValueError:

File /usr/local/lib/python3.10/dist-packages/torch/utils/tensorboard/writer.py:750, in SummaryWriter.add_figure(self, tag, figure, global_step, close, walltime)
    740     self.add_image(
    741         tag,
    742         figure_to_image(figure, close),
   (...)
    745         dataformats="NCHW",
    746     )
    747 else:
    748     self.add_image(
    749         tag,
--> 750         figure_to_image(figure, close),
    751         global_step,
    752         walltime,
    753         dataformats="CHW",
    754     )

File /usr/local/lib/python3.10/dist-packages/torch/utils/tensorboard/_utils.py:35, in figure_to_image(figures, close)
     33     return np.stack(images)
     34 else:
---> 35     image = render_to_rgb(figures)
     36     return image

File /usr/local/lib/python3.10/dist-packages/torch/utils/tensorboard/_utils.py:24, in figure_to_image.<locals>.render_to_rgb(figure)
     22 canvas.draw()
     23 data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
---> 24 w, h = figure.canvas.get_width_height()
     25 image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
     26 image_chw = np.moveaxis(image_hwc, source=2, destination=0)

AttributeError: 'NoneType' object has no attribute 'canvas'
```</div>
@adamjstewart
Copy link
Collaborator Author

As discussed in #1487:

I think this bug was introduced in #992 when I switched from catching AttributeError to ValueError. Our datamodules check to see whether or not a plot method is defined, but our trainers don't check to see whether or not a real figure was returned.

The fix is easy (either check the return values or catch more exceptions) but the thing holding this up is a test that prevents this issue. @gtgrp-user if you can provide a MWE to reproduce this or figure out how to get the tests to fail that will make my life easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant