# Google Colab VM

Notebook to set up an SSH-enabled Google Colab VM.

Remember to select the required runtime type (e.g. GPU) before running the cells below.

In [None]:
!nvidia-smi

## Set up the environment

### Using pip

Use this alternative when running the training script.

In [None]:
def write_requirements_txt():
  import sys
  import tensorflow
  import numpy
  import pandas
  import sklearn
  import skimage
  import cv2
  import six
  import requests
  import folium
  import scipy

  lines0 = [f"tensorflow~={tensorflow.__version__}",
            f"numpy~={numpy.__version__}",
            f"pandas~={pandas.__version__}",
            f"scikit-learn~={sklearn.__version__}",
            f"scikit-image~={skimage.__version__}",
            f"opencv-python~={cv2.__version__}",
            f"six~={six.__version__}",
            f"requests~={requests.__version__}",
            f"folium~={folium.__version__}",
            f"scipy~={scipy.__version__}"]

  lines1 = ["pygeos",
            "rtree",
            "geopandas",
            "rasterio",
            "descartes",
            "tqdm",
            "pydantic",
            "black",
            "pycodestyle",
            "pydocstyle",
            "mypy",
            "gitpython",
            "jinja2",
            "pyyaml==5.4.1",
            "dill",
            "mpld3",
            "typing_extensions",
            "colorama",
            "tabulate",
            ]


  lines2 = ["git+https://github.com/aleju/imgaug.git@0.4.0",
            "git+https://github.com/albumentations-team/albumentations@1.0.0",
            "tensorflow_addons==0.13.0",
            #"albumentations[imgaug]==1.0.0",
            "colab_ssh",
            "kaggle"
            ]

  lines = lines0 + lines1 + lines2
  # Backports
  if sys.version[0:3] == "3.7":
    lines.append("shared-memory38")

  with open("requirements.txt","w") as file:
    file.writelines("\n".join(lines))

write_requirements_txt()

In [None]:
!cat requirements.txt

In [None]:
%%time
!pip install -r requirements.txt

In [None]:
!curl https://rclone.org/install.sh | bash

### Using Conda

The difference is that it also installs `gdal`.


In [None]:
%%time
#!pip install -q condacolab
!pip install condacolab@git+https://github.com/adolfogc/condacolab@temporary-fixes
import os
import condacolab
condacolab.install()

In [None]:
%%writefile environment.yml
name: base
channels:
  - conda-forge
dependencies:
  - geopandas
  - gdal
  - rasterio
  - descartes
  - tqdm
  - pydantic
  - black
  - pycodestyle
  - pydocstyle
  - mypy
  - rclone
  - imgaug
  - ipython
  - gitpython
  - descartes
  - jinja2
  - pyyaml
  - dill
  - pip
  - pip:
    - tensorflow_addons==0.13.0
    - albumentations[imgaug]
    - colab_ssh
    - kaggle
    - mpld3

Install the dependencies using `mamba` which is faster than `conda`.

In [None]:
%%time
!mamba env update -n base -f environment.yml

## Check GPU, TPU, CPU

### CPU

In [None]:
# New in TF 2.5: support for oneDNN
# See: https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md#release-250
# Also see: https://github.com/oneapi-src/oneDNN
%env TF_ENABLE_ONEDNN_OPTS=1

### GPU

In [None]:
# Make sure our Tensorflow installation recognizes the GPU device
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

### TPU

In [None]:
# Alternatively, set up environment to use TPU devices from Colab
# See: https://www.tensorflow.org/guide/tpu
%tensorflow_version 2.x
import tensorflow as tf
COLAB_TPU_RESOLVER = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')  # '' is the special name used in Colab
tf.config.experimental_connect_to_cluster(COLAB_TPU_RESOLVER)
tf.tpu.experimental.initialize_tpu_system(COLAB_TPU_RESOLVER)
print("All devices: ", tf.config.list_logical_devices('TPU'))

# Note: currently we can't use TPUs due to the DS being using a Python generator.
# see: https://github.com/tensorflow/tensorflow/issues/39099
# For doing so, we need first to implement that described in issue #6, but it is not a priority.

## Mount Google Drive


In [None]:
from google.colab import drive, files
drive.mount("/content/gdrive")

In [None]:
BASE_PATH = "/content/gdrive/MyDrive/Thesis"

Create a Symlink to have a shorter path.

In [None]:
!ln -s $BASE_PATH/datasets /content/datasets

In [None]:
!ln -s $BASE_PATH/output /content/output

In [None]:
!ln -s $BASE_PATH/models /content/models

In [None]:
!ln -s $BASE_PATH/other /content/other

## Transfer dataset(s) from ERDA into Google Drive

We keep a mirror of the dataset(s) in Google Drive to access from within the Colab VM.

First, we need to transfer from ERDA into Google Drive (when we did it, it took around ~11 mins for 20 GB). For this, you will need to configure the SFTP credentials in your ERDA account. Also, you will need to create an access token in Google Cloud (the `rclone` client will guide you through the required steps).

Open up the terminal in Google Colab and set up the two `rclone` remotes using

```bash
rclone config
```

After that, copy files between remotes.

```bash
rclone copy -P erda:datasets/sahel gdrive:datasets/sahel
```

We can verify everything was copied.

```bash
rclone size erda:datasets/sahel
rclone size gdrive:datasets/sahel
```

Then, we can backup the `rclone` config to Google Drive.

```bash
# Verify config file path
rclone config file
# Copy to Google Drive
rclone copy /root/.config/rclone/rclone.conf gdrive:rclone.conf
```

In [None]:
# Copy back rclone config
!mkdir -p /root/.config/rclone/ && cp $BASE_PATH/rclone.conf /root/.config/rclone/rclone.conf

It's best to keep the frames in the local disk for training.

In [None]:
!rclone ls erda:datasets

In [None]:
!rclone config file

In [None]:
!cat /root/.config/rclone/rclone.conf

In [None]:
!ls /content/datasets/frames_zips

## Set up the VM

The `colab_ssh` utility allows to create an SSH tunnel to connect directly into the ColabVM.

Note: this also requires installing and setting up Cloudflared in your local machine.

