# How to use a SITS dataset

This notebook shows how to create a Satellite Image Time Series (SITS) dataset. The idea is that with this type of sampling, the sample image that is being returned has the shape `[batch, dates, channels, height, width]`.

Additional requirements

In [1]:
%pip install plotly planetary_computer pystac_client tqdm

In [1]:
%reload_ext autoreload
%autoreload 2

## Prepare our data

For this example we are using the RGB bands of the same Sentinel2 tile for 5 different dates. Note that right now we are selecting a specific orbit to ensure that our data covers the same spatial extent.

In [1]:
import planetary_computer
import pystac_client

catalog = pystac_client.Client.open(
    'https://planetarycomputer.microsoft.com/api/stac/v1',
    modifier=planetary_computer.sign_inplace,
)
area_of_interest = {
    'type': 'Polygon',
    'coordinates': [
        [
            [-148.56536865234375, 60.80072385643073],
            [-147.44338989257812, 60.80072385643073],
            [-147.44338989257812, 61.18363894915102],
            [-148.56536865234375, 61.18363894915102],
            [-148.56536865234375, 60.80072385643073],
        ]
    ],
}
time_of_interest = '2019-06-01/2019-10-01'
search = catalog.search(
    collections=['sentinel-2-l2a'],
    intersects=area_of_interest,
    datetime=time_of_interest,
    query={'eo:cloud_cover': {'lt': 13}, 'sat:relative_orbit': {'eq': 143}},
)

# Check how many items were returned
items = search.item_collection()
items

### Download data

In [6]:
import os
import tempfile
from urllib.parse import urlparse

import planetary_computer
import pystac

from torchgeo.datasets.utils import download_url

root = os.path.join(tempfile.gettempdir(), 'sentinel')
item_urls = [item.links[3].href for item in items]

for item_url in item_urls:
    item = pystac.Item.from_file(item_url)
    signed_item = planetary_computer.sign(item)
    for band in ['B02', 'B03', 'B04']:
        asset_href = signed_item.assets[band].href
        filename = urlparse(asset_href).path.split('/')[-1]
        download_url(asset_href, root, filename)

### Dataset
We define a custom dataset which is almost idential to the dataset used in `custom_raster_dataset.ipynb`. The main difference is the glob pattern, which matches any file band instead of a single band, since we are populating the index with all possible files. Lets first create a dataset like we are used to, with `return_as_ts=False`

In [2]:
import logging

import plotly.express as px
import torch

from torchgeo.datasets import RasterDataset


class Sentinel2(RasterDataset):
    # filename_glob = 'T*.tif'
    filename_regex = r'.*(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    all_bands = ('B02', 'B03', 'B04', 'B08')
    rgb_bands = ('B04', 'B03', 'B02')

    def plot(self, sample, show=True):
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample['image'][rgb_indices].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        # Plot the image
        fig = px.imshow(image)
        if show:
            fig.show()
        return fig

In [3]:
single_image_dataset = Sentinel2(root, bands=['B02', 'B03', 'B04'])

NameError: name 'root' is not defined

### Instantiate sampler
We are instantiating a random sampler. This means that we are sampling randomly both spatially, as temporally. 

In [9]:
from torch.utils.data import DataLoader

from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

sampler = RandomGeoSampler(single_image_dataset, size=(100, 100), length=2)

dataloader = DataLoader(
    single_image_dataset,
    sampler=sampler,
    batch_size=1,
    collate_fn=stack_samples,
    num_workers=0,
)

In [50]:
for s in sampler:
    print(s)

(slice(451150.0, 452150.0, None), slice(6753850.0, 6754850.0, None), slice(Timestamp('2019-07-06 21:15:19'), Timestamp('2019-07-06 21:15:19.999999'), None))
(slice(504790.0, 505790.0, None), slice(6715560.0, 6716560.0, None), slice(Timestamp('2019-07-06 21:15:19'), Timestamp('2019-07-06 21:15:19.999999'), None))


