In [None]:
# default_exp base


In [None]:
# hide
%load_ext autoreload
%autoreload 2

# BigEarthNet Base Functions
> A collection of common function that are applied to BEN.

In [None]:
# export
import functools
import urllib
import warnings
from datetime import datetime
from numbers import Real
from pathlib import Path
from typing import List, Optional, Set, Tuple, Union

import appdirs
import dateutil
import fastcore.all as fc
import geopandas
import pandas as pd
from fastcore.basics import compose
from fastcore.dispatch import typedispatch
from pydantic import AnyHttpUrl, validate_arguments
from shapely.geometry import LineString, Point, Polygon, box

from bigearthnet_common.constants import OLD2NEW_LABELS_DICT


  _pyproj_global_context_initialize()


In [None]:
# export
USER_DIR = Path(appdirs.user_data_dir("bigearthnet"))
USER_DIR.mkdir(exist_ok=True, parents=True)


In [None]:
# hide
from datetime import date

from dateutil.parser import ParserError
from fastcore.test import ExceptionExpected, test_close, test_eq


In [None]:
# export
@validate_arguments
def _download_and_cache_url(url: AnyHttpUrl, force_download: bool = False):
    """
    Simply download contents of url to the default user directory.
    Allow to redownload with `force_download`
    """
    fp = USER_DIR / Path(url).name
    if not fp.exists() or force_download:
        response = urllib.request.urlopen(url).read()
        fp.write_bytes(response)
    return fp


## BEN-JSON helper functions

A couple of safe parsing functions that guarantee flexible reading of the default BigEarthNet json entries.

Usually, you should not use any of these parsing functions directly, but use one of the higher-level functions instead.

In [None]:
# export
def parse_datetime(inp: Union[str, datetime]) -> datetime:
    """
    Parses an input into a `datetime` object.
    Will try its best to infer the correct format from a string.
    If a `datetime` object is already provided it will be returned.
    Otherwise it will raise an error.
    """
    return _parse_datetime(inp)


@typedispatch
def _parse_datetime(acquisition_date: str) -> datetime:
    return compose(dateutil.parser.parse, _parse_datetime)(acquisition_date)


@typedispatch
def _parse_datetime(acquisition_date: datetime) -> datetime:
    return acquisition_date


@typedispatch
def _parse_datetime(acquisition_date: object) -> None:
    raise TypeError("Could not parse acquisition_date!")


In [None]:
d1 = parse_datetime("2017-06-13 10:10:31")
d2 = parse_datetime("13.06.2017 10:10:31")
d3 = parse_datetime(datetime(year=2017, month=6, day=13, hour=10, minute=10, second=31))

test_eq(d1, d2)
test_eq(d2, d3)

with ExceptionExpected(ex=ParserError, regex="format"):
    parse_datetime("large_tile")

with ExceptionExpected(ex=TypeError, regex="parse"):
    parse_datetime(42)

with ExceptionExpected(ex=TypeError, regex="parse"):
    parse_datetime(date(year=2017, month=10, day=1))


In [None]:
# export
def _get_box_from_two_coords(p1: Tuple[Real, Real], p2: Tuple[Real, Real]) -> Polygon:
    """
    Get the polygon that bounds the two coordinates.
    These values should be supplied as numerical values.
    """
    get_bounds = lambda geom: geom.bounds
    box_from_bounds = lambda bounds: box(*bounds)
    return compose(LineString, get_bounds, box_from_bounds)([p1, p2])


In [None]:
# hide
box1 = _get_box_from_two_coords([0, 0], [2, 2])
box2 = _get_box_from_two_coords([2, 2], [0, 0])
box3 = _get_box_from_two_coords([0, 2], [2, 0])
test_eq(box1, box2)
test_eq(box1, box3)


In [None]:
# export
def box_from_ul_lr_coords(ulx: Real, uly: Real, lrx: Real, lry: Real) -> Polygon:
    """
    Build a box (`Polygon`) from upper left x/y and lower right x/y coordinates.

    This specification is the default BigEarthNet style.
    """
    return _get_box_from_two_coords([ulx, uly], [lrx, lry])


In [None]:
b1 = box_from_ul_lr_coords(ulx=0, uly=4, lrx=4, lry=0)
b2 = Polygon([[0, 0], [0, 4], [4, 4], [4, 0], [0, 0]])
assert isinstance(b1, Polygon)
assert b1.equals(b2)


In [None]:
# hide
# example of how to use geopandas to read in weirdly shaped data
import geopandas
from shapely.geometry import box, Point

# CRS with easting/northing input, i.e. no input axis-swap required
wkt_crs = 'PROJCS["WGS 84 / UTM zone 34N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",21],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32634"]]'
x1, y1, x2, y2 = 499980, 7046040, 501180, 7044840
shapes = [_get_box_from_two_coords([x1, y1], [x2, y2])]

