# Tree crown detection using DeepForest

{opticon}`tag`
{bdg-primary}`Forest`
{bdg-secondary}`Modelling`
{bdg-warning}`Standard`
{bdg-info}`Python`

<p align="left">
    <a href="https://github.com/eds-book-gallery/15d986da-2d7c-44fb-af71-700494485def/blob/main/LICENSE">
        <img alt="license" src="https://img.shields.io/badge/license-MIT-yellow.svg">
    </a>
    <a href="https://notebooks.gesis.org/binder/v2/gh/eds-book-gallery/15d986da-2d7c-44fb-af71-700494485def/main?labpath=notebook.ipynb">
        <img alt="binder" src="https://mybinder.org/badge_logo.svg">
    </a>
    <a href="https://github.com/eds-book-gallery/15d986da-2d7c-44fb-af71-700494485def/actions/workflows/render.yaml">
        <img alt="render" src="https://github.com/eds-book-gallery/15d986da-2d7c-44fb-af71-700494485def/actions/workflows/render.yaml/badge.svg">
    </a>
    <a href="https://github.com/alan-turing-institute/environmental-ds-book/pull/5">
        <img alt="review" src="https://img.shields.io/badge/view-review-purple">
    </a>
    <br/>
</p>

<p align="left">
    <a href="https://w3id.org/ro-id/15d986da-2d7c-44fb-af71-700494485def">
        <img alt="RoHub" src="https://img.shields.io/badge/RoHub-FAIR_Executable_Research_Object-2ea44f?logo=Open+Access&logoColor=blue">
    </a>
    <a href="https://doi.org/10.24424/td9g-0533">
        <img alt="doi" src="https://zenodo.org/badge/DOI/10.24424/td9g-0533.svg">
    </a>
</p>


## Context
### Purpose
Detect tree crown using a state-of-art Deep Learning model for object detection.