See: [PyPI page](https://pypi.org/project/colab-ssh/) | [Cloudflare One](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation)

In [None]:
import os
import string
import secrets

from colab_ssh import launch_ssh_cloudflared, init_git_cloudflared, init_git
from IPython.core.display import display, HTML

def launch_ssl():
  # References:
  # https://docs.python.org/3/library/secrets.html
  alphabet = string.ascii_letters + string.digits
  password = ''.join(secrets.choice(alphabet) for i in range(64))

  launch_ssh_cloudflared(password=password)
  display(HTML('''<div style="padding:10px;">
                    SSH Password: <input type="password" value="{0}" />
                    <input type="button" value="Copy" onClick="(function () {{ navigator.clipboard.writeText(\'{0}\'); }})();" />
                  </div>'''.format(password)))
  
def clone_repo(cloudflared=True):
  repo_url = os.environ.get("GITHUB_REPO_URL")
  repo_name = repo_url.split("/")[-1]
  if cloudflared:
    init_git_cloudflared(repo_url,
                        personal_token=os.environ["GITHUB_PERSONAL_TOKEN"],
                        branch=os.environ["GITHUB_REPO_BRANCH"],
                        email=os.environ["GIT_USER_EMAIL"],
                        username=os.environ["GIT_USER_NAME"],
                        )
  else:
    init_git(repo_url,
             personal_token=os.environ["GITHUB_PERSONAL_TOKEN"],
             branch=os.environ["GITHUB_REPO_BRANCH"],
             email=os.environ["GIT_USER_EMAIL"],
             username=os.environ["GIT_USER_NAME"],
             )

  return repo_name

The `colab_ssh` tool also has an option to clone a Github repo. Then, using a VS Code SSH session you can edit and add files and run them in the ColabVM and then commit into the repo.

If you just want to clone the Git repository.

In [None]:
GITHUB_REPO_NAME = clone_repo(cloudflared=False)

If you want to run VS Code.

In [None]:
launch_ssl()

In [None]:
GITHUB_REPO_NAME = clone_repo()

Some additonal configuration options for Git.

In [None]:
!git config --global core.editor "code --wait"
!git config --global merge.tool vscode
!git config --global mergetool.vscode.cmd "code --wait $MERGED"
!git config --global pull.rebase true
!cd $GITHUB_REPO_NAME && git branch

## Imports

Some frequently used libraries when prototyping stuff in notebook cells.

In [None]:
import importlib
import pathlib
import subprocess
import random

from pprint import pprint
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import geopandas as gpd
import skimage
import albumentations
import cv2
import rasterio
import rasterio.plot
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.layers.experimental.preprocessing as tf_preprocessing
import shapely.geometry
import skimage.io

from IPython.display import clear_output, display
import mpld3

Import and reload modules from our codebase.

In [None]:
import dlc.tools.db
import dlc.tools.images
import dlc.tools.plots
import dlc.tools.mpld3
import dlc.tools.splits
import dlc.tools.images
import dlc.tools.datasets
import dlc.tools.scalers
import dlc.tools.evaluation
import dlc.models
import dlc.frames.raster
import dlc.frames.creators.base
import dlc.frames.creators.data
import dlc.frames.creators.segmentation
import dlc.frames.creators.density
import dlc.frames.creators.image
import dlc.frames.creators.scalar
import dlc.frames.factory

importlib.reload(dlc.tools.db)
importlib.reload(dlc.tools.images)
importlib.reload(dlc.tools.plots)
importlib.reload(dlc.tools.mpld3)
importlib.reload(dlc.tools.splits)
importlib.reload(dlc.tools.images)
importlib.reload(dlc.tools.datasets)
importlib.reload(dlc.tools.scalers)
importlib.reload(dlc.tools.evaluation)
importlib.reload(dlc.models)
importlib.reload(dlc.frames.raster)
importlib.reload(dlc.frames.creators.base)
importlib.reload(dlc.frames.creators.data)
importlib.reload(dlc.frames.creators.segmentation)
importlib.reload(dlc.frames.creators.density)
importlib.reload(dlc.frames.creators.image)
importlib.reload(dlc.frames.creators.scalar)
importlib.reload(dlc.frames.factory)

If we update remotely, we can reload the code.

In [None]:
!cd $GITHUB_REPO_NAME && git stash && git pull origin develop && git stash pop

## Random generator seeds

In [None]:
RNG_SEEDS = [
  207146141,
  175229620,
  244591021,
  944519419,
  219512321,
  310416901,
  951601665,
  481291421,
  351411481,
  362181641,
]

DS_SPLIT_SEED = 591477907

In [None]:
%env TF_DETERMINISTIC_OPS="1"
%env PYTHONHASHSEED="0"

In [None]:
%env TF_CUDNN_DETERMINISTIC="1"

## Preprocessing

### Database

#### Creating a database

```batch
!export PYTHONPATH="${GITHUB_REPO_NAME}:${PYTHONPATH}" && \
   python3 ${GITHUB_REPO_NAME}/scripts/create_db.py \
   --img_path="/content/datasets/sahara-sahel/StackedImages" \
   --pattern="*.tif" \
   --areas_path="/content/datasets/sahara-sahel/areas.geojson" \
   --polygons_path="/content/datasets/sahara-sahel/polygons.geojson" \
   --output_path="."
```

We consider a database to be composed by a set of three `GeoDataframes`:
* `tiles`: contains information about the raw image tiles.
* `areas`: contains information about the training rectangles.
* `polygons`: contains information about the polygon annotations.

We assume the last two to be given and the first one we can compute from the images contained in a folder tree.

In [None]:
%%time
sahara_sahel_tiles = dlc.tools.db.get_objects_from_images("/content/datasets/sahara-sahel/StackedImages", "*.tif")

#### Loading the database files

We use the dataset source files to load the database:

In [None]:
rwanda = dlc.tools.db.load_database_rwanda(
    tiles="/content/datasets/rwanda/tiles.geojson",
    areas="/content/datasets/rwanda/v1/labeled_rectangles.shp",
    polygons="/content/datasets/rwanda/v1/tree_labels.shp",
    outliers = True,
)

In [None]:
rwanda.to_files("/content/datasets/rwanda")

In [None]:
(len(rwanda.tiles), len(rwanda.tiles.query("n_areas > 0")))

In [None]:
(rwanda.tiles["file_size"].sum(), rwanda.tiles.query("n_areas > 0")["file_size"].sum())

In [None]:
(len(rwanda.areas), len(rwanda.areas.query("n_tiles > 0")), len(rwanda.areas.query("n_polygons == 0")))

In [None]:
(len(rwanda.polygons), len(rwanda.polygons.query("is_orphan == False")))

In [None]:
(len(rwanda.areas),
len(rwanda.areas.query("split == 'training'")),
len(rwanda.areas.query("split == 'test'")))

In [None]:
(len(rwanda.areas.query("n_tiles > 0")), 
len(rwanda.areas.query("split == 'training' and n_tiles > 0")),
len(rwanda.areas.query("split == 'test' and n_tiles > 0")))

In [None]:
(len(rwanda.areas.query("n_polygons == 0")), 
len(rwanda.areas.query("split == 'training' and n_polygons == 0")),
len(rwanda.areas.query("split == 'test' and n_polygons == 0")))

In [None]:
(len(rwanda.polygons), 
len(rwanda.polygons.query("split == 'training'")),
len(rwanda.polygons.query("split == 'test'")))

In [None]:
(len(rwanda.polygons.query("is_orphan == False")), 
len(rwanda.polygons.query("split == 'training' and is_orphan == False")),
len(rwanda.polygons.query("split == 'test' and is_orphan == False")))

We can load a saved databased:

In [None]:
rwanda = dlc.tools.db.load_database(
    tiles="/content/datasets/rwanda/tiles.geojson",
    areas="/content/datasets/rwanda/areas.geojson",
    polygons="/content/datasets/rwanda/polygons.geojson",
)

In [None]:
rwanda.polygons.iloc[0:1].area

In [None]:
rwanda.polygons.crs

In [None]:
rwanda.polygons.iloc[0:1].to_crs("EPSG:6933").area

In [None]:
rwanda.polygons.crs

In [None]:
x = gpd.read_file("/content/datasets/rwanda/fixed-polygons-0_9.geojson")

In [None]:
x.head()

In [None]:
(len(rwanda.polygons), len(rwanda.polygons.query("is_orphan == False")))

For the Sahara/Sahel-Sudan dataset we consolidate the files for the train/test splits into a single one, for convenience. We add a column to the database indicating the corresponding split.

In [None]:
sahara_sahel = dlc.tools.db.load_database_sahara_sahel(
    tiles="/content/datasets/sahara-sahel/tiles.geojson",
    train_areas="/content/datasets/sahara-sahel/trainingData/training4_rectangles_all_lat.shp",
    test_areas="/content/datasets/sahara-sahel/evaluationData/evaluation_area.shp",
    train_polygons="/content/datasets/sahara-sahel/trainingData/training4_28and_29_lat.shp",
    test_polygons="/content/datasets/sahara-sahel/evaluationData/evaluation_polygons.shp",
    test_predicted_polygons="/content/datasets/sahara-sahel/evaluationData/predicted_polygon.shp",
    outliers=True,
)

In [None]:
sahara_sahel.polygons.query("is_outlier == True")

In [None]:
sahara_sahel.to_files("/content/datasets/sahara-sahel")

In [None]:
(len(sahara_sahel.tiles), len(sahara_sahel.tiles.query("n_areas > 0")))

In [None]:
(sahara_sahel.tiles["file_size"].sum(), sahara_sahel.tiles.query("n_areas > 0")["file_size"].sum())

In [None]:
(len(sahara_sahel.tiles.query("region == 'sahara'")), len(sahara_sahel.tiles.query("region == 'sahara' and n_areas > 0")))

In [None]:
(sahara_sahel.tiles.query("region == 'sahara'")["file_size"].sum(), sahara_sahel.tiles.query("region == 'sahara' and n_areas > 0")["file_size"].sum())

In [None]:
(len(sahara_sahel.tiles.query("region == 'sahel'")), len(sahara_sahel.tiles.query("region == 'sahel' and n_areas > 0")))

In [None]:
(sahara_sahel.tiles.query("region == 'sahel'")["file_size"].sum(), sahara_sahel.tiles.query("region == 'sahel' and n_areas > 0")["file_size"].sum())

In [None]:
len(sahara_sahel.polygons), len(sahara_sahel.polygons.query("is_orphan == False"))

We can open the consolidated database:

In [None]:
sahara_sahel = dlc.tools.db.load_database(
    tiles="/content/datasets/sahara-sahel/tiles.geojson",
    areas="/content/datasets/sahara-sahel/areas.geojson",
    polygons="/content/datasets/sahara-sahel/polygons.geojson",
    predicted_polygons="/content/datasets/sahara-sahel/predicted_polygons.geojson",
)

In [None]:
len(sahara_sahel.polygons), len(sahara_sahel.polygons.query("is_orphan == False"))

In [None]:
sahara_sahel.areas.head()

In [None]:
(len(sahara_sahel.areas),
len(sahara_sahel.areas.query("split == 'train'")),
len(sahara_sahel.areas.query("split == 'test'")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahara'")), 
len(sahara_sahel.areas.query("region == 'sahara' and split == 'train'")),
len(sahara_sahel.areas.query("region == 'sahara' and split == 'test'")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahel'")), 
len(sahara_sahel.areas.query("region == 'sahel' and split == 'train'")),
len(sahara_sahel.areas.query("region == 'sahel' and split == 'test'")))

In [None]:
(len(sahara_sahel.areas.query("n_tiles > 0")), 
len(sahara_sahel.areas.query("split == 'train' and n_tiles > 0")),
len(sahara_sahel.areas.query("split == 'test' and n_tiles > 0")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahara' and n_tiles > 0")), 
len(sahara_sahel.areas.query("region == 'sahara' and split == 'train' and n_tiles > 0")),
len(sahara_sahel.areas.query("region == 'sahara' and split == 'test' and n_tiles > 0")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahel' and n_tiles > 0")), 
len(sahara_sahel.areas.query("region == 'sahel' and split == 'train' and n_tiles > 0")),
len(sahara_sahel.areas.query("region == 'sahel' and split == 'test' and n_tiles > 0")))

In [None]:
(len(sahara_sahel.areas.query("n_polygons == 0")), 
len(sahara_sahel.areas.query("split == 'train' and n_polygons == 0")),
len(sahara_sahel.areas.query("split == 'test' and n_polygons == 0")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahara' and n_polygons == 0")), 
len(sahara_sahel.areas.query("region == 'sahara' and split == 'train' and n_polygons == 0")),
len(sahara_sahel.areas.query("region == 'sahara' and split == 'test' and n_polygons == 0")))

In [None]:
(len(sahara_sahel.areas.query("region == 'sahel' and n_polygons == 0")), 
len(sahara_sahel.areas.query("region == 'sahel' and split == 'train' and n_polygons == 0")),
len(sahara_sahel.areas.query("region == 'sahel' and split == 'test' and n_polygons == 0")))

In [None]:
sahara_sahel.polygons.head()

In [None]:
(len(sahara_sahel.polygons), 
len(sahara_sahel.polygons.query("split == 'train'")),
len(sahara_sahel.polygons.query("split == 'test'")))

In [None]:
(len(sahara_sahel.polygons.query("is_orphan == False")), 
len(sahara_sahel.polygons.query("split == 'train' and is_orphan == False")),
len(sahara_sahel.polygons.query("split == 'test' and is_orphan == False")))

In [None]:
(len(sahara_sahel.polygons.query("region == 'sahara'")), 
len(sahara_sahel.polygons.query("region == 'sahara' and split == 'train'")),
len(sahara_sahel.polygons.query("region == 'sahara' and split == 'test'")))

In [None]:
(len(sahara_sahel.polygons.query("region == 'sahel'")), 
len(sahara_sahel.polygons.query("region == 'sahel' and split == 'train'")),
len(sahara_sahel.polygons.query("region == 'sahel' and split == 'test'")))

In [None]:
(len(sahara_sahel.polygons.query("region == 'sahara' and is_orphan == False")), 
len(sahara_sahel.polygons.query("region == 'sahara' and split == 'train' and is_orphan == False")),
len(sahara_sahel.polygons.query("region == 'sahara' and split == 'test' and is_orphan == False")))

In [None]:
(len(sahara_sahel.polygons.query("region == 'sahel' and is_orphan == False")), 
len(sahara_sahel.polygons.query("region == 'sahel' and split == 'train' and is_orphan == False")),
len(sahara_sahel.polygons.query("region == 'sahel' and split == 'test' and is_orphan == False")))

#### Analyzing the area and convexity of polygons

In [None]:
labeled_areas = dict(
    sahara=sahara_sahel.areas.query("region == 'sahara'"),
    sahel=sahara_sahel.areas.query("region == 'sahel'"),
    rwanda=rwanda.areas,
)

In [None]:
labeled_polygons = dict(
    sahara=sahara_sahel.polygons.query("region == 'sahara'"),
    sahel=sahara_sahel.polygons.query("region == 'sahel'"),
    rwanda=rwanda.polygons,
)

In [None]:
!mkdir figures

In [None]:
!zip -r -j figures.zip figures

In [None]:
from google.colab import files
files.download("figures.zip")

In [None]:
labeled_polygons = dict(
    sahara_sahel=sahara_sahel.polygons.query("is_orphan == False"),
    sahel=sahara_sahel.polygons.query("region == 'sahel' and is_orphan == False"),
    sahara=sahara_sahel.polygons.query("region == 'sahara' and is_orphan == False"),
    rwanda=rwanda.polygons.query("is_orphan == False"),
)

In [None]:
for label, polygons in labeled_polygons.items():
  fig = dlc.tools.plots.plot_1d_histogram(polygons, key="size", label="Area", log = True, fontsize=18)
  fig.savefig(f"figures/area-histogram-{label}.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
for label, polygons in labeled_polygons.items():
  print(label, dlc.tools.db.get_nonoutlier_range(polygons, key="size", range_min=0.0))
  fig = dlc.tools.plots.plot_box(polygons, key="size", label=f"Area", fontsize=18)
  fig.savefig(f"figures/area-non-outliers-{label}.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
for label, polygons in labeled_polygons.items():
  print(label, dlc.tools.db.get_nonoutlier_range(polygons, key="convexity_measure", range_max=1.0))
  fig = dlc.tools.plots.plot_1d_histogram(polygons, key="convexity_measure", label="Convexity measure", log = True, fontsize=18)
  fig.savefig(f"figures/convexity-histogram-{label}.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
labeled_polygons["sahara_sahel"]["size"].describe()

In [None]:
labeled_polygons["sahara_sahel"]["convexity_measure"].describe()

In [None]:
labeled_polygons["sahara"]["size"].describe()

In [None]:
labeled_polygons["sahara"]["convexity_measure"].describe()

In [None]:
labeled_polygons["sahel"]["size"].describe()

In [None]:
labeled_polygons["sahel"]["convexity_measure"].describe()

In [None]:
labeled_polygons["rwanda"]["size"].describe()

In [None]:
labeled_polygons["rwanda"]["convexity_measure"].describe()

In [None]:
for label, polygons in labeled_polygons.items():
  fig = dlc.tools.plots.plot_box(polygons, key="convexity_measure", label=f"Convexity measure", fontsize=18)
  fig.savefig(f"figures/convexity-non-outliers-{label}.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
polygon_examples = dict(
    sahara=[47640, 89541, 47677, 44904, 50036, 89557, 50238, 50208],
    sahel=[33170, 24286, 30010, 24151, 60232, 80417, 89837, 89935],
    rwanda=[6437, 8862, 44228, 69142, 69346, 69545, 69608, 69739],
)

In [None]:
for label, ids in polygon_examples.items():
  fig, axs = plt.subplots(2, 4, figsize=(6, 4))
  for id, ax in zip(ids, axs.flat):
    polygon = labeled_polygons[label].query(f"id == {id}")
    ax.axis("off")
    ax.set_title(f"{polygon.iloc[0]['convexity_measure']:.3f}")
    polygon.boundary.plot(ax=ax, color="black")
    polygon.to_crs("EPSG:6933").centroid.to_crs(labeled_polygons[label].crs).plot(ax=ax, color="purple")
  fig.savefig(f"figures/{label}_polygons.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
hlines = dict(sahara=(37.96,), sahel=(62.44,), rwanda=(51.92,))
vlines = dict(sahara=(0.97,), sahel=(0.97,), rwanda=(0.82,))

for label, polygons in labeled_polygons.items():
  if label in hlines and label in vlines:
    fig, meta = dlc.tools.plots.plot_2d_histogram(polygons.query("is_orphan == False"), x_key="convexity_measure", x_label="Convexity measure",
                                            y_key="size", y_label="Area", fontsize=18, figsize=(6, 4),
                                            vlines=vlines[label], hlines=hlines[label],)
  fig.tight_layout()
  fig.savefig(f"figures/{label}-2d-histogram.pdf", bbox_inches="tight")
  plt.close(fig)

In [None]:
labeled_areas = dict(
    sahara_sahel=sahara_sahel.areas.query("n_tiles > 0"),
    sahel=sahara_sahel.areas.query("n_tiles > 0 and region == 'sahel'"),
    sahara=sahara_sahel.areas.query("n_tiles > 0 and region == 'sahara'"),
    rwanda=rwanda.areas.query("n_tiles > 0"),
)

In [None]:
labeled_areas["sahara_sahel"]["n_polygons"].describe()

In [None]:
labeled_areas["sahara"]["n_polygons"].describe()

In [None]:
labeled_areas["sahel"]["n_polygons"].describe()

In [None]:
labeled_areas["rwanda"]["n_polygons"].describe()

In [None]:
outlier_spec = [
                dict(name="size", range_min=0.0),
                dict(name="convexity_measure", range_min=0.0, range_max=1.0,)]
labeled_polygons["sahel"] = dlc.tools.db.mark_outliers(labeled_polygons["sahel"], keys=outlier_spec)
cmap = mpl.colors.ListedColormap(["gray", "purple"], N=2)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
labeled_polygons["sahel"].query("area_id == 195").plot(column="outlier", ax=ax, cmap=cmap, alpha=0.8)
fig.savefig("figures/outliers_350.pdf", bbox_inches="tight")

In [None]:
labeled_polygons["sahel"] = dlc.tools.db.mark_query(
    labeled_polygons["sahel"],
    "desired",
    "size <= 200.0 and convexity_measure >= 0.80")
cmap = mpl.colors.ListedColormap(["purple", "gray"], N=2)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
labeled_polygons["sahel"].query("area_id == 195").plot(column="desired", ax=ax, cmap=cmap, alpha=0.8)
fig.savefig("figures/desired_350.pdf", bbox_inches="tight")

In [None]:
labeled_polygons["rwanda"] = dlc.tools.db.mark_outliers(labeled_polygons["rwanda"], keys=outlier_spec)
cmap = mpl.colors.ListedColormap(["gray", "purple"], N=2)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
labeled_polygons["rwanda"].query("area_id == 1").plot(column="outlier", ax=ax, cmap=cmap, alpha=0.8)
fig.savefig("figures/outliers_1.pdf", bbox_inches="tight")
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
labeled_polygons["rwanda"] = dlc.tools.db.mark_query(
    labeled_polygons["rwanda"],
    "desired",
    "size <= 200.0 and convexity_measure >= 0.80")
cmap = mpl.colors.ListedColormap(["purple", "gray",], N=2)
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
labeled_polygons["rwanda"].query("area_id == 1").plot(column="desired", ax=ax, cmap=cmap, alpha=0.8)
fig.savefig("figures/desired_1.pdf", bbox_inches="tight")
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

#### Fix overlapping polygons

In [None]:
fixed_polygons = dict()

Fix the polygons:

In [None]:
fixed_polygons[("rwanda", 1.0)] = dlc.tools.db.fix_overlapped_polygons(
    rwanda.polygons.query("is_orphan == False"),
    initial_scale=1.0)

In [None]:
fixed_polygons[("rwanda", 0.9)] = dlc.tools.db.fix_overlapped_polygons(
    rwanda.polygons.query("is_orphan == False"),
    initial_scale=0.9)

In [None]:
fixed_polygons[("sahara-sahel", 1.0)] = dlc.tools.db.fix_overlapped_polygons(
    sahara_sahel.polygons.query("is_orphan == False"),
    initial_scale=1.0)

In [None]:
fixed_polygons[("sahara-sahel", 0.9)] = dlc.tools.db.fix_overlapped_polygons(
    sahara_sahel.polygons.query("is_orphan == False"),
    initial_scale=0.9)

In [None]:
for (label, initial_scale), polygons in fixed_polygons.items():
  initial_scale = f"{initial_scale:.1f}".replace(".", "_")
  polygons.to_file(f"/content/datasets/{label}/fixed-polygons-{initial_scale}.geojson", driver="GeoJSON")

Or, load already fixed polygons:

In [None]:
fixed_polygons = dict()
fixed_polygons[("sahara-sahel", 0.9)] = None
fixed_polygons[("rwanda", 0.9)] = None
for (label, initial_scale), _ in fixed_polygons.items():
  initial_scale_s = f"{initial_scale:.1f}".replace(".", "_")
  filename = f"/content/datasets/{label}/fixed-polygons-{initial_scale_s}.geojson"
  fixed_polygons[(label, initial_scale)] = gpd.read_file(filename)

#### Manual inspection of fixed polygons
Inspect interactively (to find a proper extent):

In [None]:
# Find some interesting cases:
area_ids = set()
max = fixed_polygons["fixed"].max()
for _, polygon in fixed_polygons.query(f"fixed == {max}").iterrows():
  area_ids.add(polygon["area_id"])
print(area_ids)

In [None]:
fixed_polygons[("sahara-sahel", 0.9)].query("fixed > 0")

In [None]:
def inspect_area(polygons, area_id, *, max_passes=1, figsize=(12, 8),
                 highlight=None, highlight_color="black",
                 color=True, colorbar=True, zoom_to=None,
                 scale=0.90, save=False, fontsize=12):
  x = dlc.tools.db.fix_overlapped_polygons(polygons,
                                            area_id=area_id,
                                            initial_scale=scale,)

  fig = plt.figure(figsize=figsize)
  #fig.suptitle(f"Area {area_id}, Dataset: {DATASET}")
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.set_title("With overlap correction", fontsize=fontsize)
  ax2 = fig.add_subplot(1, 2, 2)
  ax2.set_title("Without overlap correction", fontsize=fontsize)
  areas.iloc[area_id:area_id+1].boundary.plot(color="black", ax=ax1, linestyle="dashed")
  areas.iloc[area_id:area_id+1].boundary.plot(color="black", ax=ax2, linestyle="dashed")
  polygons.boundary.plot(edgecolor="gray", linestyle="dashed", ax=ax1)
  if len(x) > 0:
    x.boundary.plot(edgecolor="black", ax=ax1)
    print(x.columns)
    fixed_x = x.query("fixed > 0")
    if len(fixed_x) > 0:
      fixed_x.plot(color="green", alpha=0.5, ax=ax1)
    overlapped_x = x.query("overlapped == True")
    if len(overlapped_x) > 0:
      overlapped_x.plot(color="red", alpha=0.5, ax=ax1)
    overlapped_x = x.query("overlapped == True or fixed > 0")
    if len(overlapped_x) > 0:
      overlapped_x.plot(color="red", alpha=0.5, ax=ax2)
  polygons.query(f"area_id == {area_id}").boundary.plot(edgecolor="black", ax=ax2)
  ax1.axis("off")
  ax2.axis("off")
  if highlight is not None:
    highlight_polygons = polygons.query(f"id in {highlight}")
    highlight_polygons.plot(color=highlight_color, ax=ax2)
  #polygons.boundary.plot(edgecolor="gray", linestyle="dashed", ax=ax2)
  if zoom_to is not None:
    ax1.set_xlim(zoom_to[:2])
    ax2.set_xlim(zoom_to[:2])
    ax1.set_ylim(zoom_to[2:])
    ax2.set_ylim(zoom_to[2:])
  fig.tight_layout()
  if save:
    fig.savefig(f"/content/output/fixed_area_{area_id}-{DATASET}.png")
    fig.savefig(f"/content/output/fixed_area_{area_id}-{DATASET}.pdf")
  return fig

In [None]:
f_area_id = 181
f_scale = 0.90
zoom_to = None
polygons = sahara_sahel.polygons
#zoom_to = (-14.051819097589249, -14.050454978133669, 15.309446991508144, 15.310388615191732)
zoom_to = (-13.942109216236247, -13.941586232120281, 15.408869070220973, 15.409089388357195)
fig = inspect_area(polygons, f_area_id, max_passes=None, highlight=None,
                   colorbar=False, zoom_to=zoom_to, scale=f_scale)
#fig.savefig("figures/overlapped_polygons_350.pdf", bbox_inches="tight")
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "fixed_polygons.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
def inspect_area(polygons, area_id, *, max_passes=1, figsize=(6,6),
                 highlight=None, highlight_color="black",
                 color=True, colorbar=True, zoom_to=None,
                 scale=0.90, save=False, fontsize=12,):
  x = dlc.tools.db.fix_overlapped_polygons(polygons,
                                            area_id=area_id,
                                            initial_scale=scale,)

  fig, ax = plt.subplots(1, 1, figsize=figsize)
  areas.iloc[area_id:area_id+1].boundary.plot(color="gray", linestyle="dashed", ax=ax)
  if len(x) > 0:
    x.boundary.plot(edgecolor="black", ax=ax)
    polygons.query(f"area_id == {area_id}").boundary.plot(edgecolor="black", linestyle="dotted", ax=ax)
    overlapped_x = x.query("overlapped == True or fixed > 0")
    if len(overlapped_x) > 0:
      overlapped_x.plot(color="purple", alpha=0.8, ax=ax)
  ax.axis("off")
  if zoom_to is not None:
    ax.set_xlim(zoom_to[:2])
    ax.set_ylim(zoom_to[2:])
  fig.tight_layout()
  return fig

In [None]:
f_area_id = 350
f_scale = 0.90
zoom_to = None
polygons = sahara_sahel.polygons
#zoom_to = (-14.051819097589249, -14.050454978133669, 15.309446991508144, 15.310388615191732)
fig = inspect_area(polygons, f_area_id, max_passes=None, highlight=None,
                   colorbar=False, zoom_to=zoom_to, scale=f_scale)
#fig.savefig("figures/overlapped_polygons_350.pdf", bbox_inches="tight")
# mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
# mpld3.display()

In [None]:
f_area_id = 170
f_scale = 1.0
zoom_to = (-14.05088262754295, -14.050575983537563, 15.309966409582259, 15.31017807965024)
fig = inspect_area(polygons, f_area_id, max_passes=None, highlight=None,
                   colorbar=False, zoom_to=zoom_to, scale=f_scale)
fig.savefig("figures/overlapped_polygons_170.pdf", bbox_inches="tight")
# mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
# mpld3.display()

In [None]:
f_area_id = 462
f_scale = 1.0
zoom_to = (-12.275696933849984, -12.275648172260922, 23.379950931018126, 23.379986533827907)
fig = inspect_area(polygons, f_area_id, max_passes=None, highlight=None,
                   colorbar=False, zoom_to=zoom_to, scale=f_scale)
fig.savefig("figures/heavily_overlapped_polygons_462.pdf", bbox_inches="tight")
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
f_area_id = 81
f_scale = 0.90
fig = inspect_area(polygons, f_area_id, max_passes=None, highlight=None,
                   colorbar=False, zoom_to=None, scale=f_scale)
#fig.savefig("figures/overlapped_polygons.pdf", bbox_inches="tight")
mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
mpld3.display()

In [None]:
Warning: Found 20 overlapped pairs at finish.
                Areas: {258, 272, 145, 146, 281, 285, 289, 297, 170, 181, 186, 67, 195, 76, 213, 86, 87, 94, 97, 251}

Save the figure:

In [None]:
zoom_to = None
fig = inspect_area(f_area_id, max_passes=None, highlight=None,
                   color=False, zoom_to=zoom_to,
                   scale=f_scale, save=True)

#### Loading the frames to local disk
Please see the sections below for the creation of the frame files and database.

In [None]:
import shutil

def download_frames_zip(dataset, namespace="frames", suffix="base", local_base_path="./data/datasets"):
  archive_path = pathlib.Path(f"/content/datasets/frames_zips/{namespace}-{dataset}-{suffix}.zip")
  if not archive_path.exists():
    msg = f"Archive does not exist: {archive_path}"
    raise ValueError(msg)
  local_path = pathlib.Path(local_base_path).joinpath(f"{namespace}/{dataset}")
  if not local_path.exists():
    local_path.mkdir(parents=True)
  else:
    print(f"Removing previous files at {local_path}")
    shutil.rmtree(local_path)
    local_path.mkdir(parents=True)
  print(f"Downloading {archive_path}")
  subprocess.call(["cp", str(archive_path), "./tmp.zip"])
  print(f"Saving files to {local_path}")
  subprocess.call(["unzip", "-j", "./tmp.zip", "-d", str(local_path)])
  pathlib.Path("./tmp.zip").unlink()
  frames = gpd.read_file(local_path.joinpath("frames.geojson"))
  return frames

#### Visualizing the tiles in the Sahara, Sahel and Sudan areas


In [None]:
africa = gpd.read_file("/content/other/shapefiles/africa.geojson")

In [None]:
def territories(objects):
   return gpd.sjoin(africa, gpd.GeoDataFrame(geometry=[objects.unary_union], crs=objects.crs), how="inner", op="intersects")

In [None]:
def zoom(ax, bbox):
  ax.set_xlim(bbox[:2])
  ax.set_xlim(bbox[:2])
  ax.set_ylim(bbox[2:])
  ax.set_ylim(bbox[2:])

In [None]:
def plot_rwanda(zoom_to=None, path=None, polygons=None, polygons_as_boundaries=False, figsize=(9, 9)):
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(1, 1, 1)
  africa.query("ISO3 in ['RWA']").boundary.plot(ax=ax, color="black")
  #rwanda.tiles.query("n_areas == 0").boundary.plot(ax=ax, color="skyblue", label="Tiles with no areas")
  rwanda.tiles.query("n_areas > 0").boundary.plot(ax=ax, color="purple", label="Tiles with areas")
  rwanda.areas.boundary.plot(ax=ax, color="gray", linestyle="dashed", label="Areas")
  bbox = None
  if zoom_to is not None:
    bbox = shapely.geometry.box(zoom_to[0], zoom_to[2], zoom_to[1], zoom_to[3])
    zoom(ax, zoom_to)
  if polygons is not None:
    if bbox is not None:
      polygons = gpd.sjoin(polygons, gpd.GeoDataFrame(geometry=[bbox], crs=polygons.crs), how="inner", op="intersects")
    if polygons_as_boundaries:
      polygons.boundary.plot(
          ax=ax, color="forestgreen", label="Tree annotations"
      )
    else:
      polygons.plot(ax=ax, color="forestgreen", label="Tree annotations")
  ax.set_ylabel("Latitude", fontsize=18)
  ax.set_xlabel("Longitude", fontsize=18)
  #handles, labels = ax.get_legend_handles_labels()
  #fig.legend(handles, labels, loc="lower center", fancybox=True, shadow=True, ncol=3)
  if path is not None:
    fig.savefig(path, bbox_inches="tight")
  return fig

In [None]:
fig = plot_rwanda()
mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
mpld3.display()

In [None]:
!mkdir figures

In [None]:
from google.colab import files
files.download("./figures/rwanda-areas.png")

In [None]:
fig = plot_rwanda(path="./figures/rwanda-tiles.pdf")

In [None]:
zoom_to = (29.998086308683405, 30.13197761201941, -2.0059204534785198, -1.8936863709496667)
fig = plot_rwanda(zoom_to=zoom_to, polygons=rwanda.polygons,
                  path="./figures/rwanda-areas.png")

In [None]:
zoom_to = (30.200470901174018, 30.20192267090786, -1.7068311353141248, -1.7056141926146275)
fig = plot_rwanda(zoom_to=zoom_to, polygons=rwanda_polygons,
                  polygons_as_boundaries=True,
                  path="./figures/rwanda-trees.png")

In [None]:
def plot_sahara_sahel(zoom_to=None, path=None, polygons=None, polygons_as_boundaries=False, figsize=(9, 9),
                      interactive=False):
  fig = plt.figure(figsize=figsize)
  ax = fig.add_subplot(1, 1, 1)
  if not interactive:
    territories(sahara_sahel.tiles).boundary.plot(ax=ax, color="black")
  #sahara_sahel.tiles.query("n_areas == 0").boundary.plot(ax=ax, color="skyblue", label="Tiles with no areas")
  sahara_sahel.tiles.query("region == 'sahel' and n_areas > 0").boundary.plot(ax=ax, color="purple", label="Sahel-Sudan tiles")
  sahara_sahel.tiles.query("region == 'sahara' and n_areas > 0").boundary.plot(ax=ax, color="gold", label="Sahara tiles")
  sahara_sahel.areas.boundary.plot(ax=ax, color="gray", linestyle="dashed", label="Areas")
  bbox = None
  if zoom_to is not None:
    bbox = shapely.geometry.box(zoom_to[0], zoom_to[2], zoom_to[1], zoom_to[3])
    zoom(ax, zoom_to)
  if polygons is not None:
    if bbox is not None:
      polygons = gpd.sjoin(polygons, gpd.GeoDataFrame(geometry=[bbox], crs=polygons.crs), how="inner", op="intersects")
    if polygons_as_boundaries:
      polygons.boundary.plot(
          ax=ax, color="forestgreen", label="Tree annotations"
      )
    else:
      polygons.plot(ax=ax, color="forestgreen", label="Tree annotations")
  #handles, labels = ax.get_legend_handles_labels()
  #fig.legend(handles, labels, loc="lower center", fancybox=True, shadow=True, ncol=3)
  ax.set_ylabel("Latitude", fontsize=18)
  ax.set_xlabel("Longitude", fontsize=18)
  if path is not None:
    fig.savefig(path, bbox_inches="tight")
  return fig

In [None]:
fig = plot_sahara_sahel(interactive=True, figsize=(12, 6))
mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
mpld3.display()

In [None]:
fig = plot_sahara_sahel(path="./figures/sahara-sahel-tiles.pdf")

In [None]:
zoom_to = (-15.106290532941847, -14.567430706045057, 15.642839549101144, 16.238970734236812)

fig = plot_sahara_sahel(path="./figures/sahara-sahel-areas.png", zoom_to=zoom_to,
                        polygons=sahara_sahel.polygons)

In [None]:
zoom_to = (-10.832529509520917, -10.829493690373182, 13.282456625293477, 13.285673112691013)

fig = plot_sahara_sahel(path="./figures/sahara-sahel-trees.png", zoom_to=zoom_to,
                        polygons=sahara_sahel_polygons,
                        polygons_as_boundaries=True)

### Frame data creators

In [None]:
import dlc.tools.cache
import dlc.frames.centroids
import dlc.frames.raster
import dlc.frames.creators.base
import dlc.frames.creators.data
import dlc.frames.creators.density
import dlc.frames.creators.image
import dlc.frames.creators.scalar
import dlc.frames.creators.segmentation
import dlc.frames.factory

importlib.reload(dlc.tools.cache)
importlib.reload(dlc.frames.centroids)
importlib.reload(dlc.frames.raster)
importlib.reload(dlc.frames.creators.base)
importlib.reload(dlc.frames.creators.data)
importlib.reload(dlc.frames.creators.density)
importlib.reload(dlc.frames.creators.image)
importlib.reload(dlc.frames.creators.scalar)
importlib.reload(dlc.frames.creators.segmentation)
importlib.reload(dlc.frames.factory)

#### Selecting area and tile

##### Sahara-Sahel

In [None]:
dataset = "sahara-sahel"
tiles = sahara_sahel.tiles
areas = sahara_sahel.areas
polygons = fixed_polygons[(dataset, 0.9)]
output_path = f"./new_frames/{dataset}"
tiles_path = f"/content/datasets/{dataset}/StackedImages"

In [None]:
# Big trees
area_id, tile_id = sahara_sahel.get_area_and_tile(211)
zoom_to = (-14.871731243347035, -14.869030669749963, 16.05790845472101, 16.058391813736648)

In [None]:
# Big trees
area_id, tile_id = sahara_sahel.get_area_and_tile(350)
zoom_to = None

In [None]:
# No polygons
area_id, tile_id = sahara_sahel.get_area_and_tile(1)
zoom_to = None

In [None]:
# Dense area
area_id, tile_id = sahara_sahel.get_area_and_tile(153)
zoom_to = (-9.434685015373589, -9.433487773486778, 13.74502450648988, 13.745933990550967)

In [None]:
# Random
area_id, tile_id = sahara_sahel.get_area_and_tile()
zoom_to = None

##### Rwanda

In [None]:
dataset = "rwanda"
tiles = rwanda.tiles
areas = rwanda.areas
polygons = fixed_polygons[(dataset, 0.9)]
output_path = f"./new_frames/{dataset}"
tiles_path = f"/content/datasets/{dataset}/Images"

In [None]:
area_id, tile_id = rwanda.get_area_and_tile(79)
zoom_to = None

In [None]:
# Extremely dense area
area_id, tile_id = rwanda.get_area_and_tile(11)
zoom_to = (30.625761063670012, 30.62624789752035, -1.42618801533133, -1.425869104433315)
# Zoom to zone where polygons barely touch the area
# resulting in rasterization innacuracies
zoom_to = (30.626560903146476, 30.626697263995233, -1.4252136443760592, -1.4251243182934532)

In [None]:
# Another very dense area
area_id, tile_id = rwanda.get_area_and_tile(3)
zoom_to = (30.599905901773624, 30.60023820852144, -1.4137027731199092, -1.413465157037063)

In [None]:
area_id, tile_id = rwanda.get_area_and_tile(34)
zoom_to = (29.432809336666153, 29.433450492110822, -2.758289038240984, -2.7577796536127304)

Random sample:

In [None]:
# Random
area_id, tile_id = rwanda.get_area_and_tile()
zoom_to = None

#### Plotting creator results

We can plot the polygons in the chosen area to qualitatively verify the frames in the next sections.

In [None]:
dlc.tools.plots.plot_polygons_in_area(areas, polygons, area_id)

In [None]:
# TODO: Refactor/simplify this function.
def plot_result(result, *, path=None, cmap=None,
                        log=False, normalize=False, standardize=False, bins=20, figsize=None,
                        show_hist=True, output_path = None, masked=False, polygons=None, polygon_color="red",
                        zoom_to=None, save=False, show_title=False, additive_constant=None,
                        nodata=None, norm=None, bad_value=None, bad_color=None,
                        show_axis=True, show_band_names=True,):
  if path is None:
      if isinstance(result.payload, dict):
        path = result.payload[result.payload_main_key]
      else:
        path = result.payload

  return plot_frame(result.area_id, result.tile_id, path=path, figsize=figsize,
             cmap=cmap, log=log, normalize=normalize, standardize=standardize,
             bins=bins, show_hist=show_hist, output_path=output_path, masked=masked,
             polygons=polygons, polygon_color=polygon_color,zoom_to=zoom_to,
             show_title=show_title, save=save, additive_constant=additive_constant,
             nodata=nodata, norm=norm, bad_value=bad_value, bad_color=bad_color,
             show_axis=show_axis, show_band_names=show_band_names)

def plot_histogram(result, factor = 1.0, fontsize=12):
  if path is None:
      if isinstance(result.payload, dict):
        path = result.payload[result.payload_main_key]
      else:
        path = result.payload
  image = dlc.tools.images.load_image(path)
  image *= factor
  fig, ax = plt.suplots(1, 1, figsize=(12, 4))
  ax.hist(image.flatten(), bins="auto", color="purple")
  ax.set_xlabel("Pixel value")
  ax.set_ylabel("Frequency")
  return fig

def image_normalize(im, axis=(0, 1), c=1e-8):
    """Normalize to zero mean and unit standard deviation along the given axis"""
    return (im - im.mean(axis)) / (im.std(axis) + c)

def plot_frame(area_id, tile_id, path, *,
               cmap=None, log=False, normalize=False, standardize=False, figsize=None,
               show_hist=True, bins=20, output_path = None, masked=False, polygons=None, polygon_color="red",
               zoom_to=None, save=False, show_title=False, additive_constant=None, nodata=None,
               norm=None, bad_value=None, bad_color=None, show_axis=True, show_band_names=True,):

  with rasterio.open(path, "r") as src:
    transform = src.transform

  image = dlc.tools.images.load_image(path, masked=masked)
  if bad_value is not None:
    image = np.ma.masked_less_equal(image, bad_value)

  if additive_constant is not None:
    image[image != 0] += additive_constant

  if standardize:
    image = dlc.tools.scalers.standardize_image_np(image, axis=(0, 1))

  if normalize:
    image = dlc.tools.scalers.normalize_image_np(image, axis=(0, 1))

  if polygons is not None:
    polygons_in_area =  polygons.query(f"area_id == {area_id}", inplace=False)
    title = f"Area {area_id}, Tile {tile_id}, {len(polygons_in_area)} polygons"
  else:
    polygons_in_area = None
    title = f"Area {area_id}, Tile {tile_id}"

  output_path=None
  if save:
    path = pathlib.Path(path)
    name = path.name.split(".")[0]
    output_path = pathlib.Path(f"/content/output/plots/{DATASET}-{name}.png")

  return dlc.tools.plots.plot_frame(image, title=title, cmap=cmap, figsize=figsize,
                                    log=log, bins=bins, show_hist=show_hist,
                                    output_path=output_path, transform=transform,
                                    polygons=polygons_in_area, polygon_color=polygon_color,
                                    zoom_to=zoom_to, show_title=show_title, nodata=nodata,
                                    norm=norm, mask_color=bad_color,
                                    show_axis=show_axis, show_band_names=show_band_names,)

#### Data source and cache

The datasource object stores the tiles, areas and polygons databases and also loads and stores polygon and area raster data.

In [None]:
data = dlc.frames.creators.data.CoreFrameDataSource(
    dataset, tiles_path,
    tiles, areas, polygons,
)

In [None]:
%%time
data.load_raster_data()

In [None]:
data2 = dlc.frames.creators.data.CoreFrameDataSource(
    dataset, tiles_path,
    tiles, areas, sahara_sahel.polygons,
)

In [None]:
%%time
data2.load_raster_data()

In [None]:
data.suggested_filter_size()

We can inspect polygon rasters:

`data.get_polygon_raster_data(area_id, tile_id, 80392)` or randomly `data.get_random_polygon_raster_data()`

In [None]:
xs = []
ys = []
for k,v in data._polygon_raster_data.items():
  ys.append(v.raster.shape[0])
  xs.append(v.raster.shape[1])
xs = np.asarray(xs, dtype="float32")
ys = np.asarray(ys, dtype="float32")

print(np.mean(xs), np.quantile(xs, 0.25), np.quantile(xs, 0.50), np.quantile(xs, 0.75))
print(np.mean(ys), np.quantile(ys, 0.25), np.quantile(ys, 0.50), np.quantile(ys, 0.75))

In [None]:
x = data.get_polygon_raster_data(area_id, tile_id, polygon_id)

In [None]:
x = data.get_random_polygon_raster_data()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8,4))
axs[0].imshow(x.raster, cmap="gray")
axs[0].set_title("bbox")
axs[1].imshow(x.raster_padded, cmap="gray")
axs[1].set_title("bbox_padded")

In [None]:
def mask_bad(x):
  return np.ma.masked_less_equal(x, 0.0)

In [None]:
c = dlc.frames.centroids.standard_centroid(x.raster)
cm = dlc.frames.centroids.centroid_mask(x.raster, c)
cp = dlc.frames.centroids.standard_centroid(x.raster_padded)
cmp = dlc.frames.centroids.centroid_mask(x.raster_padded, cp)
y = dlc.frames.creators.density.apply_alt_gaussian_filter(cm, 5.0, 0.2)
yp = dlc.frames.creators.density.apply_alt_gaussian_filter(cmp, 5.0, 0.2)

fig, axs = plt.subplots(1, 2, figsize=(8,4))
cmap = mpl.cm.get_cmap("gray")
cmap.set_bad(color="purple")
axs[0].imshow(mask_bad(y), cmap=cmap)
axs[0].set_title("bbox")
axs[1].imshow(mask_bad(yp), cmap=cmap)
axs[1].set_title("bbox_padded")

In [None]:
f = dlc.frames.creators.density.get_gaussian_filter(15, 5.0)
y = dlc.frames.creators.density.apply_filter(cm, f)
yp = dlc.frames.creators.density.apply_filter(cmp, f)

fig, axs = plt.subplots(1, 2, figsize=(8,4))
axs[0].imshow(mask_bad(y), cmap="gray")
axs[0].set_title("bbox")
axs[1].imshow(mask_bad(yp), cmap="gray")
axs[1].set_title("bbox_padded")

In [None]:
y = dlc.frames.creators.density.edt_transform(x.raster)
yp = dlc.frames.creators.density.edt_transform(x.raster_padded, pad_width=1)

fig, axs = plt.subplots(1, 2, figsize=(8,4))
axs[0].imshow(mask_bad(y), cmap="gray")
axs[0].set_title("bbox")
axs[1].imshow(mask_bad(yp), cmap="gray")
axs[1].set_title("bbox_padded")

The cache can store intermediate results that multiple data creators might need (e.g. centroids, energy maps).

In [None]:
cache = dlc.tools.cache.ArrayCache()

In [None]:
cache.clear()

In [None]:
cache.hits, cache.misses, cache.size * 1e-6

In [None]:
cache = None

In [None]:
figsize = (12, 4)

#### Images

In [None]:
creator = dlc.frames.creators.image.AltImageFrameCreator(data, output_path)

res = creator.run(area_id, tile_id)

fig = plot_result(res, cmap="gray", show_hist=True, figsize=figsize,
            polygons=polygons, polygon_color="gold", zoom_to=zoom_to,
            show_band_names=True, show_axis=True,)

In [None]:
creator = dlc.frames.creators.image.AltImageFrameCreator(data, output_path)

res = creator.run(area_id, tile_id)

fig = plot_result(res, cmap="gray", show_hist=True, figsize=figsize,
            polygons=polygons, polygon_color="red", zoom_to=zoom_to)

#### Segmentation

In [None]:
creator = dlc.frames.creators.segmentation.SegmentationMaskFrameCreator(data, output_path)

res = creator.run(area_id, tile_id)

fig = plot_result(res, cmap="gray_r", show_hist=False, figsize=(12, 4),
            norm=None, polygons=polygons, polygon_color="magenta",
            bad_value=0.0, bad_color="purple",
            zoom_to=zoom_to, show_axis=False, show_band_names=False)
# To be able to zoom in
# mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
# mpld3.display()

In [None]:
from google.colab import files
filename = "gt-segmentation-mask-0_9.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.segmentation.SegmentationMaskFrameCreator(data2, output_path)

res = creator.run(area_id, tile_id)

fig = plot_result(res, cmap="gray_r", show_hist=False, figsize=(12, 4),
            norm=None, polygons=polygons, polygon_color="magenta",
            bad_value=0.0, bad_color="purple",
            zoom_to=zoom_to, show_axis=False, show_band_names=False)
# To be able to zoom in
# mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
# mpld3.display()

In [None]:
from google.colab import files
filename = "gt-segmentation-mask-1_0.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.segmentation.AltSegmentationMaskFrameCreator(data, output_path)

res = creator.run(area_id, tile_id)

fig = plot_result(res, cmap="gray", show_hist=True, figsize=figsize,
            norm=None, polygons=polygons, polygon_color="magenta")

In [None]:
creator = dlc.frames.creators.segmentation.SegmentationBoundaryWeightsFrameCreator(data, output_path)

res = creator.run(area_id, tile_id, overwrite=True)

fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
            norm=None, polygons=polygons, polygon_color="magenta", show_axis=False, show_band_names=False)

In [None]:
from google.colab import files
filename = "boundary-weigths.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.segmentation.OutlierWeightsMaskFrameCreator(data, output_path)

res = creator.run(area_id, tile_id, overwrite=True)

fig = plot_result(res, cmap="gray", show_hist=True, figsize=(12, 6),
            norm=None, polygons=polygons, polygon_color="magenta")

#### Density

In [None]:
creator = dlc.frames.creators.density.DMGaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=5,
                                                                  filter_size=13,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=True, additive_constant=None,
                  zoom_to=zoom_to, show_axis=False, show_band_names=False)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-density-dm-s5-fs13.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.DMGaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=5,
                                                                  filter_size=15,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=True, additive_constant=None,
                  zoom_to=zoom_to, show_axis=False, show_band_names=False)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-dm-fs3-sahel.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.DMGaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=2,
                                                                  filter_size=7,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=True, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=True, additive_constant=None,
                  zoom_to=zoom_to)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
creator = dlc.frames.creators.density.GaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=5,
                                                                  filter_size=13,
                                                                  filter_target="centroid",
                                                                  centroid_type="energy",
                                                                  use_padded_bbox=True,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=False, additive_constant=None,
                  zoom_to=zoom_to, show_band_names=False, show_axis=False)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
fig = plot_histogram(res)

In [None]:
from google.colab import files
filename = "gt-hist-default.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = plot_histogram(res, 1e2)

In [None]:
from google.colab import files
filename = "gt-hist-1e2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = plot_histogram(res, 1e1)

In [None]:
from google.colab import files
filename = "gt-hist-1e1.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
def plot_histogram(result, factor = 1.0, fontsize=12, log=True):
  if isinstance(result.payload, dict):
    path = result.payload[result.payload_main_key]
  else:
    path = result.payload
  image = dlc.tools.images.load_image(path)
  image *= factor
  fig, ax = plt.subplots(1, 1, figsize=(4, 4))
  ax.hist(image.flatten(), bins="auto", color="purple", log=log)
  ax.set_xlabel("Pixel value")
  ax.set_ylabel("Frequency")
  return fig

In [None]:
from google.colab import files
filename = "gt-density-g-ec-s5-fs13.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.GaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=5,
                                                                  filter_size=13,
                                                                  filter_target="centroid",
                                                                  centroid_type="standard",
                                                                  use_padded_bbox=True,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=True, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=False, additive_constant=None,
                  zoom_to=zoom_to, show_band_names=False, show_axis=False)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-density-g-sc-s5-fs13.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.GaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=5,
                                                                  filter_size=3,
                                                                  filter_target="polygon",
                                                                  use_padded_bbox=True,
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=False, additive_constant=None,
                  zoom_to=zoom_to, show_band_names=False, show_axis=False)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-density-g-poly-s5-fs3.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.THGaussianDensityFrameCreator(data,
                                                                  output_path,
                                                                  sigma=3,
                                                                  thresh_z_score=None,
                                                                  filter_target="centroid",
                                                                  centroid_type="energy",
                                                                  )

res = creator.run(area_id, tile_id, cache=cache)
fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
                  polygons=polygons, polygon_color="magenta",
                  norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
                  log=False, additive_constant=None,
                  zoom_to=zoom_to)
# To be able to zoom in
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
creator = dlc.frames.creators.density.EDTDensityFrameCreator(data, output_path)

res = creator.run(area_id, tile_id, cache=None)

fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
            polygons=polygons, polygon_color="magenta",
            norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple",
            log=False, zoom_to=zoom_to, additive_constant=0.0,
            show_band_names=False, show_axis=False)
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-density-edt1.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
creator = dlc.frames.creators.density.UniformDensityFrameCreator(data, output_path)

