Skip to content

Commit

Permalink
Add DFC2022 dataset (#354)
Browse files Browse the repository at this point in the history
* add DFC2022 dataset

* plot fix

* mypy fixes

* add tests and tests data

* maximum coverage

* remove local dir

* update per suggestions

* update monkeypatching

* update docstring

* fix indentation in docstring
  • Loading branch information
isaaccorley committed Jan 12, 2022
1 parent ff28a3b commit 45f3703
Show file tree
Hide file tree
Showing 22 changed files with 585 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ CV4A Kenya Crop Type Competition

.. autoclass:: CV4AKenyaCropType

2022 IEEE GRSS Data Fusion Contest (DFC2022)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: DFC2022

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
121 changes: 121 additions & 0 deletions tests/data/dfc2022/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import random
import shutil

import numpy as np
import rasterio

from torchgeo.datasets import DFC2022

SIZE = 32

np.random.seed(0)
random.seed(0)


train_set = [
{
"image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501
"dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
"target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501
},
{
"image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501
"dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
"target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501
},
]

unlabeled_set = [
{
"image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501
"dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
},
{
"image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501
"dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
},
]

val_set = [
{
"image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501
"dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
},
{
"image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501
"dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
},
]


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE

if "float" in profile["dtype"]:
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
else:
Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
)

src = rasterio.open(path, "w", **profile)
for i in range(1, profile["count"] + 1):
src.write(Z, i)


if __name__ == "__main__":
for split in DFC2022.metadata:
directory = DFC2022.metadata[split]["directory"]
filename = DFC2022.metadata[split]["filename"]

# Remove old data
if os.path.isdir(directory):
shutil.rmtree(directory)
if os.path.exists(filename):
os.remove(filename)

if split == "train":
files = train_set
elif split == "train-unlabeled":
files = unlabeled_set
else:
files = val_set

for file_dict in files:
# Create image file
path = file_dict["image"]
os.makedirs(os.path.dirname(path), exist_ok=True)
create_file(path, dtype="uint8", num_channels=3)

# Create DEM file
path = file_dict["dem"]
os.makedirs(os.path.dirname(path), exist_ok=True)
create_file(path, dtype="float32", num_channels=1)

# Create mask file
if split == "train":
path = file_dict["target"]
os.makedirs(os.path.dirname(path), exist_ok=True)
create_file(path, dtype="uint8", num_channels=1)

# Compress data
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory)

# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{filename}: {md5}")
Binary file added tests/data/dfc2022/labeled_train.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/dfc2022/unlabeled_train.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/dfc2022/val.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
96 changes: 96 additions & 0 deletions tests/datasets/test_dfc2022.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import Generator

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import DFC2022


class TestDFC2022:
@pytest.fixture(params=["train", "train-unlabeled", "val"])
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
) -> DFC2022:
monkeypatch.setitem( # type: ignore[attr-defined]
DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90"
)
monkeypatch.setitem( # type: ignore[attr-defined]
DFC2022.metadata["train-unlabeled"],
"md5",
"b2bf3839323d4eae636f198921442945",
)
monkeypatch.setitem( # type: ignore[attr-defined]
DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a"
)
root = os.path.join("tests", "data", "dfc2022")
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return DFC2022(root, split, transforms, checksum=True)

def test_getitem(self, dataset: DFC2022) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 3
assert x["image"].shape[0] == 4

if dataset.split == "train":
assert isinstance(x["mask"], torch.Tensor)
assert x["mask"].ndim == 2

def test_len(self, dataset: DFC2022) -> None:
assert len(dataset) == 2

def test_extract(self, tmp_path: Path) -> None:
shutil.copyfile(
os.path.join("tests", "data", "dfc2022", "labeled_train.zip"),
os.path.join(tmp_path, "labeled_train.zip"),
)
shutil.copyfile(
os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"),
os.path.join(tmp_path, "unlabeled_train.zip"),
)
shutil.copyfile(
os.path.join("tests", "data", "dfc2022", "val.zip"),
os.path.join(tmp_path, "val.zip"),
)
DFC2022(root=str(tmp_path))

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
DFC2022(root=str(tmp_path), checksum=True)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
DFC2022(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
DFC2022(str(tmp_path))

def test_plot(self, dataset: DFC2022) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()

if dataset.split == "train":
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()
del x["mask"]
dataset.plot(x)
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .dfc2022 import DFC2022
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .fair1m import FAIR1M
Expand Down Expand Up @@ -115,6 +116,7 @@
"COWCCounting",
"COWCDetection",
"CV4AKenyaCropType",
"DFC2022",
"ETCI2021",
"EuroSAT",
"FAIR1M",
Expand Down
Loading

0 comments on commit 45f3703

Please sign in to comment.