# Contribute a new Non-Geospatial Dataset

Open-source datasets have significantly accelerated Machine Learning Research. Within the area of Geospatial Machine Learning, the dataset can be singnificantly more complex to handle and load than more standard RGB-based Vision datasets for example. To spare the community from having to repeatly implement the logic over and over, TorchGeo supports dozens of datasets such that they can be downloaded and ready for use in a PyTorch framework within a single line of code. This tutorial will show how you can add a Non-Geospatial Dataset to this growing collection. As a reminder, TorchGeo differentiates between two dataset types: Geospatial and Non-Geospatial datasets. The difference is that Non-Geospatial datasets are integer index based datasets like the datasets one might be familar with from torchvision, while Geospatial datasets are indexed via geospatial coordinate bounding boxes. Non-geospatial datasets can still return geospatial and other metadata and should be specific to the remote-sensing domain. 

## Where to start

If you are interested of an overview of existing geospatial datasets see for example the [Satellite-Image-Deep-Learning](https://github.com/satellite-image-deep-learning/datasets) page that contains a list of datasets and links to other dataset lists. 

Two aspects that will make it a lot easier to add the dataset are whether or not the dataset can be easily downloaded and whether or the dataset comes with a Github repository and publication that outlines how the authors intend the dataset to be used. These are not necessariy criteria, and sometimes it might be even more worth to add a dataset without an existing code base, precisely because the marginal contribution to the community might be greater because a use of the dataset does not necessitate writing the loading implementation from scratch.

## Adding the dataset

Once you have identified a dataset that you would like to add to TorchGeo, you could identify in what application category it might roughly fall in. For example, a segmentation dataset based on a collection of *.png* files, versus a classification dataset based on pre-defined image chips in *.tif* files. In the later case, if you find that the dataset contains *.tif* files that have very large pixel sizes, such that loading a single file might be costly, consider adding the dataset as a `Geospatial` dataset for easier indexing. Once, you have identified the "task" such as segmentation vs classification and the dataset format, see whether a dataset of the same or similar category exists in TorchGeo already. All datasets inherit from a `BaseClass` that provides an outline for the implementation logic as well as additional utility functions that should be reused. This reduces code duplication and makes it easier to unit test datasets.

Adding a dataset to TorchGeo consists of roughly four parts:
- a `dataset_name.py` file itself that implements the logic of the dataset
- a `data.py` file that creates dummy data in the same structure and format as the original dataset for unit tests
- a `test_dataset_name.py` file that implements unit tests for the dataset
- an entry to the documentation page files: `non_geo_datasets.csv` and `datasets.rst`

## The `dataset_name.py` file

This file implements the logic to load a sample from the dataset as well as downloading the dataset automatically if possible. The new dataset inherits from a base class and the documentation string of the class should contain:

- a short summary of the dataset
- outline features, such as the task 
- outline the format the dataset comes in, i.e. file types, pixel dimensions etc.
- a proper reference to the dataset such as a link to the paper so users can adequatly cite the dataset when using it
- if required, a note about additional dependencies that are not part of TorchGeo's dependencies

The dataset implementation itself should contain:

- method to create an index structure the dataset can iterate over to load samples. This index structure also defines the length (`__len__`) of the dataset, i.e. how many individual samples can be loaded from the dataset
- a `__getitem__` method that takes an integer index argument, loads a sample of the dataset, and returns its components in a dictionary
- a `_verify` method that checks whether the dataset can be found on the filesystem, has already been downloaded and only needs to be extracted, or downloads and extracts the dataset from the web
- a `plot` method that can visually display a single sample of the dataset

The code below attempts to roughly outline the parts required for a new `NonGeoDataset`. Specfics are of course very dependent on the type of dataset you want to add, but this template and other existing datasets should give you a decent starting point.

```python
from typing import Any
from torchgeo.datasets import NonGeoDataset
from pathlib import Path
from matplotlib.pyplot import Figure

class MyNewDataset(NonGeoDataset):
    """MyNewDatasets
    
    Small Sumamry of the dataset

    Dataset features:
    * number of classes, different sensors, area covered etc.

    Dataste format:
    * what file format and shape the input data comes in
    * what file format and shape the target data comes in
    * possible metadata files

    Mention publication
    
    .. versionadded: tag number here
    """
    # in this part of the code you can define class attributes such as list of class names, color maps,
    # url and checksums for data download, and other attributes that one might require repeatedly in the
    # subsequent logic
    
    def __init__(self, root: Path, download: bool=False, transform: bool=None) -> None:
        """Initialize the dataset.

        The init arguments can include additional arguments, for example a dedicated split of the data,
        being able to subselect specific modalities or even individual bands, or other arguments that give
        dedicated control over the dataset to run experiments that might come from publication itself or be
        helpful to the community. They should be reasonable defaults.

        Args:
            root: root directory where the dataset is stored
            download: whether to download the dataset if it is not found in the root directory. Defaults to False.
            transform: transformation to apply to the dataset. Defaults to None.
        """
    
    def _load_files(self) -> Any:
        """This method should create a data structure from which one can retrieve individual samples
        and from which the total number of dataset samples can be derived. You can name the method as you like
        but sticking with naming schemes that are common across other implemented datasets is recommended.
        """

    def __len__(self) -> int:
        """Define the length of the dataset, so the total number of samples can be retrieved for the specific setting of arguments
        specified in the __init__ method."""


    def __getitem__(self, index: int) -> dict[str, Any]:
        """Based on an index, return a dictionary with the data and target information for the dataset.

        This might involve separate class methods to load the data and target information etc, but the __getitem__
        decides "how to assemble" a dataset sample.
        """
    
    def plot(self) -> Figure:
        """Plot a sample of the dataset for visualization purposes.
        
        This might involve subselecting RGB bands that can be displayed, displaying class labels in segmentation tasks etc.
        Implemented datasets have already covered a wide range of visualization techniques, so it should be helpful to check
        for similar datasets.
        """
```

## The `data.py` file

The `data.py` file is placed under `tests/data/dataset_name/` directory and creates a smaller dummy dataset that replicates the features and formats of the actual full datasets for unit tests. The script should therefore:

- replicate the directory structure
- replicate the naming scheme of directories and individual files
- replicate roughly the value ranges found in the dataset, for example contain the same number of classes
- use the same compression scheme to simulate downloading the dataset and its checksum for verification

This is usually highly dependent on the dataset format and structure the new dataset comes in. However, again below is an outline of the usual building blocks of a `data.py` script, for example an image segmentation dataset with 10 classes. 

```python

import hashlib
import os
import shutil

import numpy as np
from PIL import Image

# Define the root directory and subdirectories
root_dir = 'my_new_dataset'
sub_dirs = ['sub_dir_1', 'sub_dir_2', 'sub_dir_3']
splits = ['train', 'val', 'test']

image_file_names = [
    'sample_1.png',
    'sample_2.png',
    'sample_3.png',
]

IMG_SIZE = 32

# Function to create dummy input images
def create_input_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None:
    data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8)
    img = Image.fromarray(data)
    img.save(path)

# function to create dummy targets
def create_target_images(split: str, filename: str) -> None:
    target_pixel_values = range(10)
    path = os.path.join(root_dir, 'target', split, filename)
    create_dummy_image(path, (IMG_SIZE, IMG_SIZE), target_pixel_values)

# create a new clean version when rerunning script
if os.path.exists(root_dir):
    shutil.rmtree(root_dir)

# Create the directory structure
for sub_dir in sub_dirs:
    for split in splits:
        os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True)

# Create dummy data for all splits and filenames
for split in splits:
    for filename in zone_file_names:
        create_input_image(split, filename)
        create_target_images(split, filename.replace('_', '_target_'))

# zip directory 
shutil.make_archive(root_dir, 'zip', '.', root_dir)

# compute checksum
def md5(fname: str) -> str:
    hash_md5 = hashlib.md5()
    with open(fname, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), b''):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


md5sum = md5('dummy_data.zip')
print(f'MD5 checksum: {md5sum}')
# add this checksum to the test_dataset_name.py file to mock dataset download if applicable

```

## The `test_dataset_name.py` file

The `test_dataset_name.py` file is placed under the `tests/datasets/` directory. This file implements the unit tests for the dataset, such that every line of code in `dataset_name.py` is tested. The logic of the individual test cases will likely be very similar to existing test files so you can look at those to to see how you can test the individual parts of the dataset logic.


```python

from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import MyNewDataset


def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
    shutil.copy(url, root)


class TestMyNewDataset:
    # pytest fixtures can be used to define variables to test different argument
    # configurations to test, for example the different splits of the dataset
    # or subselection of modalities/bands
    @pytest.fixture(
        params=product(['train', 'val', 'test'])
    )
    def dataset(
        self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
    ) -> MyNewDataset:
        split: str = request.param
        # monkeypatch can overwrite the class attributes defined above the __init__ method
        # and use the specific unit tests settings to mock behavior such as downloading
        monkeypatch.setattr(torchgeo.datasets.my_new_dataset, 'download_url', download_url)

        # for downloads you can also add the specific checksums that the data.py script yields
        # for the dummy data

        root = tmp_path
        transforms = nn.Identity()
        return MyNewDataset(
            root=root, split=split, transforms=transforms, download=True, checksum=True
        )

    def test_getitem(self, dataset: MyNewDataset) -> None:
        # retrieve a sample and check some of the desired properties
        x = dataset[0]
        assert isinstance(x, dict)
        assert isinstance(x['image'], torch.Tensor)
        assert isinstance(x['label'], torch.Tensor)

    # for all the additional class arguments, check what happens if you define invalid parameters
    def test_invalid_split(self) -> None:
        with pytest.raises(AssertionError):
            MyNewDataset(split='foo')
    # for example if you have a list of bands, check what happens if you define invalid bands
    def test_invalid_bands(self) -> None:
        with pytest.raises(ValueError):
            MyNewDataset(bands=('OK', 'BK'))

    # test the length of the dataset, this should coincide with the dummy data created in data.py
    def test_len(self, dataset: MyNewDataset) -> None:
        assert len(dataset) == 2

    # test the logic when the dataset is already downloaded
    def test_already_downloaded(self, dataset: MyNewDataset, tmp_path: Path) -> None:
        MyNewDataset(root=tmp_path, download=True)

    # test the logic when the dataset is already downloaded but not extracted
    def test_already_downloaded_not_extracted(
        self, dataset: MyNewDataset, tmp_path: Path
    ) -> None:
        shutil.rmtree(dataset.root)
        download_url(dataset.url, root=tmp_path)
        MyNewDataset(root=tmp_path, download=False)

    # mock the download function to test the logic when the dataset is not downloaded
    def test_not_downloaded(self, tmp_path: Path) -> None:
        with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
            MyNewDataset(tmp_path)

    # test the plotting method through something like the following
    def test_plot(self, dataset: MyNewDataset) -> None:
        x = dataset[0].copy()
        dataset.plot(x, suptitle='Test')
        plt.close()
        dataset.plot(x, show_titles=False)
        plt.close()
        x['prediction'] = x['label'].clone()
        dataset.plot(x)
        plt.close()
```

## Documentation Entries

The entry point for new and experienced users of domain libraries is often the dedicated documentation page that accompanies a Github repository. TorchGeo uses the popular `sphinx` framework to build its documentation. To display the documentation strings you have written in `dataset_name.py` on the actual documentation page, you need to create an entry in `docs/api/datastes.rst` in alphabetical order:


```rst
Dataset Name
^^^^^^^^^^^^

.. autoclass:: Dataset Class Name
```

Additionally, add a row in the `non_geo_datasets.csv` file under `docs/api/datasets` to include the dataset in the overview table.

## Linters

See the [linter docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#linters) for an overview of linters that TorchGeo employs and how to apply them during commits for example. 

## Test Coverage

TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being called by some unit test. The [testing docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#tests) provide instructions how you can test the coverage locally for the `dataset_new.py` file that you are adding.

## Final Checklist

This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask in the Slack channel or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side.

- Dataset implementation in `dataset_name.py`
    - Class doc string containining:
        - Summary intro
        - Dataset features
        - Dataset format
        - Link to publication
        - `versionadded` tag
        - if applicable a note on additional dependencies
    - all class methods have docstrings
    - all class methods have argument and return type hints, mypy (the tool that checks type hints) can be confusing at the beginning so don't hesitate to ask for help
    - if dataset is on Huggingface, url link should contain the commit hash
    - checksum added
    - plot method that can display a single sample from the dataset (you can add the resulting figure in your PR description)
    - add the dataset to `torchgeo/datastes/__init__`
    - microsoft copy right at top of the file
- Dummy data script `data.py`
    - replicate directory structure
    - replicate naming of directory and files
    - for image based datasets use not the actual pixelsize but smaller extend, commonly we use 32x32
- Unit tests `test_dataset_name.py`
    - 100% test coverage 
- Documentation with `non_geo_datasets.csv` and `datasets.rst`
    - entry in `datasets.rst`
    - entry in `non_geo_datasets.csv`
    - documentation displays properly, this can be checked locally or via the GitHub CI tests under `docs/readthedocs.org:torchgeo`