res = creator.run(area_id, tile_id, cache=cache)

fig = plot_result(res, cmap="gray", show_hist=False, figsize=figsize,
            polygons=polygons, polygon_color="magenta", zoom_to=zoom_to,
            norm=mpl.colors.LogNorm(), bad_value=0.0, bad_color="purple", log=False, additive_constant=0.0,
            show_band_names=False, show_axis=False)
#mpld3.plugins.connect(fig, dlc.tools.mpld3.ZoomSizePlugin())
#mpld3.display()

In [None]:
from google.colab import files
filename = "gt-density-uniform.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Scalars

##### Frame properties

In [None]:
creator = dlc.frames.creators.scalar.FramePropertiesDataCreator(data)

creator.run(area_id, tile_id)

##### Model selection (Sahara-Sahel dataset)

In [None]:
creator = dlc.frames.creators.scalar.ModelSelectionDataCreator(data)

In [None]:
creator.run(area_id, tile_id)

In [None]:
# We know area 343 is in the Sahel/Sudan region
creator.run(*sahara_sahel.get_area_and_tile(343))

In [None]:
# We know area 21 is in the Sahara region
creator.run(*sahara_sahel.get_area_and_tile(21))

### Frame data factory

A frame data factory takes a list of frame data creators and creates frames given a pair of tiles and areas databases or an existing frames database. In the first case, it creates a new frames database and in the latter case, it will update the existing database.

