# Cropping pole images with Grounded SAM

As a way of integrating NLP techniques into this work, a good idea is to employ Grounded SAM to segment the images. This is a combination of two techniques:

- Grounding DINO: uses NLP techniques to compute bounding boxes of objects in an image.
- SAM: automatically segments objects in an image.

## Required imports

In [1]:
import cv2
import os
import numpy as np
import supervision as sv
import spectral.io.envi as envi

import torch
import torchvision

%matplotlib inline
import matplotlib.pyplot as plt

from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor

DEVICE = 'cpu'
IMG_FOLDER = '../../RGB'
CROP_FOLDER = '../../RGB/sam_crops'

import os
os.environ['CURL_CA_BUNDLE'] = ''



## The code itself

In [29]:
def crop_img(path):
    # GroundingDINO config and checkpoint
    GROUNDING_DINO_CONFIG_PATH = "../Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    GROUNDING_DINO_CHECKPOINT_PATH = "../Grounded-Segment-Anything/groundingdino_swint_ogc.pth"

    # Segment-Anything checkpoint
    SAM_ENCODER_VERSION = "vit_h"
    SAM_CHECKPOINT_PATH = "../Grounded-Segment-Anything/sam_vit_h_4b8939.pth"

    # Building GroundingDINO inference model
    grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)

    # Building SAM Model and SAM Predictor
    sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
    sam.to(device=DEVICE)
    sam_predictor = SamPredictor(sam)


    # Predict classes and hyper-param for GroundingDINO
    SOURCE_IMAGE_PATH = IMG_FOLDER + '/' + path
    CLASSES = ["A wooden pole"]
    BOX_THRESHOLD = 0.25
    TEXT_THRESHOLD = 0.25
    NMS_THRESHOLD = 0.8


    # load image
    image = cv2.imread(SOURCE_IMAGE_PATH)

    # detect objects
    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes=CLASSES,
        box_threshold=BOX_THRESHOLD,
        text_threshold=BOX_THRESHOLD
    )
    
    crop_img = image[int(detections.xyxy[0,1]):int(detections.xyxy[0,3]),int(detections.xyxy[0,0]):int(detections.xyxy[0,2]),:]
    cv2.imwrite(f"{CROP_FOLDER}/{path[:-4]}_crop.jpg", crop_img)
    
    # annotate image with detections
    box_annotator = sv.BoxAnnotator()
    labels = [
        f"{CLASSES[class_id]} {confidence:0.2f}" 
        for _, _, confidence, class_id, _ 
        in detections]
    annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

    # save the annotated grounding dino image
    #cv2.imwrite("test_dino.jpg", annotated_frame)


    # NMS post process
    print(f"Before NMS: {len(detections.xyxy)} boxes")
    nms_idx = torchvision.ops.nms(
        torch.from_numpy(detections.xyxy), 
        torch.from_numpy(detections.confidence), 
        NMS_THRESHOLD
    ).numpy().tolist()

    detections.xyxy = detections.xyxy[nms_idx]
    detections.confidence = detections.confidence[nms_idx]
    detections.class_id = detections.class_id[nms_idx]

    print(f"After NMS: {len(detections.xyxy)} boxes")

    # Prompting SAM with detected boxes
    def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
        sam_predictor.set_image(image)
        result_masks = []
        for box in xyxy:
            masks, scores, logits = sam_predictor.predict(
                box=box,
                multimask_output=True
            )
            index = np.argmax(scores)
            result_masks.append(masks[index])
        return np.array(result_masks)


    # convert detections to masks
    detections.mask = segment(
        sam_predictor=sam_predictor,
        image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
        xyxy=detections.xyxy
    )
    
    cv2.imwrite(f"{CROP_FOLDER}/{path[:-4]}_mask.jpg", np.where(detections.mask[0], 255, 0))
    img_masked = np.zeros_like(image)
    img_masked[:,:,0] = detections.mask[0]
    img_masked[:,:,1] = detections.mask[0]
    img_masked[:,:,2] = detections.mask[0]
    img_masked = np.where(img_masked, image, 255)
    cv2.imwrite(f"{CROP_FOLDER}/{path[:-4]}_masked.jpg", img_masked)
    return

    # annotate image with detections
    box_annotator = sv.BoxAnnotator()
    mask_annotator = sv.MaskAnnotator()
    labels = [
        f"{CLASSES[class_id]} {confidence:0.2f}" 
        for _, _, confidence, class_id, _ 
        in detections]
    annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
    annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

    # save the annotated grounded-sam image
    cv2.imwrite("test_sam.jpg", annotated_image)

In [30]:
for img in os.listdir("../../RGB/"):
    if os.path.isfile(os.path.join("../../RGB/", img)):
        print(f"Cropping {img}")
        crop_img(img)

Cropping 0_0.jpg
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
Cropping 0_180.jpg
final text_encoder_type: bert-base-uncased



KeyboardInterrupt



## Cropping hyperspectral images

Since the results for RGB images are very good, it is natural to ask whether this technique can also be applied to hyperspectral scans