s = geopandas.GeoSeries(shapes, crs=wkt_crs)
assert s.is_valid.all()


In [None]:
# hide

# crs with northing/easting input, i.e. input axis-swap required
north_east_crs = "epsg:2953"
enc_point = Point(1099489.55, 9665176.75)
tfm_points = geopandas.GeoSeries([enc_point], crs=north_east_crs).to_crs(epsg="4326")
long, lat = tfm_points.x[0], tfm_points.y[0]

# _golden values_ from http://epsg.io/
# http://epsg.io/transform#s_srs=2953&t_srs=4326&x=1099489.55&y=9665176.75
ref_long, ref_lat = (-94.375, 63.25)
test_close([long, lat], [ref_long, ref_lat], eps=0.1)


CRSError: Invalid projection: epsg:2953: (Internal Proj Error: proj_create: no database context specified)

## Common BEN patch checks and transformations

These tiny functions mainly allow the user to write cleaner code by importing them instead of writing lambda functions.
All functions use caches to guarantee fast lookups, so feel free to use them on large data.

The most relevant functions are:
- check if patch name in cloud/snow collection
    - `is_snowy_patch`
    - `is_cloudy_shadowy_patch`
- Retrieve the original split by looking up the patch name
    - `get_original_split_from_patch_name`
- Convert the old 43-label nomenclature to the new 19-label variant
    - `old2new_labels`

In [None]:
# export
# PATCHES_WITH_SNOW_URL = "http://bigearth.net/static/documents/patches_with_seasonal_snow.csv"
# PATCHES_WITH_CLOUD_AND_SHADOW_URL = "http://bigearth.net/static/documents/get_patches_with_cloud_and_shadow.csv"
PATCHES_WITH_SNOW_URL = "https://git.tu-berlin.de/k.clasen/ben-mirror/-/raw/master/patches_with_seasonal_snow.csv"
PATCHES_WITH_CLOUD_AND_SHADOW_URL = "https://git.tu-berlin.de/k.clasen/ben-mirror/-/raw/master/patches_with_cloud_and_shadow.csv"


In [None]:
# export
@validate_arguments
def _conv_single_col_csv_to_set(
    url: AnyHttpUrl, name: str = "Name", force_download: bool = False
) -> Set:
    """
    Given a url to a CSV file *without* a header
    line and only a single column, return the set of
    all values.

    Will write remote csv to disk for better performance.
    Set `force_download` to re-download the file.
    """
    fp = _download_and_cache_url(url, force_download=force_download)
    series = pd.read_csv(
        fp,
        names=[name],
        squeeze=True,
    )
    return set(series.values)


@functools.lru_cache()
def get_patches_with_seasonal_snow(force_download: bool = False) -> Set:
    """List all patches with seasonal snow from **original** BigEarthNet dataset."""
    return _conv_single_col_csv_to_set(
        PATCHES_WITH_SNOW_URL, force_download=force_download
    )


@functools.lru_cache()
def get_patches_with_cloud_and_shadow(force_download: bool = False) -> Set:
    """List all patches with cloud and shadow from **original** BigEarthNet dataset."""
    return _conv_single_col_csv_to_set(
        PATCHES_WITH_CLOUD_AND_SHADOW_URL, force_download=force_download
    )


In [None]:
snow_patches = get_patches_with_seasonal_snow()
assert len(snow_patches) == 61_707

cloud_and_shadow_patches = get_patches_with_cloud_and_shadow()
assert len(cloud_and_shadow_patches) == 9_280


In [None]:
# export
@validate_arguments
def is_snowy_patch(patch_name: str):
    """
    Fast function that checks whether `patch_name` is a patch
    that contains a lot of seasonal snow.
    """
    return patch_name in get_patches_with_seasonal_snow()


@validate_arguments
def is_cloudy_shadowy_patch(patch_name: str):
    """
    Fast function that checks whether `patch_name` is a patch
    that contains a lot of shadow or is obstructed by clouds.
    """
    return patch_name in get_patches_with_cloud_and_shadow()


In [None]:
# hide

# doesn't check if name is sensible
assert is_snowy_patch("hello") == False
assert is_cloudy_shadowy_patch("hello") == False

assert is_snowy_patch("S2A_MSIL2A_20180205T100211_2_0") == True


In [None]:
# export
@functools.lru_cache()
@validate_arguments
def patches_from_original_train_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv",
    force_download: bool = False,
) -> Set:
    """
    List all train patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv
    2. https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


@functools.lru_cache()
@validate_arguments
def patches_from_original_validation_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/val.csv",
    force_download: bool = False,
) -> Set:
    """
    List all validation patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/train.csv"
    2. "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv"

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