```batch
!export PYTHONPATH="${GITHUB_REPO_NAME}:${PYTHONPATH}" && \
   python3 ${GITHUB_REPO_NAME}/scripts/create_frames.py \
   --dataset-name="sahara-sahel" \
   --datasets-path="/content/datasets" \
   --img-dir="StackedImages" \
   --output-path="./data/datasets/frames" \
   --creators-group="segmentation"
```

In [None]:
!export PYTHONPATH="counting-trees-private:${PYTHONPATH}" && \
  python3 counting-trees-private/scripts/create_frames.py \
   --dataset-name="sahara-sahel" \
   --datasets-path="/content/datasets" \
   --img-dir="StackedImages" \
   --output-path="./frames" \
   --creators-group="density" \
   --fixed-polygons \
   --initial-scale="0.9"

In [None]:
!ls frames/

In [None]:
!cp frames/frames-sahara-sahel-density-0_9.zip /content/datasets/frames_zips/

In [None]:
!unzip -vl frames/frames-sahara-sahel-density-0_9.zip | grep geojson

In [None]:
!shasum -a 256 frames/frames-sahara-sahel-density-0_9.zip

In [None]:
!zip -j -r frames/frames-sahara-sahel-density-0_9.zip 

#### Training frames

In [None]:
FRAMES_SUFFIX = "segmentation"
FRAMES_PATH = f"./data/datasets/frames/{DATASET}"
FRAMES_ZIP = f"frames-{DATASET}-{FRAMES_SUFFIX}.zip"