In [10]:
from torchgeo.datasets.utils import unbind_samples

for batch in dataloader:
    for sample in unbind_samples(batch):
        single_image_dataset.plot(sample, show=True)

Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B02_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B02_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B02_10m.tif']
Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B03_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B03_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B03_10m.tif']
Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B04_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B04_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190626T211519_B04_10m.tif']


Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B02_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B02_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B02_10m.tif']
Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B03_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B03_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B03_10m.tif']
Merging filpaths: ['C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B04_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B04_10m.tif', 'C:\\Users\\SIEGER~1.FAL\\AppData\\Local\\Temp\\sentinel\\T06VVN_20190706T211519_B04_10m.tif']


## Create the TS dataset
Now we create the SITS dataset. The only difference is that we add the `return_as_ts=True` and add a custom plotting function that allows to visualize timeseries data.

In [7]:
class Sentinel2SITS(RasterDataset):
    filename_glob = 'T*.tif'
    filename_regex = r'.*(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    rgb_bands = ('B04', 'B03', 'B02')

    def plot(self, sample, indices_to_plot=None, show=False, **kwargs):
        """
        Plots the image data from the given sample.

        Args:
            sample (dict): A dictionary containing the image data returned by self.__get_item__.
            indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.
            show (bool, optional): Whether to display the plot. Defaults to False.
            **kwargs: Additional keyword arguments to be passed to the plot function.

        Returns:
            fig: The plotly figure object.

        Raises:
            None

        """

        if indices_to_plot:
            indices = indices_to_plot
        else:
            if self.bands == self.all_bands:
                # Find the correct band index order
                indices = []
                for band in self.rgb_bands:
                    indices.append(self.all_bands.index(band))
            else:
                logging.info('No indices to plot provided, using first band by default')
                print('Here')
                indices = [0]

        print(f'Plotting bands: {[self.bands[i] for i in indices]}')

        image = sample['image']

        # Reorder and rescale the image
        if self.time_series:
            # Shape of image = [d, c, h, w]
            image = image[:, indices, :, :].permute(0, 2, 3, 1)
            if image.shape[-1] == 1:
                image = image.squeeze(-1)
            image = torch.clamp(image / 5000, min=0, max=1).numpy()

            fig = px.imshow(
                image, animation_frame=0, labels={'animation_frame': 'Date'}, **kwargs
            )
            date_labels = [
                date.strftime('%m/%d/%Y, %H:%M:%S') for date in sample['dates']
            ]
            for i, label in enumerate(date_labels):
                fig.layout.sliders[0].steps[i].label = label

        else:
            image = image[indices].permute(1, 2, 0)
            image = torch.clamp(image / 5000, min=0, max=1).numpy()

            # Plot the image
            fig = px.imshow(image)

        fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
        if show:
            fig.show()
        return fig

In [8]:
sits_dataset = Sentinel2SITS(root, bands=['B02', 'B03', 'B04'], time_series=True)

In [11]:
sits_dataset.bounds

(slice(np.float64(399960.0), np.float64(509760.0), 10.0),
 slice(np.float64(6690240.0), np.float64(6800040.0), 10.0),
 slice(Timestamp('2019-06-26 21:15:19'), Timestamp('2019-08-20 21:15:21.999999'), 1))

In [20]:
import pandas as pd

sample = sits_dataset[
    slice(451150.0, 452150.0, None),
    slice(6753850.0, 6754850.0, None),
    slice(
        pd.Timestamp('2019-06-26 00:00:00'), pd.Timestamp('2019-09-06 21:15:19.999999')
    ),
]

In [21]:
sample['image'].shape

torch.Size([4, 3, 100, 100])

In [22]:
sits_dataset.plot(sample, indices_to_plot=[0, 1, 2], show=False)

Plotting bands: ['B02', 'B03', 'B04']
