In [None]:
import os
import json
import math
import numpy as np
import pandas as pd
import rasterio as rio
import tifffile as tif
from collections import defaultdict
from glob import glob
from pathlib import Path
from pydantic import BaseModel, BaseSettings
from typing import List
from datetime import datetime
from tqdm import tqdm

In [None]:
# change the working directory to the root dir, since we are in a subfolder
print(os.getcwd())
root = Path(os.getcwd()).parent.absolute()
print(f"Setting cwd to: {str(root)}")
os.chdir(str(root))

In [None]:
# define some models, this is a surprise tool that will help us later
class Configuration(BaseSettings):
    class Config:
        env_file = '.env'
        env_file_encoding = 'utf-8'
    data_source: Path
    activations_file: str
    summary_file: str
    countries_file: str
    token: str


class ActivationModel(BaseModel):
    title: str
    type: str
    country: str
    delineation: List[str]
    start: datetime
    end: datetime
    lat: float = None
    lon: float = None
    subset: str = None


class SARImage(BaseModel):
    path: Path
    code: str
    shape: tuple
    num_tiles: int

In [None]:
cfg = Configuration()
for k, v in cfg.dict().items():
    print(f"{k:<20s}: {str(v)}")

In [None]:
# load each activation in a pydantic model for simplicity
activations = dict()

with open(cfg.activations_file, "r") as f:
    obj = json.load(f)
    for name, value in obj.items():
        value = {k.lower(): v for k, v in value.items()}
        activations[name] = ActivationModel(**value)


In [None]:
# merge the available images based on the activation groups
activation_groups = defaultdict(list)
for img_file in glob(str(cfg.data_source / "*" / "s1_raw" / "*.tif")):
    img_file = Path(img_file)
    emsr_code = img_file.stem.split("-")[0]
    activation_groups[emsr_code].append(img_file)

print(len(activation_groups))

In [None]:
def get_center(path: Path) -> tuple:
    """
    Returns the coordinates of the center of the given image, reading the GeoTIFF from file.
    """
    with rio.open(str(path), "r") as dataset:
        bbox = dataset.bounds
        center_lon = (bbox.left + bbox.right) / 2
        center_lat = (bbox.top + bbox.bottom) / 2
    return center_lon, center_lat


def smooth_weights(data: np.array, smoothing: float = 0.15, clip: float = 10.0, normalize: bool = True):
    assert smoothing >= 0 and smoothing <= 1, "Smoothing factor out of range"
    if smoothing > 0:
        # the larger the smooth factor, the bigger the quantities to sum to the remaining counts (additive smoothing)
        smoothed_maxval = np.max(data) * smoothing
        data += smoothed_maxval

    # retrieve the (new) max value, divide by counts, round to 2 digits and clip to the given value
    # max / value allows to keep the majority class' weights to 1, while the others will be >= 1 and <= clip
    majority = 1.0 / np.max(data)
    result = np.clip(np.round(data.astype(float)), a_min=0, a_max=clip)
    if normalize:
        result /= result.sum()
    return result


In [None]:
# first, filter out activations not present in the groups
valid_activations = {k: v for k, v in activations.items() if k in activation_groups}
len(valid_activations)

In [None]:
for emsr, image_paths in activation_groups.items():
    # compute centers for each tiff and get average
    centers = [get_center(p) for p in image_paths]
    avg_center = tuple(map(lambda y: sum(y) / float(len(y)), zip(*centers)))
    assert len(avg_center) == 2, "it must be a tuple (lon, lat)!"
    # store the coords in the activation list
    valid_activations[emsr].lon = avg_center[0]
    valid_activations[emsr].lat = avg_center[1]

In [None]:
df = defaultdict(list)

for code, act in valid_activations.items():
    df["code"].append(code)
    df["title"].append(code + " " + act.title)
    df["lon"].append(act.lon)
    df["lat"].append(act.lat)
    df["size"].append(1)
df = pd.DataFrame.from_dict(df)

