<a href="https://colab.research.google.com/github/martintmv-git/RB-IBDM/blob/main/GSL/GSL_usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<div class="align-center">
  <a href="https://www.fontys.nl/"><img src="https://www.fontys.nl/static/design/FA845701-BD71-466E-9B3D-38580DFAD5B4-fsm/images/logo-inverted@2x.png" height="75"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://www.has.nl/"><img src="https://i.imgur.com/ZxkugVW.png" height="75"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://www.has.nl/onderzoek/lectoraten/lectoraat-innovatieve-biomonitoring/"><img src="https://i.imgur.com/oH3VJpE.png" height="75"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://www.naturalis.nl/"><img src="https://i.imgur.com/mAHW7XQ.png" height="75"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://www.arise-biodiversity.nl/"><img src="https://i.imgur.com/j6gBpqT.png" height="75"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://faunabit.eu/"><img src="https://i.imgur.com/HxqzRYg.png" height="70"></a>
  <img src="https://i.imgur.com/zyfbV3r.png" width="20">
  <a href="https://diopsis.eu/"><img src="https://i.imgur.com/NHZ8e1b.png" height="75"></a>
</div>

# **How to Use the GSL Background Subtitution on New Images**

---

[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/martintmv-git/RB-IBDM) <a href="https://github.com/facebookresearch/detectron2"><img src="https://dl.fbaipublicfiles.com/detectron2/Detectron2-Logo-Horz.png" width="120"></a>

<b>This notebook contains two main parts:
1. Environment setup
2. Usage

You need to run the cells in the forst part in order to get everything set up and ready to use and then pass either the path to the image to process or the path to the folder containing multiple images to the `process` function for inference.</b>

# Environment setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#GroundedSAM
%cd /content
!git clone https://github.com/IDEA-Research/Grounded-Segment-Anything
%cd /content/Grounded-Segment-Anything
!pip install -q -r requirements.txt
%cd /content/Grounded-Segment-Anything/GroundingDINO
!pip install -q .
%cd /content/Grounded-Segment-Anything/segment_anything
!pip install -q .
%cd /content/Grounded-Segment-Anything

#LaMa
!pip install -q simple-lama-inpainting

#EXIF
!pip install -q piexif

In [None]:
import os, sys

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))

import argparse
import copy

from IPython.display import display
import PIL
from PIL import Image, ImageDraw, ImageFont, ImageChops, ImageEnhance
import piexif
from torchvision.ops import box_convert

# GroundingDINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict

from huggingface_hub import hf_hub_download

import supervision as sv

# SAM
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt

# LaMa
import requests
import torch
from io import BytesIO
from simple_lama_inpainting import SimpleLama

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model

In [None]:
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"


groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device)

In [None]:
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

sam_checkpoint = 'sam_vit_h_4b8939.pth'

sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))

In [None]:
simple_lama = SimpleLama()

In [None]:
def single_process(image_path):
  path, file_name = os.path.split(image_path)
  save_path = os.path.join(path, file_name.replace('.jpg', '_gsl.jpg'))

  # Read image
  image_source, image = load_image(image_path)

  # detect insects using GroundingDINO
  def detect(image, model, text_prompt = 'insect . flower . cloud', box_threshold = 0.25, text_threshold = 0.25):
    boxes, logits, phrases = predict(
        image=image,
        model=model,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        device=device
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[...,::-1] # BGR to RGB
    return annotated_frame, boxes, phrases

  annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model)

  indices = [i for i, s in enumerate(phrases) if 'insect' in s]

  def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords = None,
        point_labels = None,
        boxes = transformed_boxes,
        multimask_output = True,
        )
    return masks.cpu()

  def draw_mask(mask, image, random_color=True):
      if random_color:
          color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
      else:
          color = np.array([30/255, 144/255, 255/255, 0.6])
      h, w = mask.shape[-2:]
      mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

      annotated_frame_pil = Image.fromarray(image).convert("RGBA")
      mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

      return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

  segmented_frame_masks = segment(image_source, sam_predictor, boxes=detected_boxes[indices])

  # combine all masks into one for easy visualization
  final_mask = None
  for i in range(len(segmented_frame_masks) - 1):
    if final_mask is None:
      final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i+1][0].cpu())
    else:
      final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i+1][0].cpu())

  annotated_frame_with_mask = draw_mask(final_mask, image_source)

  def dilate_mask(mask, dilate_factor=15):
      mask = mask.astype(np.uint8)
      mask = cv2.dilate(
          mask,
          np.ones((dilate_factor, dilate_factor), np.uint8),
          iterations=1
      )
      return mask

  # original image
  image_source_pil = Image.fromarray(image_source)

  # create mask image
  mask = final_mask.numpy()
  mask = mask.astype(np.uint8) * 255
  image_mask_pil = Image.fromarray(mask)

  # dilate mask
  mask = dilate_mask(mask)
  dilated_image_mask_pil = Image.fromarray(mask)

  result = simple_lama(image_source, dilated_image_mask_pil)

  img1 = Image.fromarray(image_source)
  img2 = result

  diff = ImageChops.difference(img2, img1)

  threshold = 7
  # Grayscale
  diff2 = diff.convert('L')
  # Threshold
  diff2 = diff2.point( lambda p: 255 if p > threshold else 0 )
  # # To mono
  diff2 = diff2.convert('1')

  img3 = Image.new('RGB', img1.size, (255, 236, 10))
  diff3 = Image.composite(img1, img3, diff2)
  diff3.save(save_path)
  piexif.transplant(image_path, save_path)
  diff3
  print('Processing completed!')