In [None]:
creator_names = ("image", "props", "model", "segmentation-mask",
                 "segmentation-boundary-weights", "outlier-weights",)

creator_names += ("gaussian-density", "th-gaussian-density", "dm-gaussian-density", "uniform-density", "edt-density",)

data = dlc.frames.creators.data.CoreFrameDataSource(DATASET, DATASET_IMG_PATH, tiles, areas, fixed_polygons)

gaussian_options=(
    dict(filter_size=15, sigma=5.0, centroid_type="energy", filter_target="centroid",),
    dict(filter_size=15, sigma=5.0, centroid_type="standard", filter_target="centroid",),
    dict(filter_size=3, sigma=5.0, centroid_type="energy", filter_target="polygon",),
)

th_gaussian_options=(
    dict(sigma=5.0, thresh_z_score=0.1, centroid_type="energy", filter_target="centroid"),
)

dm_gaussian_options=(
    dict(sigma=5.0, filter_size=15),
)

factory  = dlc.frames.factory.create_and_configure_factory(
    data, FRAMES_PATH,
    gaussian_options=gaussian_options,
    th_gaussian_options=th_gaussian_options,
    dm_gaussian_options=dm_gaussian_options,
    creator_names=creator_names,
)

In [None]:
pprint(factory.keys)