In [None]:
import plotly.express as px


px.set_mapbox_access_token(cfg.token)
fig = px.scatter_mapbox(df,
                        lat=df.lat,
                        lon=df.lon,
                        size="size",
                        size_max=5,
                        hover_name="title",
                        zoom=1)
fig.show()

## Splitting

In [None]:
sar_files = glob(str(cfg.data_source / "*" / "s1_raw" / "*.tif"))
len(sar_files)

In [None]:
tile_size = 512.0
emsr2sar = list()

for path in tqdm(sar_files):
    code = Path(path).stem.split("-")[0] # extract the EMSR code
    image = tif.imread(path) # read image and estimate tile count
    num_tiles = (math.ceil(image.shape[0] / tile_size) * math.ceil(image.shape[1] / tile_size))
    # create data and store into list
    sample = SARImage(path=Path(path), code=code, shape=image.shape[:-1], num_tiles=num_tiles)
    emsr2sar.append(sample)

In [None]:
emsr2index = {k: i for i, k in enumerate(activation_groups.keys())}
index2emsr = {v: k for k, v in emsr2index.items()}
weights = np.zeros(len(emsr2index))
counts = {code: 0 for code in emsr2index.keys()}

for sample in emsr2sar:
    weights[emsr2index[sample.code]] += sample.num_tiles
    counts[sample.code] += sample.num_tiles

weights = smooth_weights(data=weights, clip=2000)
print(len(weights), weights.sum())

In [None]:
emsr_codes = list(activation_groups.keys())
train_percent = 0.65  # it usually becomes 80% of the data, after tile count
valid_percent = 0.12  # 10% of the train set

train = np.random.choice(emsr_codes, int(len(emsr_codes) * train_percent), replace=False, p=weights)
valid = np.random.choice(train, int(len(train) * valid_percent), replace=False, p=np.ones(len(train)) / len(train))
train = [c for c in train if c not in valid]
test = [c for c in emsr_codes if (c not in train and c not in valid)]

train_tiles = sum([counts[code] for code in train])
valid_tiles = sum([counts[code] for code in valid])
test_tiles = sum([counts[code] for code in test])
total = sum(list(counts.values()))

print(f"train: {train_tiles} ({train_tiles / float(total) * 100}%), val: {valid_tiles} ({valid_tiles / float(total) * 100}%), test: {test_tiles} ({test_tiles / float(total) * 100}%)")

In [None]:
sets = []
for code in df.code:
    if code in train:
        sets.append("train")
    elif code in valid:
        sets.append("val")
    else:
        sets.append("test")
splitdf = df.assign(set=sets, size=[0.1] * len(sets))
print(len(train), len(valid), len(test), len(df))
splitdf.head()

In [None]:
import plotly.express as px

fig = px.scatter_mapbox(splitdf,
                        lat=splitdf.lat,
                        lon=splitdf.lon,
                        color="set",
                        size="size",
                        size_max=5,
                        hover_name="title",
                        zoom=1)
fig.show()

In [None]:
print(train)
print(valid)
print(test)

In [None]:
for code in train:
    valid_activations[code].subset = "train"

for code in test:
    valid_activations[code].subset = "test"

for code in valid:
    valid_activations[code].subset = "val"

for code, act in valid_activations.items():
    act.delineation = [Path(v).name for v in act.delineation]
    assert act.subset is not None and act.subset in ("train", "test", "val")


In [None]:
class DateTimeAwareEncoder(json.JSONEncoder):
    def default(self, v):
        if isinstance(v, datetime):
            return (v.isoformat())
        else:
            return super().default(v)

In [None]:
raw_data = dict()

for k, v in valid_activations.items():
    raw_act = v.dict()
    raw_act["delineations"] = raw_act.pop("delineation")
    raw_data[k] = raw_act

with open(cfg.summary_file, "w") as file:
    file.write(json.dumps(raw_data, indent=4, cls=DateTimeAwareEncoder))
