In [1]:
from torchgeo.datasets import Sentinel2

data_dir = r"tests\data\sentinel2"

ds = Sentinel2(data_dir, bands=["B02", "B03", "B04", "B08"], cache=False, res=10)


Populating index


100%|██████████| 16/16 [00:00<00:00, 76.25it/s]


In [2]:
ds.bounds

BoundingBox(minx=399960.0, maxx=401240.0, miny=4498720.0, maxy=4500000.0, mint=1555079321.0, maxt=1649927271.999999)

In [3]:
from torchgeo.datasets.utils import BoundingBox
full_t_query = BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1649927271.999999)
sample = ds[[full_t_query]]
print(sample["image"].shape)
print(sample["dates"])

torch.Size([1, 4, 13, 13])
[[datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51), datetime.datetime(2022, 4, 14, 11, 7, 51)]]


In [4]:
multi_t_query = [BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1605264929),
                 BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1605264929.0, maxt=1649927272),
                 ]
sample = ds[multi_t_query]
print(sample["image"].shape)
print(sample["dates"])

torch.Size([2, 4, 13, 13])
[[datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51)], [datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2022, 4, 14, 11, 7, 51)]]


In [None]:
import torch
import plotly.express as px

def plot(
    sample: dict,
    indices_to_plot,
    show = False,
    **kwargs,
):
    """Plots the image data from the given sample.

    Args:
        sample (dict): A dictionary containing the image data returned by self.__get_item__. Should contain the key "image".
        indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.
        show (bool, optional): Whether to display the plot. Defaults to False.
        **kwargs (dict): Additional keyword arguments to be passed to `px.imshow`.

    Returns:
        fig: The plotly figure object.
    """
    image = sample["image"]

    # Reorder and rescale the image
    if (sample["image"].ndim == 4) and (sample["image"].shape[0] > 1):
        # Shape of image = [d, c, h, w]
        image = image[:, indices_to_plot, :, :].permute(0, 2, 3, 1)
        if image.shape[-1] == 1:
            image = image.squeeze(-1)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        fig = px.imshow(
            image, animation_frame=0, labels={"animation_frame": "Date"}, **kwargs
        )
        # Todo, currently taking the first date, need to handle multiple dates
        date_labels = [
            dates[0].strftime("%m/%d/%Y, %H:%M:%S") for dates in sample["dates"]
        ]
        for i, label in enumerate(date_labels):
            fig.layout.sliders[0].steps[i].label = label

    else:
        image = image[indices_to_plot].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        # Plot the image
        fig = px.imshow(image, **kwargs)

    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    if show:
        fig.show()
    return fig

In [12]:
plot(sample, show=False, indices_to_plot=[2, 1, 0])