Download S2 from PC

In [None]:
import os
from urllib.parse import urlparse

import matplotlib.pyplot as plt
import planetary_computer
import pystac
import torch
from torch.utils.data import DataLoader

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 12)

In [None]:
root = "/data/sentinel"

In [None]:
item_urls = [
    'https://planetarycomputer.microsoft.com/api/stac/v1/collections/sentinel-2-l2a/items/S2B_MSIL2A_20241101T170349_R069_T15TWG_20241101T204038',
]

for item_url in item_urls:
    item = pystac.Item.from_file(item_url)
    signed_item = planetary_computer.sign(item)
    for band in ['B02', 'B03', 'B04', 'B08']:
        asset_href = signed_item.assets[band].href
        filename = urlparse(asset_href).path.split('/')[-1]
        download_url(asset_href, root, filename)

Load S2

In [None]:
class Sentinel2(RasterDataset):
    filename_glob = 'T*_B02_10m.tif'
    filename_regex = r'^.{6}_(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    all_bands = ('B02', 'B03', 'B04', 'B08')
    rgb_bands = ('B04', 'B03', 'B02')

    def plot(self, sample):
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample['image'][rgb_indices].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig

Viz

In [None]:
torch.manual_seed(1)

dataset = Sentinel2(root)
sampler = RandomGeoSampler(dataset, size=4096, length=3)
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)

for batch in dataloader:
    sample = unbind_samples(batch)[0]
    dataset.plot(sample)
    plt.axis('off')
    plt.show()

Define S2 and CDL datamodules

In [None]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Sentinel-2 and CDL datamodule."""

from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from torchgeo.datasets import CDL, Sentinel2, random_grid_cell_assignment
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler
from torchgeo.samplers.utils import _to_tuple
from torchgeo.transforms import AugmentationSequential
from torchgeo.datamodules.geo import GeoDataModule

import os
import tempfile
from urllib.parse import urlparse

import matplotlib.pyplot as plt
import planetary_computer
import pystac
import torch
from torch.utils.data import DataLoader

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 12)

class Sentinel2Custom(RasterDataset):
    filename_glob = 'T*_B02_10m.tif'
    filename_regex = r'^.{6}_(?P<date>\d{8}T\d{6})_(?P<band>B0[\d])'
    date_format = '%Y%m%dT%H%M%S'
    is_image = True
    separate_files = True
    all_bands = ('B02', 'B03', 'B04', 'B08')
    rgb_bands = ('B04', 'B03', 'B02')

    def plot(self, sample):
        # Find the correct band index order
        rgb_indices = []
        for band in self.rgb_bands:
            rgb_indices.append(self.all_bands.index(band))

        # Reorder and rescale the image
        image = sample['image'][rgb_indices].permute(1, 2, 0)
        image = torch.clamp(image / 10000, min=0, max=1).numpy()

        # Plot the image
        fig, ax = plt.subplots()
        ax.imshow(image)

        return fig

class Sentinel2CDLDataModule(GeoDataModule):
    """LightningDataModule implementation for the Sentinel-2 and CDL datasets.

    .. versionadded:: 0.6
    """

    def __init__(
        self,
        batch_size: int = 64,
        patch_size: Union[int, tuple[int, int]] = 64,
        length: Optional[int] = None,
        num_workers: int = 0,
        **kwargs: Any,
    ) -> None:
        """Initialize a new Sentinel2CDLDataModule instance.

        Args:
            batch_size: Size of each mini-batch.
            patch_size: Size of each patch, either ``size`` or ``(height, width)``.
            length: Length of each training epoch.
            num_workers: Number of workers for parallel data loading.
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.CDL` (prefix keys with ``cdl_``) and
                :class:`~torchgeo.datasets.Sentinel2`
                (prefix keys with ``sentinel2_``).
        """
        # Define prefix for Cropland Data Layer (CDL) and Sentinel-2 arguments
        cdl_signature = "cdl_"
        sentinel2_signature = "sentinel2_"
        self.cdl_kwargs = {}
        self.sentinel2_kwargs = {}

        for key, val in kwargs.items():
            # Check if the current key starts with the CDL prefix
            if key.startswith(cdl_signature):
                # If so, extract the key-value pair to the CDL dictionary
                self.cdl_kwargs[key[len(cdl_signature) :]] = val
            # Check if the current key starts with the Sentinel-2 prefix
            elif key.startswith(sentinel2_signature):
                # If so, extract the key-value pair to the Sentinel-2 dictionary
                self.sentinel2_kwargs[key[len(sentinel2_signature) :]] = val

        super().__init__(
            CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs
        )

        self.train_aug = AugmentationSequential(
            K.Normalize(mean=self.mean, std=self.std),
            K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
            K.RandomVerticalFlip(p=0.5),
            K.RandomHorizontalFlip(p=0.5),
            data_keys=["image", "mask"],
            extra_args={
                DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
            },
        )

        self.aug = AugmentationSequential(
            K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
        )

    def setup(self, stage: str) -> None:
        """Set up datasets and samplers.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        self.sentinel2 = Sentinel2Custom(**self.sentinel2_kwargs)
        self.cdl = CDL(**self.cdl_kwargs)
        print(self.sentinel2.index, self.sentinel2.crs)
        print(self.cdl.index, self.cdl.crs)
        self.dataset = self.sentinel2 & self.cdl

        generator = torch.Generator().manual_seed(0)

        (self.train_dataset, self.val_dataset, self.test_dataset) = (
            random_grid_cell_assignment(
                self.dataset, [0.8, 0.10, 0.10], grid_size=8, generator=generator
            )
        )
        if stage in ["fit"]:
            self.train_batch_sampler = RandomBatchGeoSampler(
                self.train_dataset, self.patch_size, self.batch_size, self.length
            )
        if stage in ["fit", "validate"]:
            self.val_sampler = GridGeoSampler(
                self.val_dataset, self.patch_size, self.patch_size
            )
        if stage in ["test"]:
            self.test_sampler = GridGeoSampler(
                self.test_dataset, self.patch_size, self.patch_size
            )

    def plot(self, *args: Any, **kwargs: Any) -> Figure:
        """Run CDL plot method.

        Args:
            *args: Arguments passed to plot method.
            **kwargs: Keyword arguments passed to plot method.

        Returns:
            A matplotlib Figure with the image, ground truth, and predictions.
        """
        return self.cdl.plot(*args, **kwargs)

In [None]:
datamodule = Sentinel2CDLDataModule(
    sentinel2_crs="epsg:4326",
    cdl_crs="epsg:4326",
    batch_size=64,
    patch_size=224,
    cdl_paths="/data/datatorchgeo",
    sentinel2_paths="/data/sentinel")

In [None]:
datamodule.setup('fit')
cdl_tr_dl = datamodule.train_dataloader()