In [None]:
# default_exp base


In [None]:
# hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# BigEarthNet Base Functions
> A collection of common function that are applied to BEN.

In [None]:
# export
import json
import functools
import urllib
import warnings
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Set, Union, Sequence, Dict
import csv

import appdirs
import dateutil
import fastcore.all as fc
from fastcore.basics import compose
from fastcore.dispatch import typedispatch
from pydantic import AnyHttpUrl, validate_arguments, FilePath, DirectoryPath

import bigearthnet_common.constants as ben_constants
from bigearthnet_common.constants import OLD2NEW_LABELS_DICT


In [None]:
# export
USER_DIR = Path(appdirs.user_data_dir("bigearthnet"))
USER_DIR.mkdir(exist_ok=True, parents=True)


In [None]:
# hide
from datetime import date

from dateutil.parser import ParserError
from fastcore.test import ExceptionExpected, test_close, test_eq


In [None]:
# export
@validate_arguments
def _download_and_cache_url(url: AnyHttpUrl, force_download: bool = False):
    """
    Simply download contents of url to the default user directory.
    Allow to redownload with `force_download`
    """
    fp = USER_DIR / Path(url).name
    if not fp.exists() or force_download:
        response = urllib.request.urlopen(url).read()
        fp.write_bytes(response)
    return fp


## Safe JSON Parsing

A couple of safe parsing functions that guarantee flexible reading of the default BigEarthNet json entries.

Usually, you should not use any of these parsing functions directly, but use one of the higher-level functions instead.

In [None]:
# export
def parse_datetime(inp: Union[str, datetime]) -> datetime:
    """
    Parses an input into a `datetime` object.
    Will try its best to infer the correct format from a string.
    If a `datetime` object is already provided it will be returned.
    Otherwise it will raise an error.
    """
    return _parse_datetime(inp)


@typedispatch
def _parse_datetime(acquisition_date: str) -> datetime:
    return compose(dateutil.parser.parse, _parse_datetime)(acquisition_date)


@typedispatch
def _parse_datetime(acquisition_date: datetime) -> datetime:
    return acquisition_date


@typedispatch
def _parse_datetime(acquisition_date: object) -> None:
    raise TypeError("Could not parse acquisition_date!")


In [None]:
d1 = parse_datetime("2017-06-13 10:10:31")
d2 = parse_datetime("13.06.2017 10:10:31")
d3 = parse_datetime(datetime(year=2017, month=6, day=13, hour=10, minute=10, second=31))

test_eq(d1, d2)
test_eq(d2, d3)

with ExceptionExpected(ex=ParserError, regex="format"):
    parse_datetime("large_tile")

with ExceptionExpected(ex=TypeError, regex="parse"):
    parse_datetime(42)

with ExceptionExpected(ex=TypeError, regex="parse"):
    parse_datetime(date(year=2017, month=10, day=1))


In [None]:
# export
@validate_arguments
def _read_json(json_fp: FilePath, expected_keys: Set, read_only_expected: bool = True) -> Dict[str, str]:
    """
    Parse the json file given with the file path `json_fp`.
    The function checks if all of the `expected_keys` are present, which
    ensures that no keys have been accidentilly deleted (this has happend before).
    If `read_only_expected` is set, only the keys provided in `expected_keys` are read
    and returned.
    This prevents accidental processing of injected metadata.

    Args:
        json_fp (FilePath): Path to json file
        expected_keys (Set): Keys that are expected to be present in the json file
        read_only_expected (bool, optional): Read only the keys given in `expected_keys`. Defaults to True.

    Returns:
        [Dict[str, str]]: A dictionary of the keys. 
    """
    try:
        complete_data = json.loads(json_fp.read_text())
    except json.JSONDecodeError:
        raise ValueError(f"Error trying to read json from: ", json_fp)
    
    missing_elements = expected_keys - complete_data.keys()
    if len(missing_elements) > 0:
        raise ValueError(f"{json_fp} is missing entries!", missing_elements)

    # ensure that the original values are loaded, as some users may customize the original json files
    if read_only_expected:
        return {k: v for k, v in complete_data.items() if k in expected_keys}
    return complete_data


def read_S1_json(json_fp: FilePath) -> Dict[str, str]:
    """
    A helper function that *safely* reads a BigEarthNet-S1 json file.
    It will ensure that all expected entries are present and only read those
    entries.
    This helps to avoid issues where the JSON files were accidentally modified
    or partially deleted.

    Note: This function will also silently fix a typo present in the `coordinates` key
    from version: S1_v1.0. A coordinates key is named `lly` and it should be `lry`.
    """
    data = _read_json(json_fp, ben_constants.BEN_S1_V1_0_JSON_KEYS)
    # Silently fix key error in S1
    if "lly" in data["coordinates"]:
        data["coordinates"]["lry"] = data["coordinates"].pop("lly")
    return data

