In [None]:
# Copyright (c) 2023 William Locke

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

This notebook is intended to be run in Google Colab with access to corresponding Google Drive files. If running locally or on another service, change import and install code accordingly.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lu-liang-geo/UAV_Tree_Detection/blob/main/notebooks/Tree_Detection_and_Segmentation.ipynb)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
!unzip '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/training.zip' -d "/content/training"
!unzip '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/annotations.zip' -d "/content"
!unzip '/content/drive/MyDrive/UAV/Data/example_mosaic/SA7_RGB_Multi.zip'

In [None]:
%%capture
!pip install rasterio
!pip install supervision

In [None]:
#@title Copy GroundingDINO from IDEA-Research github repository
%%capture

%cd /content
import os
if not os.path.exists('/content/weights'):
  !mkdir /content/weights
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd /content/GroundingDINO
!pip install -q .
%cd /content/weights
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth
%cd /content

In [None]:
#@title Copy SAM from personal github repository
%%capture

%cd /content
import os
if os.path.exists('/content/segment-anything'):
  !rm -r /content/segment-anything
!git clone https://github.com/lu-liang-geo/UAV_Tree_Detection.git
%cd /content/segment-anything
!pip install -q .
%cd /content/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
%cd /content

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
import os
import cv2
import glob
import torch
import rasterio
import numpy as np
from PIL import Image
import supervision as sv
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from segment_and_detect_anything.detr import box_ops
from GroundingDINO.groundingdino.util.inference import Model
from segment_and_detect_anything import NEONTreeDataset, sam_model_registry, SamPredictor

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load GroundingDINO Model
%%capture
GROUNDING_DINO_CONFIG_PATH = "/content/GroundingDINO/groundingdino/config/GroundingDINO_SwinB_cfg.py"
GROUNDING_DINO_CHECKPOINT_PATH = "/content/weights/groundingdino_swinb_cogcoor.pth"
gd_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

In [None]:
# Load SAM Model
sam_model = sam_model_registry["vit_h"](checkpoint="/content/weights/sam_vit_h_4b8939.pth")
sam_predictor = SamPredictor(sam_model)

# Example 1: Clearly Separated Trees

In [None]:
# Load Image
rgb_path = "/content/SA7_RGB_Multi_transparent_mosaic_group1_2_2.tif"
with rasterio.open(rgb_path) as img :
  rgb_img = img.read()[:-1].transpose(1,2,0)
bgr_img = rgb_img[:,:,::-1]

# Show image
plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(rgb_img)
plt.show()

## Detect with GroundingDINO

In [None]:
# Detect Trees
classes = ['tree']
threshold = 0.2

gd_boxes_raw = gd_model.predict_with_classes(
    image=bgr_img,
    classes=classes,
    box_threshold=threshold,
    text_threshold=threshold)

gd_boxes = box_ops.custom_nms(gd_boxes_raw)

In [None]:
box_annotator = sv.BoxAnnotator(thickness=10, color=sv.Color.red())
gd_plot = box_annotator.annotate(scene=bgr_img.copy(), detections=gd_boxes, skip_label=True)

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(gd_plot[:,:,::-1])
plt.show()

## Segment with SAM

In [None]:
sam_predictor.set_image(rgb_img)

In [None]:
def segment(sam_predictor: SamPredictor, boxes: np.ndarray) -> np.ndarray:
    result_masks = []
    for box in boxes:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            multimask_output=False
        )
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)

In [None]:
gd_boxes.mask = segment(sam_predictor, gd_boxes.xyxy)

In [None]:
mask_annotator = sv.MaskAnnotator()
gd_masks = mask_annotator.annotate(scene=gd_plot.copy(), detections=gd_boxes)

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(gd_masks[:,:,::-1])
plt.show()

# Example 2: Closely Grouped Trees

In [None]:
# Load Image

ds = NEONTreeDataset(image_path='/content/training', ann_path='/content/annotations')
img = ds.get_image('2018_BART_4_322000_4882000_image_crop')
rgb_img = img['rgb']
bgr_img = rgb_img[:,:,::-1].copy()
plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(rgb_img)
plt.show()

## Detect with GroundingDINO

In [None]:
# Detect Trees
classes = ['tree']
threshold = 0.2

gd_boxes_raw = gd_model.predict_with_classes(
    image=bgr_img,
    classes=classes,
    box_threshold=threshold,
    text_threshold=threshold)

gd_boxes = box_ops.custom_nms(gd_boxes_raw)

In [None]:
box_annotator = sv.BoxAnnotator(thickness=2, color=sv.Color.red())
gd_plot = box_annotator.annotate(scene=bgr_img.copy(), detections=gd_boxes, skip_label=True)

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(gd_plot[:,:,::-1])
plt.show()

## True Annotations

In [None]:
true_boxes_raw = img['annotation']
true_boxes = sv.Detections(xyxy=true_boxes_raw,
                           confidence=np.ones(len(true_boxes_raw)),
                           class_id=np.zeros(len(true_boxes_raw), dtype='int64'))

box_annotator = sv.BoxAnnotator(thickness=2, color=sv.Color.red())
true_plot = box_annotator.annotate(scene=bgr_img.copy(), detections=true_boxes, skip_label=True)

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(true_plot[:,:,::-1])
plt.show()

## Segment with SAM

In [None]:
sam_predictor.set_image(rgb_img)

In [None]:
def segment(sam_predictor: SamPredictor, boxes: np.ndarray) -> np.ndarray:
    result_masks = []
    for box in boxes:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            multimask_output=False
        )
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)

In [None]:
gd_boxes.mask = segment(sam_predictor, gd_boxes.xyxy)
true_boxes.mask = segment(sam_predictor, true_boxes.xyxy)

In [None]:
mask_annotator = sv.MaskAnnotator()

gd_masks = mask_annotator.annotate(scene=gd_plot.copy(), detections=gd_boxes)
true_masks = mask_annotator.annotate(scene=true_plot.copy(), detections=true_boxes)

## Plot Masks

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(15,10))

axs[0].set_title('GroundingDINO', fontsize=22)
axs[0].imshow(gd_masks[:,:,::-1])
axs[1].set_title('True Detection', fontsize=22)
axs[1].imshow(true_masks[:,:,::-1])
plt.tight_layout()
for ax in axs.ravel():
  ax.axis('off')