Update an existing database.

In [None]:
factory_db = f"/content/datasets/frames/{DATASET}/frames.geojson"
_ = factory.run_jobs(factory_db,
                     dry_run=True)

Create a new frames database from scratch (derived from tiles and areas databases).

In [None]:
factory_db = (tiles, areas)
_ = factory.run_jobs(factory_db,
                     dry_run=True,
                     job_slice=slice(0, 6),
                    )

In [None]:
%%time
data.load_raster_data(n_processes=None)

In [None]:
%%time
result = factory.run_jobs(factory_db,
                          dry_run=False,
                          job_slice=None,
                          n_processes=4,
                          overwrite=True,
                          output_path=f"{FRAMES_PATH}/frames.geojson",
                          save_keys=None,
                          cache_enabled=False,
                          cache_shared=True,
                          )

In [None]:
result.frames.query("model == 'sahel' and canopy_cover > 0.0").head()

In [None]:
!rm $FRAMES_ZIP

In [None]:
!zip -j -r $FRAMES_ZIP $FRAMES_PATH

In [None]:
!cp $FRAMES_ZIP /content/datasets/frames_zips/$FRAMES_ZIP

We can sample a random row and see a created frame:

In [None]:
area_id, tile_id = get_area_and_tile(350)

In [None]:
dlc.tools.plots.plot_polygons_in_area(areas, fixed_polygons, area_id)

In [None]:
fig = dlc.tools.plots.plot_frame_by_key(result.frames, "image", FRAMES_PATH,
                                  area_id=area_id, tile_id=tile_id,
                                  polygons=fixed_polygons,
                                  show_title=True, show_hist=False,
                                  log=False, norm=None,
                                  cmap="gray", figsize=(12, 6), zoom_to=None)

In [None]:
fig = dlc.tools.plots.plot_frame_by_key(result.frames, "segmentation-boundary-weights", FRAMES_PATH,
                                  area_id=area_id, tile_id=tile_id,
                                  #polygons=fixed_polygons,
                                  show_title=True, show_hist=True,
                                  log=True, cmap="gray", figsize=(12, 6), zoom_to=None)

In [None]:
fig = dlc.tools.plots.plot_frame_by_key(result.frames, "outlier-weights", FRAMES_PATH,
                                  area_id=area_id, tile_id=tile_id,
                                  #polygons=fixed_polygons,
                                  show_title=True, show_hist=True,
                                  log=True, cmap="gray", figsize=(12, 6), zoom_to=None)

### Object splits

We can create splits for different objects (e.g. tiles or frames).

In [None]:
rng = np.random.default_rng()

In [None]:
seed = rng.integers(0, np.iinfo(np.int32).max, size=1)

In [None]:
seed = 591477907

In [None]:
splitter = dlc.tools.splits.LatitudeObjectSplitter()

#### Tile splits
In the case of tiles, we could use a split to create a smaller dataset.

In [None]:
split_areas = areas.query("n_tiles > 1", inplace=False)

splits, sampled_areas = splitter.run(split_areas, splits=(0.20,), seed=seed)

splits

In [None]:
dlc.tools.plots.plot_object_splits(split_areas, splits, areas=None, colors=["white", "green", "red"])

In [None]:
splitter = dlc.tools.splits.SimpleSplitter()

rwanda_tiles = gpd.read_file("/content/datasets/rwanda/tiles.gpkg")
splits, _ = splitter.run(rwanda_tiles, splits=(0.20, 0.20), seed=seed)
dlc.tools.plots.plot_object_splits(rwanda_tiles, splits, areas=areas, colors=["white", "green", "red"])

In [None]:
sahel_tiles = gpd.read_file("/content/datasets/sahel/tiles.gpkg")
splits, areas = splitter.run(sahel_tiles, splits=(0.20, 0.20), seed=seed)
dlc.tools.plots.plot_object_splits(sahel_tiles, splits, areas=areas, colors=["white", "green", "red"])

In [None]:
sahel_tiles["centroid_y"] = sahel_tiles.to_crs("EPSG:6933").centroid.to_crs(sahel_tiles.crs).y
sahel_tiles.hist("centroid_y")

#### Frame splits
In the case of frames, we could use splits to divide them into training, validation, and test splits.

In [None]:
frames = gpd.read_file(f"/content/datasets/frames/{DATASET}/frames.geojson")

In [None]:
splits, _ = splitter.run(frames, splits=(0.20, 0.20), seed=seed)

In [None]:
dlc.tools.plots.plot_object_splits(frames, splits, areas=None, colors=["white", "green", "red"])

In [None]:
dlc.tools.plots.plot_object_splits(frames, splits, areas=None, colors=["white", "green", "red"],
                                   window=((14.2, 14.3), (-14.75, -14.6)))

## Datasets

### Remote sensing datasets

In [None]:
%%time
frames = download_frames_zip(DATASET, namespace="frames", suffix="density")