def read_S2_json(json_fp: FilePath) -> Dict[str, str]:
    """
    A helper function that *safely* reads a BigEarthNet-S1 json file.
    It will ensure that all expected entries are present and only read those
    entries.
    This helps to avoid issues where the JSON files were accidentally modified
    or partially deleted.
    """
    return _read_json(json_fp, ben_constants.BEN_S2_V1_0_JSON_KEYS)

In [None]:
# hide
s2_json_path = "S2_json_only/S2A_MSIL2A_20170617T113321_4_55/S2A_MSIL2A_20170617T113321_4_55_labels_metadata.json"
s2_data = read_S2_json(s2_json_path)
assert all(k in ben_constants.BEN_S2_V1_0_JSON_KEYS for k in s2_data)
assert len(s2_data) == len(ben_constants.BEN_S2_V1_0_JSON_KEYS)

s1_json_path = "S1_json_only/S1A_IW_GRDH_1SDV_20170613T165043_33UUP_61_39/S1A_IW_GRDH_1SDV_20170613T165043_33UUP_61_39_labels_metadata.json"
s1_data = read_S1_json(s1_json_path)
assert all(k in ben_constants.BEN_S1_V1_0_JSON_KEYS for k in s1_data)
assert len(s1_data) == len(ben_constants.BEN_S1_V1_0_JSON_KEYS)

with ExceptionExpected(ValueError, "missing entries"):
    read_S2_json(s1_json_path)

with ExceptionExpected(ValueError, "missing entries"):
    read_S1_json(s2_json_path)

## Common BEN patch checks and transformations

To quickly filter a list of directories and ensure that only Sentinel directories are accessed, use:


In [None]:
# export
@validate_arguments
def get_s2_patch_directories(dir_path: DirectoryPath) -> List[Path]:
    """
    Will find all S2 patch directories in the provided `dir_path`.
    Only directories that strictly cohere to the naming convention will be returned.
    """
    return [p for p in dir_path.iterdir() if ben_constants.BEN_S2_RE.fullmatch(p.name) is not None]

@validate_arguments
def get_s1_patch_directories(dir_path: DirectoryPath) -> List[Path]:
    """
    Will find all S1 patch directories in the provided `dir_path`.
    Only directories that strictly cohere to the naming convention will be returned.
    """
    return [p for p in dir_path.iterdir() if ben_constants.BEN_S1_RE.fullmatch(p.name) is not None]

In [None]:
# hide

s2_dir = "S2_json_only"
s1_dir = "S1_json_only"

assert len(get_s2_patch_directories(s2_dir)) == 2
assert len(get_s2_patch_directories(s1_dir)) == 0

assert len(get_s1_patch_directories(s1_dir)) == 2
assert len(get_s1_patch_directories(s2_dir)) == 0

The following functions mainly allow the user to write cleaner code by importing them instead of writing lambda functions.
All functions use caches to guarantee fast lookups, so feel free to use them on large data.

The most relevant functions are:
- check if patch name in cloud/snow collection
    - `is_snowy_patch`
    - `is_cloudy_shadowy_patch`
- Retrieve the original split by looking up the patch name
    - `get_original_split_from_patch_name`
- Convert the old 43-label nomenclature to the new 19-label variant
    - `old2new_labels`

In [None]:
# export
# PATCHES_WITH_SNOW_URL = "http://bigearth.net/static/documents/patches_with_seasonal_snow.csv"
# PATCHES_WITH_CLOUD_AND_SHADOW_URL = "http://bigearth.net/static/documents/get_patches_with_cloud_and_shadow.csv"
PATCHES_WITH_SNOW_URL = "https://git.tu-berlin.de/k.clasen/ben-mirror/-/raw/master/patches_with_seasonal_snow.csv"
PATCHES_WITH_CLOUD_AND_SHADOW_URL = "https://git.tu-berlin.de/k.clasen/ben-mirror/-/raw/master/patches_with_cloud_and_shadow.csv"


In [None]:
# export
@validate_arguments
def _conv_single_col_csv_to_set(
    url: AnyHttpUrl, name: str = "Name", force_download: bool = False
) -> Set:
    """
    Given a url to a CSV file *without* a header
    line and only a single column, return the set of
    all values.

    Will write remote csv to disk for better performance.
    Set `force_download` to re-download the file.
    """
    fp = _download_and_cache_url(url, force_download=force_download)
    with open(fp, mode="r") as csv_file:
        reader = csv.DictReader(csv_file, fieldnames=[name])
        return {row[name] for row in reader}

