Skip to content

Commit

Permalink
Add own plot method to CMSMangroveCanopy (#427)
Browse files Browse the repository at this point in the history
* add own plot method

* Update torchgeo/datasets/cms_mangrove_canopy.py

* Update cms_mangrove_canopy.py

* whitespace

* Removing versionchanged

* Any instead of Tensor

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
nilsleh and calebrob6 committed Feb 26, 2022
1 parent 7e724dc commit 3cc9ef9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/datasets/test_cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,10 @@ def test_or(self, dataset: CMSGlobalMangroveCanopy) -> None:
def test_plot(self, dataset: CMSGlobalMangroveCanopy) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x["mask"])
dataset.plot(x, suptitle="Test")

def test_plot_prediction(self, dataset: CMSGlobalMangroveCanopy) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
46 changes: 46 additions & 0 deletions torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from typing import Any, Callable, Dict, Optional

import matplotlib.pyplot as plt
from rasterio.crs import CRS

from .geo import RasterDataset
Expand Down Expand Up @@ -249,3 +250,48 @@ def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, self.zipfile)
extract_archive(pathname)

def plot( # type: ignore[override]
self,
sample: Dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1

showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze()
ncols = 2

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4, 4))

if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(pred)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")

if suptitle is not None:
plt.suptitle(suptitle)

return

0 comments on commit 3cc9ef9

Please sign in to comment.