@functools.lru_cache()
@validate_arguments
def patches_from_original_test_split(
    split_url: AnyHttpUrl = "https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/test.csv",
    force_download: bool = False,
) -> Set:
    """
    List all test patches from the original train/validation/test split.
    There are two possible sources:

    1. https://git.tu-berlin.de/rsim/BigEarthNet-S2_19-classes_models/-/raw/master/splits/test.csv
    2. https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/test.csv

    While writing this function, there is **no** difference between these two splits.
    But this may change in the future!
    """
    return _conv_single_col_csv_to_set(split_url, force_download=force_download)


In [None]:
# hide
train1 = patches_from_original_train_split()
train2 = patches_from_original_train_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/train.csv"
)
assert len(train1 - train2) == 0

val1 = patches_from_original_validation_split()
val2 = patches_from_original_validation_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/val.csv"
)
assert len(val1 - val2) == 0


test1 = patches_from_original_test_split()
test2 = patches_from_original_test_split(
    "https://git.tu-berlin.de/rsim/BigEarthNet-S2_43-classes_models/-/raw/master/splits/test.csv"
)
assert len(test1 - test2) == 0

assert len(test1) < len(train1)
assert len(val1) < len(train1)


In [None]:
# export
@validate_arguments
def get_original_split_from_patch_name(patch: str) -> Optional[str]:
    """
    Returns "train"/"validation"/"test" or `None`.
    The value is retrieved from the original BigEarthNet
    train/validation/test split. If the input is not present
    in any split, it will return `None` and raise a UserWarning.
    This happens for patches that are either in the
    cloud/shadow or seasonal snow set or if there exists no 19-label target.

    The splits are from the 19-classes version.
    While writing this function there was no difference between the
    19-classes and the 43-classes version.
    """
    train = patches_from_original_train_split()
    validation = patches_from_original_validation_split()
    test = patches_from_original_test_split()

    if patch in train:
        return "train"
    elif patch in validation:
        return "validation"
    elif patch in test:
        return "test"
    warnings.warn(
        "Provided an input patch name which was not part of the original split.",
        UserWarning,
    )
    return None


In [None]:
assert "train" == get_original_split_from_patch_name("S2A_MSIL2A_20170717T113321_28_87")
assert "validation" == get_original_split_from_patch_name(
    "S2B_MSIL2A_20170812T092029_75_6"
)
assert "test" == get_original_split_from_patch_name("S2A_MSIL2A_20170717T113321_28_88")


In [None]:
# hide
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    get_original_split_from_patch_name("WrongName")
    assert len(w) == 1
    assert issubclass(w[-1].category, UserWarning)


In [None]:
# export
@validate_arguments
def _old2new_label(old_label: str) -> Optional[str]:
    """
    Converts old-style BigEearthNet label to the
    new labels.

    > Note: Some labels were removed! This function
    will return `None` if the label was removed and
    raise a `KeyError` if the input label is unknown.
    """
    return OLD2NEW_LABELS_DICT[old_label]


def old2new_labels(old_labels: List[str]) -> Optional[List[str]]:
    """
    Converts a list of old-style BigEarthNet labels
    to a list of labels.

    If there are no corresponding new labels (which can happen with original BEN patches!)
    then the function will return `None` and raise a user warning.

    If an illegal/unknown input label is provided, a `KeyError` is raised.
    """
    new_labels = [
        _old2new_label(l) for l in old_labels if _old2new_label(l) is not None
    ]
    if len(old_labels) > 0 and len(new_labels) == 0:
        warnings.warn(
            "Provided a list of old labels that only contains `removed` labels!",
            UserWarning,
        )
        new_labels = None
    return new_labels


> Warning: Some of the original 43-class nomenclature patches have 0 labels with the 19-class nomenclature! This function might return `None` instead of an empty list!

In [None]:
# hide
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    old2new_labels(["Burnt areas"])
    assert len(w) == 1
    assert issubclass(w[-1].category, UserWarning)

with fc.ExceptionExpected(ex=KeyError):
    old2new_labels(["Illegal input label"])

fc.test_eq(
    old2new_labels(
        [
            "Continuous urban fabric",
            "Discontinuous urban fabric",
        ]
    ),
    ["Urban fabric", "Urban fabric"],
)


NameError: name 'old2new_labels' is not defined

In [None]:
# hide
from nbdev.cli import nbdev_build_docs
from nbdev.export import notebook2script

notebook2script()
nbdev_build_docs()


Converted 01a_constants.ipynb.
Converted 01b_base.ipynb.
Converted 01c_gdf_builder.ipynb.
Converted 01d_subset_builder.ipynb.
Converted index.ipynb.
converting: /home/kai/git/bigearthnet_common/nbs/01b_base.ipynb
converting /home/kai/git/bigearthnet_common/nbs/index.ipynb to README.md