### Modelling approach
A prebuilt Deep Learning model, named *DeepForest*, is used to predict individual tree crowns from an airborne RGB image. *DeepForest* was trained on data from the National Ecological Observatory Network (NEON). _DeepForest_ was implemented in Python 3.7 using initally Tensorflow v1.14 but later moved to Pytorch. Further details can be found in the [package documentation](https://deepforest.readthedocs.io/en/latest/).

### Highlights
* Fetch a NEON sample image from a Zenodo repository.
* Retrieve and plot the reference annotations (bounding boxes) for the target image.
* Load and use a pretrained *DeepForest* model to generate full-image or tile-wise prediction.
* Indicate the pros and cons of full-image and tile-wise prediction.

### Contributions

#### Notebook
* Alejandro Coca-Castro (author), The Alan Turing Institute, [@acocac](https://github.com/acocac)
* Matt Allen (reviewer), Department of Geography - University of Cambridge, [@mja2106](https://github.com/mja2106)

#### Modelling codebase
* Ben Weinstein (maintainer & developer), University of Florida, [@bw4sz](https://github.com/bw4sz)
* Henry Senyondo (support maintainer), University of Florida, [@henrykironde](https://github.com/henrykironde)
* Ethan White (PI and author), University of Florida, [@weecology](https://github.com/ethanwhite)
* Other contributors are listed in the [GitHub repo](https://github.com/weecology/DeepForest/graphs/contributors)

#### Modelling publications
```{bibliography}
  :style: plain
  :list: bullet
  :filter: topic % "15d986da-2d7c-44fb-af71-700494485def"
```

:::{note}
The author acknowledges [DeepForest](https://deepforest.readthedocs.io/en/latest/) contributors. Some code snippets were extracted from DeepForest [GitHub public repository](https://github.com/weecology/DeepForest).
:::

## Install and load libraries

In [None]:
!pip -q install torchvision==0.10.0
!pip -q install torch==1.9.0
!pip -q install DeepForest==1.0.0
!pip -q install geoviews

In [None]:
import glob
import os
import urllib
import numpy as np

import intake
import matplotlib.pyplot as plt
import xmltodict
import cv2

import torch

from shapely.geometry import box
import pandas as pd
from geopandas import GeoDataFrame
import xarray as xr
import panel as pn
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
from skimage.exposure import equalize_hist

import pooch

import warnings
warnings.filterwarnings(action='ignore')

hv.extension('bokeh', width=100)
%matplotlib inline

## Set project structure

In [None]:
notebook_folder = './notebook'
if not os.path.exists(notebook_folder):
    os.makedirs(notebook_folder)

## Fetch a RGB image from Zenodo

Fetch a sample image from a publicly accessible location.

In [None]:
pooch.retrieve(
    url="doi:10.5281/zenodo.3459803/2018_MLBS_3_541000_4140000_image_crop.tif",
    known_hash="md5:01a7cf23b368ff9e006fda8fe9ca4c8c",
    path=notebook_folder,
    fname="2018_MLBS_3_541000_4140000_image_crop.tif"
)

In [None]:
# set catalogue location
catalog_file = os.path.join(notebook_folder, 'catalog.yaml')

with open(catalog_file, 'w') as f:
    f.write('''
sources:
  NEONTREE_rgb:
    driver: xarray_image
    description: 'NeonTreeEvaluation RGB images (collection)'
    args:
      urlpath: "{{ CATALOG_DIR }}/2018_MLBS_3_541000_4140000_image_crop.tif"
      ''')

Load an intake catalog for the downloaded data.

In [None]:
cat_tc = intake.open_catalog(catalog_file)

## Load sample image

Here we use `intake` to load the image through `dask`.

In [None]:
tc_rgb = cat_tc["NEONTREE_rgb"].to_dask()
tc_rgb

## Load and prepare labels

In [None]:
# functions to load xml and extract bounding boxes

# function to create ordered dictionary of .xml annotation files
def loadxml(imagename):
    imagename = imagename.replace('.tif','')
    fullurl = "https://raw.githubusercontent.com/weecology/NeonTreeEvaluation/master/annotations/" + imagename + ".xml"
    file = urllib.request.urlopen(fullurl)
    data = file.read()
    file.close()
    data = xmltodict.parse(data)
    return data

# function to extract bounding boxes
def extractbb(i):
    bb = [f['bndbox'] for f in allxml[i]['annotation']['object']]
    return bb

In [None]:
filenames = glob.glob(os.path.join(notebook_folder, '*.tif'))
filesn = [os.path.basename(i) for i in filenames]

allxml = [loadxml(i) for i in filesn]
bball = [extractbb(i) for i in range(0,len(allxml))]
print(len(bball))

## Visualise image and labels

In [None]:
# function to plot images
def cv2_imshow(a, **kwargs):
    a = a.clip(0, 255).astype('uint8')
    # cv2 stores colors as BGR; convert to RGB
    if a.ndim == 3:
        if a.shape[2] == 4:
            a = cv2.cvtColor(a, cv2.COLOR_BGRA2RGBA)
        else:
            a = cv2.cvtColor(a, cv2.COLOR_BGR2RGB)

    return plt.imshow(a, **kwargs)

In [None]:
image = tc_rgb

# plot predicted bbox
image2 = image.values.copy()
target_bbox = bball[0]
print(type(target_bbox))
print(target_bbox[0:2])

In [None]:
for row in target_bbox:
    cv2.rectangle(image2, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), (0,255,255), thickness=2, lineType=cv2.LINE_AA)

plot_reference = plt.figure(figsize=(15,15))
cv2_imshow(np.flip(image2,2))
plt.title('Reference labels',fontsize='xx-large')
plt.show()

## Load *DeepForest* pretrained model

Now we're going to load and use a pretrained model from the `deepforest` package.

In [None]:
from deepforest import main

# load deep forest model
model = main.deepforest()
model.use_release()
model.current_device = torch.device("cpu")

In [None]:
pred_boxes = model.predict_image(image=image.values)
print(pred_boxes.head(5))

In [None]:
image3 = image.values.copy() 

for index, row in pred_boxes.iterrows():
    cv2.rectangle(image3, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), (0,255,255), thickness=2, lineType=cv2.LINE_AA)

plot_fullimage = plt.figure(figsize=(15,15))
cv2_imshow(np.flip(image3,2))
plt.title('Full-image predictions',fontsize='xx-large')
plt.show()

## Comparison full image prediction and reference labels

Let's compare the labels and predictions over the tested image.

In [None]:
plot_referandfullimage = plt.figure(figsize=(15,15))
ax1 = plt.subplot(1, 2, 1), cv2_imshow(np.flip(image2,2))
ax1[0].set_title('Reference labels',fontsize='xx-large')
ax2 = plt.subplot(1, 2, 2), cv2_imshow(np.flip(image3,2))
ax2[0].set_title('Full-image predictions', fontsize='xx-large')
plt.show() # To show figure

**Interpretation:**

*   It seems the pretrained model doesn't perform well with the tested image.
*   The low performance might be explained due to the pretrained model used 10 cm resolution images.

## Tile-based prediction

To optimise the predictions, the DeepForest can be run [tile-wise](https://deepforest.readthedocs.io/en/latest/better.html).

The following cells show how to define the optimal window i.e. tile size.

In [None]:
from deepforest import preprocess

#Create windows of 400px
windows = preprocess.compute_windows(image.values, patch_size=400,patch_overlap=0)
print(f'We have {len(windows)} windows in the image')

In [None]:
#Loop through a few sample windows, crop and predict
plot_tilewindows, axes, = plt.subplots(nrows=2,ncols=2, figsize=(15,15))
axes = axes.flatten()
for index2 in range(4):
    crop = image.values[windows[index2].indices()]
    #predict in bgr channel order, color predictions in red.
    boxes = model.predict_image(image=np.flip(crop[...,::-1],2), return_plot = True)

    #but plot in rgb channel order
    axes[index2].imshow(boxes[...,::-1])
    axes[index2].set_title(f'Prediction in Window {index2 + 1} out of {len(windows)}', fontsize='xx-large')

Once a suitable tile size is defined, we can run in a batch using the `predict_tile` function:

In [None]:
tile = model.predict_tile(image=image.values,return_plot=False,patch_overlap=0,iou_threshold=0.05,patch_size=400)

# plot predicted bbox
image_tile = image.values.copy()

for index, row in tile.iterrows():
    cv2.rectangle(image_tile, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), (0, 255, 255), thickness=2, lineType=cv2.LINE_AA)

plot_tilewise = plt.figure(figsize=(15,15))
ax1 = plt.subplot(1, 2, 1), cv2_imshow(np.flip(image2,2))
ax1[0].set_title('Reference labels',fontsize='xx-large')
ax2 = plt.subplot(1, 2, 2), cv2_imshow(np.flip(image_tile,2))
ax2[0].set_title('Tile-wise predictions', fontsize='xx-large')
plt.show() # To show figure

**Interpretation**

* The tile-based prediction provides more reasonable results than predicting over the whole image.
* While the prediction looks closer to the reference labels, there seem to be some tiles edges artefacts. This will require further investigation i.e. inspecting the `deepforest` tile-wise prediction function to understand how predictions from different tiles are combined after the model has made them.

## Interactive plots

The plot below summarises above static plots by interactively comparing bounding boxes and scores of full-image and tile-wise predictions. To zoom-in  the reference NEON RGB image with its original resolution change `rasterize=True` to `rasterize=False`.

In [None]:
## function to convert bbox in dictionary to geopandas
def bbox_to_geopandas(bbox_df):
    geometry = [box(x1, y1, x2, y2) for x1,y1,x2,y2 in zip(bbox_df.xmin, bbox_df.ymin, bbox_df.xmax, bbox_df.ymax)]
    poly_geo = GeoDataFrame(bbox_df, geometry=geometry)
    return poly_geo

## prepare reference and prediction bbox
### convert data types for reference bbox dictionary
reference = pd.DataFrame.from_dict(target_bbox, dtype=int)
reference[['xmin', 'ymin', 'xmax', 'ymax']] = reference[['xmin', 'ymin', 'xmax', 'ymax']].astype(int)

poly_reference = bbox_to_geopandas(reference)
poly_prediction_image = bbox_to_geopandas(pred_boxes)
poly_prediction_tile = bbox_to_geopandas(tile)

## settings for hvplot objects
settings_vector = dict(fill_color=None, width=400, height=400, clim=(0,1), fontsize={'title': '110%'})
settings_image = dict(x='x', y='y', data_aspect=1, xaxis=False, yaxis=None)

## create hvplot objects
plot_RGB = tc_rgb.hvplot.rgb(**settings_image, bands='channel', hover=False, rasterize=True)
plot_vector_reference = poly_reference.hvplot(hover_cols=False, legend=False).opts(title='Reference labels', alpha=1, **settings_vector)
plot_vector_image = poly_prediction_image.hvplot(hover_cols=['score'], legend=False).opts(title='Full-image predictions', alpha=0.5, **settings_vector)
plot_vector_tile = poly_prediction_tile.hvplot(hover_cols=['score'], legend=False).opts(title='Tile-wise predictions', alpha=0.5, **settings_vector)

plot_comparison = pn.Row(pn.Column(plot_RGB * plot_vector_reference, 
                         plot_RGB * plot_vector_image),
                         pn.Column(pn.Spacer(background='white', width=400, height=400),  
                         plot_RGB * plot_vector_tile), scroll=True)

plot_comparison.embed()

## Summary

This notebook has demonstrated the use of:

* `pooch` and `intake` package to fetch data from a Zenodo repository containing training data files of the [NeonTreeEvaluation Benchmark](https://zenodo.org/record/3459803#.YhI54xPP30o).
* `deepforest` package to easily load and run a pretrained model for tree crown classification from very-high resolution RGB imagery.
* The `tile-wise` option in `deepforest` considerably improve tree crown predictions. However, the user should define an optimal tile size.
* `cv2` to generate static plots comparing reference against bounding boxes and scores of two prediction strategies, full-image and tile-wise predictions.
* `hvplot` and `panel` to interactively compare both prediction strategies against reference labels.

## Citing this Notebook

Alejandro Coca-Castro, and Matt Allen. "Tree crown detection using DeepForest (Jupyter Notebook) published in the Environmental Data Science book." ROHub. Feb 20 ,2022. https://doi.org/10.24424/td9g-0533.

## Additional information
**License**: The code in this notebook is licensed under the MIT License. The Environmental Data Science book is licensed under the Creative Commons by Attribution 4.0 license. See further details [here](https://github.com/alan-turing-institute/environmental-ds-book/blob/master/LICENSE.md).

**Contact**: If you have any suggestion or report an issue with this notebook, feel free to [create an issue](https://github.com/alan-turing-institute/environmental-ds-book/issues/new/choose) or send a direct message to [environmental.ds.book@gmail.com](mailto:environmental.ds.book@gmail.com).

In [None]:
from datetime import date
print(f'Last tested: {date.today()}')

## Outputs registration
The cell below is dedicated to save the notebook outputs for registering them into a Zenodo repository curated by the Environmental DS book.

In [None]:
outputs = {
    'static_tables': {
        'filenames': ['bbox_reference','bbox_fullimage','bbox_tilewise'],
        'data':[poly_reference, poly_prediction_image, poly_prediction_tile]},
    'static_figures': {
        'filenames': ['reference','fullimage_prediction','fullimage_comparison','tilewise_prediction','tilewise_comparison'],
        'data':[plot_reference, plot_fullimage, plot_referandfullimage, plot_tilewindows, plot_tilewise]},
    'interactive_figures': {
        'filenames': ['comparison_interactive'],
        'data':[plot_comparison]},
}

#save static tables
if len(outputs['static_tables']['filenames']) > 0:
    [data.to_csv(os.path.join(notebook_folder,outputs['static_tables']['filenames'][x] + '.csv')) for x, data in enumerate(outputs['static_tables']['data'])]

#save static figures
if len(outputs['static_figures']['filenames']) > 0:
    [data.savefig(os.path.join(notebook_folder,outputs['static_figures']['filenames'][x]  + '.png')) for x, data in enumerate(outputs['static_figures']['data'])]

#save interactive figures
if len(outputs['interactive_figures']['filenames']) > 0:
    [data.save(os.path.join(notebook_folder,outputs['interactive_figures']['filenames'][x]  + '.html')) for x, data in enumerate(outputs['interactive_figures']['data'])]