In [None]:
src_frames = frames.copy()

In [None]:
frames = src_frames.query("model == 'sahel'")

In [None]:
frames.columns

In [None]:
splitter = dlc.tools.splits.LatitudeObjectSplitter()

splits, _ = splitter.run(frames, [0.20], seed=591477907)

splits_map = {"training": 0, "validation1": 1,}

In [None]:
images_ds_gen = dlc.tools.datasets.ImageDatasetGenerator(frames,
                                                         splits=splits,
                                                         splits_map=splits_map,
                                                         image_keys=(["image"],
                                                                     ["segmentation-mask",
                                                                      "dm-gaussian-density-sgm_500-fs_1500"]),
                                                         input_base_path=f"./data/datasets/frames/{DATASET}",
                                                        )

In [None]:
images_ds_gen = dlc.tools.datasets.ImageDatasetGenerator(frames,
                                                         splits=splits,
                                                         splits_map=splits_map,
                                                         image_keys=(["image"],
                                                                     ["segmentation-mask"]),
                                                         input_base_path=f"./data/datasets/frames/{DATASET}",
                                                        )

#### Sequential frames

In [None]:
# Finite cardinality
images = images_ds_gen.get_sequential_images(split="validation1", shuffle=False,
                                             seed=None, verbose=True)

for x in images.take(1):
  pprint(x)

#### Random frame patches

In [None]:
# Infinite cardinality
images = images_ds_gen.get_random_patches((256, 256), split="training", seed=None, verbose=True)

for x in images.take(1):
  pprint(x)

#### Sequential frame patches

In [None]:
# Finite cardinality
images = images_ds_gen.get_sequential_patches((256, 256), split="validation1", shuffle=False, seed=None, verbose=True)
images = images.batch(8)

for x in images.skip(41).take(1):
  pprint(x)

#### Loading images

In [None]:
cmf = dlc.tools.plots.ColorMapFactory()
# cmf.add_group(dict(keys=(("annotations", 2),),
#                    cmap="gray",
#                    bad_value=0.0,
#                    bad_color="red",))

In [None]:
cache = dlc.tools.cache.ArrayCache()
image_loader = dlc.tools.images.ImageLoader(local_standardization_p=[0.0, None],
                                            # Defaults must take into account
                                            # bands per each image passed
                                            # defaults=([(0.0, 0.0)],
                                            #           [0.0, 0.0, 1.0]),
                                            seed=None, cache=cache)

In [None]:
images = images_ds_gen.get_sequential_patches((256, 256),split="validation1", shuffle=False, seed=None, verbose=True)
images = images.map(image_loader.load)
images = images.batch(8)

In [None]:
importlib.reload(dlc.tools.plots)

In [None]:
for xs, ys in images.take(1):
  fig = dlc.tools.plots.plot_batch(xs, ys, cmf=cmf, show_hist=True)

In [None]:
images = images_ds_gen.get_sequential_patches((256, 256),split="validation1", shuffle=False, seed=None, verbose=True)
images = images.map(image_loader.load)
images = images.map(fix_boundary_weights_v2)
images = images.map(dlc.transformers.to_cover5_annotation)
images = images.batch(8)

In [None]:
for xs, ys in images.skip(41).take(1):
  fig = dlc.tools.plots.plot_multioutput_batch(xs, ys,
                                               keys=("segmentation_map",),
                                               cmf=cmf,)

#### Augmentations

In [None]:
density_aug_transform = dlc.augmentation.DensityAugTransform0()

In [None]:
segmentation_aug_transform = dlc.augmentation.SegmentationAugTransform0()

In [None]:
for x,y in training_images.map(image_loader.load).map(segmentation_aug_transform).take(10):
  xn = x.numpy()
  yn = y.numpy()
  fig = plt.figure(constrained_layout=True)
  gs = fig.add_gridspec(max(xn.shape[2], yn.shape[2]), 2)
  for i in range(xn.shape[2]):
    ax = fig.add_subplot(gs[0, i])
    ax.imshow(xn[:, :, i], cmap="gray")
    ax.set_title(f"Features (Band {i + 1})")
  for i in range(yn.shape[2]):
    ax = fig.add_subplot(gs[1, i])
    ax.imshow(yn[:, :, i], cmap="gray")
    ax.set_title(f"Annotations (Band {i + 1})")

## Model visualization

### U-Net model

In [None]:
# Example from the 2015 U-Net paper
input_shape = (572, 572, 1)
output_depth = 2
batch_size=None
model = dlc.models.base.unet.create_model(input_shape,
                                  output_depth,
                                  batch_size=batch_size,
                                  padding="valid",
                                  name='2015_unet_paper',
                                  resize_output=False,
                                  use_attention_gate=False,
                                  initializer_seed=0)
model.summary()

### Segmentation-based models

In [None]:
# U-Net model v1.0 from the reference paper
input_shape = (256, 256, 2)
model = dlc.models.misc.unet0.load_unet0_model('/content/other/sahara_v1_0_0.h5')
model.call(tf.keras.layers.Input(shape=input_shape))
model.summary()

#### CSR-Net

In [None]:
input_shape = (256, 256, 3)
output_depth = 1
batch_size=None
model = dlc.models.base.csrnet.create_csrnet_b(input_shape,
                                           output_depth,
                                           batch_size=batch_size,
                                           name='csrnet_test',
                                           initializer_seed=0)
model.summary()

## Model training


In [None]:
!cd $GITHUB_REPO_NAME && git stash && git pull origin develop && git stash pop

In [None]:
dataset = "sahara-sahel"
data_directory = "data"

In [None]:
dataset = "rwanda"
data_directory = "data"

In [None]:
%%time
frames = download_frames_zip(dataset, namespace="frames", suffix="density-0_9")

In [None]:
frames.columns

In [None]:
!ls -l $data_directory/datasets/frames/$dataset | wc -l

In [None]:
# Optional: to load existing models (e.g. for evaluation)
!ln -s $BASE_PATH/models $data_directory/models

In [None]:
!ls -l data/models

In [None]:
%%writefile train.sh
#!/bin/bash

export DLC_PROJECT_DIRECTORY="./counting-trees-private"
export DLC_DATA_DIRECTORY="./data"


python3 $DLC_PROJECT_DIRECTORY/scripts/train_eval.py --config \
    seeds \
    ds/sahel \
    multi0/model \
    multi0/density-block-adapter \
    multi0/gt-adapter-g4 \
    multi0/loss/b1wds \
    multi0/metrics multi0/plots \
    opt/cyclical/sgd \
    --train 100 \
    --seed 0 \
    --datadir $DLC_DATA_DIRECTORY \
    --projectdir $DLC_PROJECT_DIRECTORY

In [None]:
!sh train.sh

## Model evaluation

In [None]:
!cd $GITHUB_REPO_NAME && git stash && git pull origin develop && git stash pop

In [None]:
import dlc.tools.evaluation
importlib.reload(dlc.tools.evaluation)

### Experiment: Cover

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/cover.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Seed", "Loss"],
                                           format_rows="best", showindex="never", tablefmt="html", verbose=True, group_by="Dataset")

In [None]:
display(HTML(table))

#### Sahara

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="loss", key_label="Loss", sharey=True, dataset="Sahara",
                                     key_best="min",)

In [None]:
filename = "figures/cover_model_sahara_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False, dataset="Sahara",
                                     key_best="max",
                                     hlines=(0.0, 1.0,), start_at=5)

In [None]:
filename = "figures/cover_model_sahara_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_mcc", key_label="MCC", sharey=False, dataset="Sahara",
                                     key_best="max",
                                     hlines=(1.0,), start_at=5)

In [None]:
filename = "figures/cover_model_sahara_cover_mcc.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_r_square", key_label="Segmentation Map $R^2$", sharey=False, dataset="Sahara",
                                     key_best="max",
                                     hlines=(1.0,), start_at=5)

In [None]:
filename = "figures/cover_model_sahara_seg_map_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Sahel

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="loss", key_label="Loss", sharey=True, dataset="Sahel-Sudan",
                                     key_best="min",)

In [None]:
filename = "figures/cover_model_sahel_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False, dataset="Sahel-Sudan",
                                     key_best="max",
                                     hlines=(0.0, 1.0,), start_at=5)

In [None]:
filename = "figures/cover_model_sahel_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_mcc", key_label="MCC", sharey=False, dataset="Sahel-Sudan",
                                     key_best="max",
                                     hlines=(1.0, 0.0), start_at=0)

In [None]:
filename = "figures/cover_model_sahel_cover_mcc.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_r_square", key_label="Segmentation Map $R^2$", sharey=False, dataset="Sahel-Sudan",
                                     key_best="max",
                                     hlines=(0.0, 1.0), start_at=5)

In [None]:
filename = "figures/cover_model_sahel_seg_map_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Rwanda

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="loss", key_label="Loss", sharey=True, dataset="Rwanda",
                                     key_best="min",)

In [None]:
filename = "figures/cover_model_rwanda_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False, dataset="Rwanda",
                                     key_best="max",
                                     hlines=(0.0, 1.0,), start_at=5)

In [None]:
filename = "figures/cover_model_rwanda_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_mcc", key_label="MCC", sharey=False, dataset="Rwanda",
                                     key_best="max",
                                     hlines=(1.0, 0.0), start_at=0)

In [None]:
filename = "figures/cover_model_rwanda_cover_mcc.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     label_cols=["model", "loss_short"], label_type="str", legend_cols=1,
                                     key="segmentation_map_r_square", key_label="Segmentation Map $R^2$", sharey=False, dataset="Rwanda",
                                     key_best="max",
                                     hlines=(0.0, 1.0), start_at=5)

In [None]:
filename = "figures/cover_model_rwanda_seg_map_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Sahel-Sudan

### Experiment: Density GTs

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/density-gt.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Loss", "Seed"],
                                            format_rows="best", showindex="never", tablefmt="html")

In [None]:
display(HTML(table))

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/density-gt.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Loss", "Seed", "Best Epoch"],
                                            format_rows="best", showindex="never", tablefmt="html", is_test=True)

In [None]:
display(HTML(table))

#### Sahara

In [None]:
!mkdir -p figures

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="loss", key_label="Loss", sharey=True,
                                     dataset="Sahara", label_type="str", label_cols=["model", "gt"],
                                     monitor_key="val_loss", monitor_best="min",
                                     legend_cols=1,)

In [None]:
filename = "figures/density_model_sahara_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="count_r_square", key_label="Count $R^2$", sharey=False,
                                     dataset="Sahara", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(1.0,))

In [None]:
filename = "figures/density_model_sahara_count_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="density_map_r_square", key_label="Density Map $R^2$", sharey=False,
                                     dataset="Sahara", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,))

In [None]:
filename = "figures/density_model_sahara_dm_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False,
                                     dataset="Sahara", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,), start_at=10)

In [None]:
filename = "figures/density_model_sahara_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Sahel

In [None]:
!mkdir -p figures

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="loss", key_label="Loss", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["model", "gt"],
                                     monitor_key="val_loss", monitor_best="min",
                                     legend_cols=1,)

In [None]:
filename = "figures/density_model_sahel_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="count_r_square", key_label="Count $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(1.0,))

In [None]:
filename = "figures/density_model_sahel_count_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="density_map_r_square", key_label="Density Map $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,))

In [None]:
filename = "figures/density_model_sahel_dm_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,), start_at=10)

In [None]:
filename = "figures/density_model_sahel_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

#### Rwanda