def batch_process(path):
  save_path = os.path.join(path, 'GSL_output')
  if os.path.exists(save_path) == False:
    os.mkdir(save_path)

  for file in os.listdir(path):
    if file.endswith('.jpg'):
      # Read image
      image_source, image = load_image(os.path.join(path, file))

      # detect insects using GroundingDINO
      def detect(image, model, text_prompt = 'insect . flower . cloud', box_threshold = 0.25, text_threshold = 0.25):
        boxes, logits, phrases = predict(
            image=image,
            model=model,
            caption=text_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=device
        )

        annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
        annotated_frame = annotated_frame[...,::-1] # BGR to RGB
        return annotated_frame, boxes, phrases

      annotated_frame, detected_boxes, phrases = detect(image, model=groundingdino_model)

      indices = [i for i, s in enumerate(phrases) if 'insect' in s]

      def segment(image, sam_model, boxes):
        sam_model.set_image(image)
        H, W, _ = image.shape
        boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
        masks, _, _ = sam_model.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = transformed_boxes,
            multimask_output = True,
            )
        return masks.cpu()

      def draw_mask(mask, image, random_color=True):
          if random_color:
              color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
          else:
              color = np.array([30/255, 144/255, 255/255, 0.6])
          h, w = mask.shape[-2:]
          mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

          annotated_frame_pil = Image.fromarray(image).convert("RGBA")
          mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

          return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

      segmented_frame_masks = segment(image_source, sam_predictor, boxes=detected_boxes[indices])

      # combine all masks into one for easy visualization
      final_mask = None
      for i in range(len(segmented_frame_masks) - 1):
        if final_mask is None:
          final_mask = np.bitwise_or(segmented_frame_masks[i][0].cpu(), segmented_frame_masks[i+1][0].cpu())
        else:
          final_mask = np.bitwise_or(final_mask, segmented_frame_masks[i+1][0].cpu())

      annotated_frame_with_mask = draw_mask(final_mask, image_source)

      def dilate_mask(mask, dilate_factor=15):
          mask = mask.astype(np.uint8)
          mask = cv2.dilate(
              mask,
              np.ones((dilate_factor, dilate_factor), np.uint8),
              iterations=1
          )
          return mask

      # original image
      image_source_pil = Image.fromarray(image_source)

      # create mask image
      mask = final_mask.numpy()
      mask = mask.astype(np.uint8) * 255
      image_mask_pil = Image.fromarray(mask)

      # dilate mask
      mask = dilate_mask(mask)
      dilated_image_mask_pil = Image.fromarray(mask)

      result = simple_lama(image_source, dilated_image_mask_pil)

      img1 = Image.fromarray(image_source)
      img2 = result

      diff = ImageChops.difference(img2, img1)

      threshold = 7
      # Grayscale
      diff2 = diff.convert('L')
      # Threshold
      diff2 = diff2.point( lambda p: 255 if p > threshold else 0 )
      # # To mono
      diff2 = diff2.convert('1')

      img3 = Image.new('RGB', img1.size, (255, 236, 10))
      diff3 = Image.composite(img1, img3, diff2)
      diff3.save(os.path.join(save_path, file))
      piexif.transplant(os.path.join(path, file), os.path.join(save_path, file))
  print('Batch completed, find processed images in GSL_output!')

def process(path):
  if os.path.isdir(path):
    batch_process(path)
  else:
    single_process(path)

# Inference

## Single image

In [None]:
# Select image file
image_path = '/content/drive/MyDrive/Fontys/Fontys_Sem7/insect_detection/GSL/GSL_test_images/20230715012944.jpg'
process(image_path)



Processing completed!


## Batch

In [None]:
# Select folder containing images
folder_path = '/content/drive/MyDrive/Fontys/Fontys_Sem7/insect_detection/GSL/GSL_test_images'
process(folder_path)



Batch completed, find processed images in GSL_output!