@functools.lru_cache()
def get_patches_with_seasonal_snow(force_download: bool = False) -> Set:
    """List all patches with seasonal snow from **original** BigEarthNet dataset."""
    return _conv_single_col_csv_to_set(
        PATCHES_WITH_SNOW_URL, force_download=force_download
    )


@functools.lru_cache()
def get_patches_with_cloud_and_shadow(force_download: bool = False) -> Set:
    """List all patches with cloud and shadow from **original** BigEarthNet dataset."""
    return _conv_single_col_csv_to_set(
        PATCHES_WITH_CLOUD_AND_SHADOW_URL, force_download=force_download
    )


In [None]:
snow_patches = get_patches_with_seasonal_snow()
assert len(snow_patches) == ben_constants.BEN_SNOWY_PATCHES_COUNT

cloud_and_shadow_patches = get_patches_with_cloud_and_shadow()
assert len(cloud_and_shadow_patches) == ben_constants.BEN_CLOUDY_OR_SHADOWY_PATCHES_COUNT


In [None]:
# export
@validate_arguments
def is_snowy_patch(patch_name: str):
    """
    Fast function that checks whether `patch_name` is a patch
    that contains a lot of seasonal snow.
    """
    return patch_name in get_patches_with_seasonal_snow()


@validate_arguments
def is_cloudy_shadowy_patch(patch_name: str):
    """
    Fast function that checks whether `patch_name` is a patch
    that contains a lot of shadow or is obstructed by clouds.
    """
    return patch_name in get_patches_with_cloud_and_shadow()


In [None]:
# hide

# doesn't check if name is sensible
assert is_snowy_patch("hello") == False
assert is_cloudy_shadowy_patch("hello") == False

assert is_snowy_patch("S2A_MSIL2A_20180205T100211_2_0") == True


In [None]:
# export
@functools.lru_cache()
@validate_arguments
def patches_from_original_train_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv",
    force_download: bool = False,
) -> Set:
    """
    List all train patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv
    2. https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


@functools.lru_cache()
@validate_arguments
def patches_from_original_validation_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/val.csv",
    force_download: bool = False,
) -> Set:
    """
    List all validation patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv"
    2. "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv"

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