In [None]:
!mkdir -p figures

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="loss", key_label="Loss", sharey=False,
                                     dataset="Rwanda", label_type="str", label_cols=["model", "gt"],
                                     monitor_key="val_loss", monitor_best="min",
                                     legend_cols=1,)

In [None]:
filename = "figures/density_model_rwanda_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="count_r_square", key_label="Count $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(1.0,))

In [None]:
filename = "figures/density_model_rwanda_count_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="density_map_r_square", key_label="Density Map $R^2$", sharey=False,
                                     dataset="Rwanda", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,))

In [None]:
filename = "figures/density_model_rwanda_dm_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="cover_r_square", key_label="Cover $R^2$", sharey=False,
                                     dataset="Rwanda", label_type="str", label_cols=["model", "gt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,), start_at=10)

In [None]:
filename = "figures/density_model_rwanda_cover_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

### Experiment: Density Optimizer

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/density-250.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Seed", "Loss", "Folder"],
                                            format_rows="best", showindex="never", tablefmt="html")

In [None]:
display(HTML(table))

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 3.5), legend_pos=1, show_legend=True,
                                     key="loss", key_best="min", key_label="Loss", sharey=True,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["opt"],
                                     legend_cols=1, combined=True, legend_loc="upper left", bbox_to_anchor=(1.0, 1.03))

In [None]:
filename = "figures/density_opt_sahel_loss.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="count_r_square", key_label="Count $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["opt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,))

In [None]:
filename = "figures/density_opt_sahel_count_r2.pdf"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="density_map_r_square", key_label="Density Map $R^2$", sharey=False,
                                     dataset="Sahel-Sudan", label_type="str", label_cols=["opt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,))

### Experiment: Multitask

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/multitask.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Loss", "Seed"],
                                            format_rows="best", showindex="never", tablefmt="html",
                                            none_replacement="")

In [None]:
display(HTML(table))

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="loss", key_best="min", key_label="Loss", sharey=False,
                                     dataset="Sahara", label_type="str", label_cols=["task", "opt"],
                                     legend_cols=1, combined=True, legend_loc="upper left", bbox_to_anchor=(1.0, 1.03))

In [None]:
fig = dlc.tools.evaluation.make_plot(models_path, settings_path, figsize=(12, 6), legend_pos=1, show_legend=True,
                                     key="count_r_square", key_label="Count $R^2$", sharey=False,
                                     dataset="Sahara", label_type="str", label_cols=["task","opt"],
                                     key_best="max",
                                     legend_cols=1, hlines=(0.0, 1.0,), start_at=5)

In [None]:
models_path = f"{BASE_PATH}/models/archive_final"
settings_path = f"counting-trees-private/config/tables/multitask.yml"
table, df = dlc.tools.evaluation.make_table(models_path, settings_path, hidecolumns=["Loss", "Seed", "Folder", "Best Epoch"],
                                            format_rows="best", showindex="never", tablefmt="html",
                                            none_replacement="", is_test=True)

In [None]:
display(HTML(table))

### Evaluation on the test set

In [None]:
%%writefile eval.sh
#!/bin/bash

export DLC_PROJECT_DIRECTORY="./counting-trees-private"
export DLC_DATA_DIRECTORY="./data"

python3 $DLC_PROJECT_DIRECTORY/scripts/train_eval.py --config \
    seeds \
    ds/sahel \
    cover/1/model/d4-ds cover/1/loss/b1wds \
    cover/1/eval-metrics \
    opt/cyclical/sgd \
    --models-dir "models/archive_final" \
    --eval

In [None]:
!sh eval.sh

## Visualizing model predictions

In [None]:
dataset = "sahara-sahel"
data_directory = "data"

In [None]:
dataset = "rwanda"
data_directory = "data"

In [None]:
%%time
frames = download_frames_zip(dataset, namespace="frames", suffix="density-0_9")

### Sahara

In [None]:
# Best DM R2: 9, 6
# Worst DM R2: 4, 5
include = (4, 5, 6, 9)

In [None]:
settings_path = f"counting-trees-private/config/tables/density-gt.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = None
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahara_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      include=(0, 1, 2, 3),
                                                      label_cols=("model", "gt"), cache=cache)

Search for interesting cases:

In [None]:
rng = np.random.default_rng()

In [None]:
n_skip = rng.integers(0, 172, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahara_data["ds"], sahara_data["models"], sahara_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
n_skip_list = [18]

In [None]:
n_skip = n_skip_list[0]

Example for models with independent GTs

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahara_data["ds"], sahara_data["models"], sahara_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(8.5, 8), show_hist=True)

Example for models with shared GT, here we plot the segmentation mask and the thresholded density map.

In [None]:
plotter = dlc.tools.evaluation.DMCoverPredictionPlotter(sahara_data["ds"], sahara_data["models"], sahara_data["labels"])
# Note, now the scalar is the cover not the count
fig = plotter.plot_common_gt(n_skip, figsize=(10.2, 4), show_hist=True, scalar_format="{:.2E}")

In [None]:
settings_path = f"counting-trees-private/config/tables/cover.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahara_cover_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      label_cols=("model", "loss"), cache=cache,
                                                      include=(2, 3, 6, 7),
                                                      number_labels=False)

In [None]:
n_skip = rng.integers(0, 25, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahara_cover_data["ds"], sahara_cover_data["models"], sahara_cover_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahel_cover_data["ds"], sahel_cover_data["models"], sahel_cover_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, scalar_format="{:.2E}", log=True)

In [None]:
filename = f"figures/pred_cover_sahara_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
settings_path = f"counting-trees-private/config/tables/multitask.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahara_multi_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      label_cols=("task", "gt"), cache=cache,
                                                      include=(1,2,3),
                                                      number_labels=False)

In [None]:
n_skip = 18
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahara_multi_data["ds"], sahara_multi_data["models"], sahara_multi_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, scalar_format="{:.2E}", log=True)

In [None]:
filename = f"figures/pred_cover_multi_sahara_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

### Sahel

In [None]:
include = (10, 13, 17, 23)

In [None]:
settings_path = f"counting-trees-private/config/tables/density-gt.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahel_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      include=include,
                                                      label_cols=("model", "gt"), cache=cache,
                                                      number_labels=False)

In [None]:
settings_path = f"counting-trees-private/config/tables/multitask.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahel_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      #include=(6,7,8,9),
                                                      include=(5,8,9),
                                                      label_cols=("task", "gt"), cache=cache,
                                                      number_labels=False)

In [None]:
n_skip = rng.integers(0, 70, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_data["ds"], sahel_data["models"], sahel_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
14, 15, 51

In [None]:
n_skip = 51

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_data["ds"], sahel_data["models"], sahel_data["labels"],
                                                      mask=False, y_band_idx=(0, 0, 1, 1,))
fig = plotter.plot_unique_gt(n_skip, figsize=(12, 8), show_hist=True, log=True)

In [None]:
filename = f"figures/pred_multi_sahel_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahel_data["ds"], sahel_data["models"], sahel_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, scalar_format="{:.2E}", log=False)

In [None]:
filename = f"figures/pred_multi_cover_sahara_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
n_skip = 18

plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_data["ds"], sahel_data["models"], sahel_data["labels"],
                                                      mask=False)
fig = plotter.plot_unique_gt(n_skip, figsize=(12, 4), show_hist=False, log=False)

In [None]:
plotter = dlc.tools.evaluation.DMCoverPredictionPlotter(sahel_data["ds"], sahel_data["models"], sahel_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, scalar_format="{:.2E}")

In [None]:
settings_path = f"counting-trees-private/config/tables/density-250.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahel_opt_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      label_cols=("opt",), cache=cache,
                                                      number_labels=False)

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_opt_data["ds"], sahel_opt_data["models"], sahel_opt_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, log=True)

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_opt_data["ds"], sahel_opt_data["models"], sahel_opt_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=False, log=True)

In [None]:
settings_path = f"counting-trees-private/config/tables/cover.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/sahara-sahel")
sahel_cover_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                      label_cols=("model", "loss"), cache=cache,
                                                      include=(10, 11, 14, 15),
                                                      number_labels=False)

In [None]:
n_skip = rng.integers(0, 70, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahel_cover_data["ds"], sahel_cover_data["models"], sahel_cover_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
plotter = dlc.tools.evaluation.CoverPredictionPlotter(sahel_cover_data["ds"], sahel_cover_data["models"], sahel_cover_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, log=True)

In [None]:
filename = f"figures/pred_cover_sahara_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

### Rwanda

In [None]:
settings_path = f"counting-trees-private/config/tables/cover.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/rwanda")
rwanda_cover_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                                   include=(18, 19, 22,  23),
                                                                   label_cols=("model", "gt"), cache=cache,
                                                                    number_labels=True)

In [None]:
settings_path = f"counting-trees-private/config/tables/density-gt.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/rwanda")
rwanda_density_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                                     include=(26, 28, 32, 33),
                                                                     label_cols=("model", "gt"), cache=cache,
                                                                      number_labels=False)

In [None]:
settings_path = f"counting-trees-private/config/tables/multitask.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/rwanda")
rwanda_density_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                                     #include=(11,12,13,14),
                                                                     include=(10,13,14),
                                                                     label_cols=("task", "gt"), cache=cache,
                                                                      number_labels=False)

In [None]:
n_skip = rng.integers(0, 506, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CountPredictionPlotter(rwanda_density_data["ds"], rwanda_density_data["models"], rwanda_density_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
filename = f"figures/pred_multi_rwanda_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
n_skip = 194

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(rwanda_density_data["ds"], rwanda_density_data["models"], rwanda_density_data["labels"],
                                                                     y_band_idx=(0, 0, 1, 1),
)
fig = plotter.plot_unique_gt(n_skip, figsize=(12, 8), show_hist=True, log=True)

In [None]:
filename = f"figures/pred_multi_density_rwanda_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
plotter = dlc.tools.evaluation.CoverPredictionPlotter(rwanda_density_data["ds"], rwanda_density_data["models"], rwanda_density_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(14, 4), show_hist=True)

In [None]:
filename = f"figures/pred_multi_cover_rwanda_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)

In [None]:
plotter = dlc.tools.evaluation.CountPredictionPlotter(sahel_opt_data["ds"], sahel_opt_data["models"], sahel_opt_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=False, log=True)

In [None]:
settings_path = f"counting-trees-private/config/tables/cover.yml"
models_path = pathlib.Path(f"{BASE_PATH}/models/archive_final")
cache = dlc.tools.cache.ArrayCache()
frames_path = pathlib.Path(f"./data/datasets/frames/rwanda")
rwanda_cover_data = dlc.tools.evaluation.load_models_from_settings(settings_path, models_path, frames_path,
                                                                   include=(18, 19, 22,  23),
                                                                   label_cols=("model", "loss"), cache=cache,
                                                                    number_labels=False)

In [None]:
n_skip = rng.integers(0, 506, size=1)[0]
print(n_skip)
plotter = dlc.tools.evaluation.CoverPredictionPlotter(rwanda_cover_data["ds"], rwanda_cover_data["models"], rwanda_cover_data["labels"])
fig = plotter.plot_unique_gt(n_skip, figsize=(6, 4), show_hist=False, include=(0,))

In [None]:
plotter = dlc.tools.evaluation.CoverPredictionPlotter(rwanda_cover_data["ds"], rwanda_cover_data["models"], rwanda_cover_data["labels"])
fig = plotter.plot_common_gt(n_skip, figsize=(12, 4), show_hist=True, log=True)

In [None]:
filename = f"figures/pred_cover_rwanda_{n_skip}.png"
fig.savefig(filename, bbox_inches="tight")
files.download(filename)