In [2]:
def crop_hyper(pole_id, img_idx):
    # GroundingDINO config and checkpoint
    GROUNDING_DINO_CONFIG_PATH = "../Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    GROUNDING_DINO_CHECKPOINT_PATH = "../Grounded-Segment-Anything/groundingdino_swint_ogc.pth"

    # Segment-Anything checkpoint
    SAM_ENCODER_VERSION = "vit_h"
    SAM_CHECKPOINT_PATH = "../Grounded-Segment-Anything/sam_vit_h_4b8939.pth"

    # Building GroundingDINO inference model
    grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)

    # Building SAM Model and SAM Predictor
    sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
    sam.to(device=DEVICE)
    sam_predictor = SamPredictor(sam)


    # Predict classes and hyper-param for GroundingDINO
    CLASSES = ["A wooden pole"]
    BOX_THRESHOLD = 0.25
    TEXT_THRESHOLD = 0.25
    NMS_THRESHOLD = 0.8


    # load scan
    paths = os.listdir(os.path.join("C:\\Users\\ext-lugo\\Hyperspectral\\Radiance", str(pole_id)))
    paths = [path for path in paths if "float32.hdr" in path]
    paths.sort()
    image = envi.open(os.path.join("C:\\Users\\ext-lugo\\Hyperspectral\\Radiance", str(pole_id), paths[img_idx])).load()
    image = np.take(image, [8, 24, 39], axis=2)
    image = np.where(image > 0.025, 0, image)
    
    # Normalize scan and convert to uint8
    image[:,:,0] = np.interp(image[:,:,0], (np.min(image[:,:,0]), np.max(image[:,:,0])), (0, 255)).astype(np.uint8)
    image[:,:,1] = np.interp(image[:,:,1], (np.min(image[:,:,1]), np.max(image[:,:,1])), (0, 255)).astype(np.uint8)
    image[:,:,2] = np.interp(image[:,:,2], (np.min(image[:,:,2]), np.max(image[:,:,2])), (0, 255)).astype(np.uint8)
    image = image.astype(np.uint8)
    cv2.imwrite(f"../../Hyperspectral_masks/{pole_id}_{img_idx}_frgb.jpg", image)

    # detect objects
    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes=CLASSES,
        box_threshold=BOX_THRESHOLD,
        text_threshold=BOX_THRESHOLD
    )
    
    crop_img = image[int(detections.xyxy[0,1]):int(detections.xyxy[0,3]),int(detections.xyxy[0,0]):int(detections.xyxy[0,2]),:]
    cv2.imwrite(f"../../Hyperspectral_masks/{pole_id}_{img_idx}_crop.jpg", crop_img)
    
    # annotate image with detections
    box_annotator = sv.BoxAnnotator()
    labels = [
        f"{CLASSES[class_id]} {confidence:0.2f}" 
        for _, _, confidence, class_id, _ 
        in detections]
    annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

    # save the annotated grounding dino image
    #cv2.imwrite("test_dino.jpg", annotated_frame)


    # NMS post process
    print(f"Before NMS: {len(detections.xyxy)} boxes")
    nms_idx = torchvision.ops.nms(
        torch.from_numpy(detections.xyxy), 
        torch.from_numpy(detections.confidence), 
        NMS_THRESHOLD
    ).numpy().tolist()

    detections.xyxy = detections.xyxy[nms_idx]
    detections.confidence = detections.confidence[nms_idx]
    detections.class_id = detections.class_id[nms_idx]

    print(f"After NMS: {len(detections.xyxy)} boxes")

    # Prompting SAM with detected boxes
    def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
        sam_predictor.set_image(image)
        result_masks = []
        for box in xyxy:
            masks, scores, logits = sam_predictor.predict(
                box=box,
                multimask_output=True
            )
            index = np.argmax(scores)
            result_masks.append(masks[index])
        return np.array(result_masks)


    # convert detections to masks
    detections.mask = segment(
        sam_predictor=sam_predictor,
        image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
        xyxy=detections.xyxy
    )
    
    cv2.imwrite(f"../../Hyperspectral_masks/{pole_id}_{img_idx}_mask.jpg", np.where(detections.mask[0], 255, 0))
    img_masked = np.zeros_like(image)
    img_masked[:,:,0] = detections.mask[0]
    img_masked[:,:,1] = detections.mask[0]
    img_masked[:,:,2] = detections.mask[0]
    img_masked = np.where(img_masked, image, 255)
    cv2.imwrite(f"../../Hyperspectral_masks/{pole_id}_{img_idx}_masked.jpg", img_masked)
    return

    # annotate image with detections
    box_annotator = sv.BoxAnnotator()
    mask_annotator = sv.MaskAnnotator()
    labels = [
        f"{CLASSES[class_id]} {confidence:0.2f}" 
        for _, _, confidence, class_id, _ 
        in detections]
    annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
    annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

    # save the annotated grounded-sam image
    cv2.imwrite("test_sam.jpg", annotated_image)

In [3]:
for pole_id in [0,5,6,30,41]:
    for rotation_idx in range(4):
        crop_hyper(pole_id,rotation_idx)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
final text_encoder_type: bert-base-uncased




Before NMS: 1 boxes
After NMS: 1 boxes