@functools.lru_cache()
@validate_arguments
def patches_from_original_test_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/test.csv",
    force_download: bool = False,
) -> Set:
    """
    List all test patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/test.csv
    2. https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/test.csv

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


In [None]:
# hide
train1 = patches_from_original_train_split()
train2 = patches_from_original_train_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv"
)
assert len(train1 - train2) == 0

val1 = patches_from_original_validation_split()
val2 = patches_from_original_validation_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/val.csv"
)
assert len(val1 - val2) == 0


test1 = patches_from_original_test_split()
test2 = patches_from_original_test_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/test.csv"
)
assert len(test1 - test2) == 0

assert len(test1) < len(train1)
assert len(val1) < len(train1)


In [None]:
# export
@validate_arguments
def get_original_split_from_patch_name(patch: str) -> Optional[str]:
    """
    Returns "train"/"validation"/"test" or `None`.
    The value is retrieved from the original BigEarthNet
    train/validation/test split. If the input is not present
    in any split, it will return `None` and raise a UserWarning.
    This happens for patches that are either in the
    cloud/shadow or seasonal snow set or if there exists no 19-label target.

    The splits are from the 19-classes version.
    While writing this function there was no difference between the
    19-classes and the 43-classes version.
    """
    train = patches_from_original_train_split()
    validation = patches_from_original_validation_split()
    test = patches_from_original_test_split()

    if patch in train:
        return "train"
    elif patch in validation:
        return "validation"
    elif patch in test:
        return "test"
    warnings.warn(
        "Provided an input patch name which was not part of the original split.",
        UserWarning,
    )
    return None


In [None]:
assert "train" == get_original_split_from_patch_name("S2A_MSIL2A_20170717T113321_28_87")
assert "validation" == get_original_split_from_patch_name(
    "S2B_MSIL2A_20170812T092029_75_6"
)
assert "test" == get_original_split_from_patch_name("S2A_MSIL2A_20170717T113321_28_88")


In [None]:
# hide
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    get_original_split_from_patch_name("WrongName")
    assert len(w) == 1
    assert issubclass(w[-1].category, UserWarning)


In [None]:
# export
@validate_arguments
def _old2new_label(old_label: str) -> Optional[str]:
    """
    Converts old-style BigEearthNet label to the
    new labels.

    > Note: Some labels were removed! This function
    will return `None` if the label was removed and
    raise a `KeyError` if the input label is unknown.
    """
    return OLD2NEW_LABELS_DICT[old_label]


def old2new_labels(old_labels: Sequence[str]) -> Optional[List[str]]:
    """
    Converts a list of old-style BigEarthNet labels
    to a list of labels.

    If there are no corresponding new labels (which can happen with original BEN patches!)
    then the function will return `None` and raise a user warning.

    If an illegal/unknown input label is provided, a `KeyError` is raised.
    """
    new_labels = [
        _old2new_label(l) for l in old_labels if _old2new_label(l) is not None
    ]
    if len(old_labels) > 0 and len(new_labels) == 0:
        warnings.warn(
            "Provided a list of old labels that only contains `removed` labels!",
            UserWarning,
        )
        new_labels = None
    return new_labels


> Warning: Some of the original 43-class nomenclature patches have 0 labels with the 19-class nomenclature! This function might return `None` instead of an empty list!

In [None]:
# hide
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    old2new_labels(["Burnt areas"])
    assert len(w) == 1
    assert issubclass(w[-1].category, UserWarning)

old2new_labels(("Burnt areas",))
assert len(w) == 1

with fc.ExceptionExpected(ex=KeyError):
    old2new_labels(["Illegal input label"])

fc.test_eq(
    old2new_labels(
        [
            "Continuous urban fabric",
            "Discontinuous urban fabric",
        ]
    ),
    ["Urban fabric", "Urban fabric"],
)




In [None]:
# export
@validate_arguments
def ben_19_labels_to_multi_hot(labels: Sequence[str]) -> List[float]:
    """
    Convenience function that converts an input sequence of labels into
    a multi-hot encoded vector.
    The naturally ordered label list is used as an encoder reference
    - `bigearthnet_common.NEW_LABELS`

    If an unknown label is given, a `KeyError` is raised.
    
    Be aware that this approach assumes that **all** labels are actually used in the dataset!
    This is not necessarily the case if you are using a subset!
    For example, the "Agro-forestry areas" class is only present in Portugal and in no other country!
    """
    idxs = [ben_constants.NEW_LABELS_TO_IDX[label] for label in labels]
    multi_hot = fc.L([0] * len(ben_constants.NEW_LABELS))
    multi_hot[idxs] = 1.
    return list(multi_hot)

@validate_arguments
def ben_43_labels_to_multi_hot(labels: Sequence[str]) -> List[float]:
    """
    Convenience function that converts an input sequence of labels into
    a multi-hot encoded vector.
    The naturally ordered label list is used as an encoder reference
    - `bigearthnet_common.OLD_LABELS`

    If an unknown label is given, a `KeyError` is raised.
    
    Be aware that this approach assumes that **all** labels are actually used in the dataset!
    This is not necessarily the case if you are using a subset!
    For example, the "Agro-forestry areas" class is only present in Portugal and in no other country!
    """
    idxs = [ben_constants.OLD_LABELS_TO_IDX[label] for label in labels]
    multi_hot = fc.L([0] * len(ben_constants.OLD_LABELS))
    multi_hot[idxs] = 1.
    return list(multi_hot)


In [None]:
agro = ben_19_labels_to_multi_hot(("Agro-forestry areas", ))
assert len(agro) == 19
assert agro[0] == 1.0
assert not any(agro[1:])

multi = ben_19_labels_to_multi_hot(("Agro-forestry areas", "Arable land"))
assert multi[0] == 1.0
assert multi[1] == 1.0
assert not any(multi[2:])

with fc.ExceptionExpected(ex=KeyError):
    ben_19_labels_to_multi_hot(["Airports"])

agro43 = ben_43_labels_to_multi_hot(("Agro-forestry areas", ))
assert len(agro43) == 43
assert agro43[0] == 1.0
assert not any(agro43[1:])

multi43 = ben_43_labels_to_multi_hot(("Agro-forestry areas", "Airports"))
assert multi43[0] == 1.0
assert multi43[1] == 1.0
assert not any(multi43[2:])

with fc.ExceptionExpected(ex=KeyError):
    ben_43_labels_to_multi_hot(["Arable land"])

In [None]:
# hide
# TODO: Add a function to transform a S2 to a S1 name and
# vice versa
# I think the best option is to simply return the entire dictionary
# if desired
# This will ensure that the access is very fast and will cost
# some memory but than it is easy to distribute and the dictionary
# can be cached and wrapped around a single mapper

In [None]:
# hide
from nbdev.cli import nbdev_build_docs
from nbdev.export import notebook2script

notebook2script()
nbdev_build_docs()


Converted 01a_constants.ipynb.
Converted 01b_base.ipynb.
Converted index.ipynb.
converting: /home/kai/git/bigearthnet_common/nbs/01b_base.ipynb
converting /home/kai/git/bigearthnet_common/nbs/index.ipynb to README.md
