In [None]:
!nvidia-smi

# Grounded SAM setection and segmentation of Plant based patties


## Installation and Imports

In [None]:
%cd /content

!git clone https://github.com/IDEA-Research/Grounded-Segment-Anything

%cd /content/Grounded-Segment-Anything

!git checkout b579761a11ffab025d75a0f84a2d8a722abf7d5e

!pip install -q -r requirements.txt
%cd /content/Grounded-Segment-Anything/GroundingDINO
!pip install -q .
%cd /content/Grounded-Segment-Anything/segment_anything
!pip install -q .
%cd /content/Grounded-Segment-Anything

In [3]:
!pip install tqdm



In [4]:
import os, sys

sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))

import argparse
import copy

from IPython.display import display
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from GroundingDINO.groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt


# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline


from huggingface_hub import hf_hub_download

## Load Models

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Grounding DINO model

In [6]:
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    return model

In [None]:
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"


groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device)

SAM model

In [None]:
! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

sam_checkpoint = 'sam_vit_h_4b8939.pth'

sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))

## Grounding DINO for detection

In [9]:

def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25):
  boxes, logits, phrases = predict(
      model=model,
      image=image,
      caption=text_prompt,
      box_threshold=box_threshold,
      text_threshold=text_threshold
  )

  annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
  annotated_frame = annotated_frame[...,::-1] # BGR to RGB
  return annotated_frame, boxes

## SAM for Segmentation

In [10]:



def segment(image, sam_model, boxes):
  sam_model.set_image(image)
  H, W, _ = image.shape
  boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

  transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
  masks, _, _ = sam_model.predict_torch(
      point_coords = None,
      point_labels = None,
      boxes = transformed_boxes,
      multimask_output = False,
      )
  return masks.cpu()


def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

## Methods for Detection, Segmentation, Generating masks and Saving the output image


In [11]:
def run_dino_detection(image_source, image):

  annotated_frame, detected_boxes = detect(image, text_prompt="patty", model=groundingdino_model)
  return annotated_frame, detected_boxes

In [12]:
def run_sam_segmentation(image_source, image, annotated_frame, detected_boxes):

  segmented_frame_masks = segment(image_source, sam_predictor, boxes=detected_boxes)
  annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
  return annotated_frame_with_mask

In [13]:
def generate_masked_image(annotated_frame, annotated_frame_with_mask):
  b = annotated_frame_with_mask[:, :, :3]
  are_equal = np.array_equal(annotated_frame, b)


  img = image_source.copy()
  mask = annotated_frame == b

  img[mask] = 0
  return img

In [14]:
def save_image(masked_image, save_folder_path, image_name):
  image = Image.fromarray(masked_image)

  filename = f"{image_name}"
  save_image_path = os.path.join(save_folder_path, filename)

  image.save(save_image_path, "JPEG")

## Loading IMAGES from drive and running the methods


In [15]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from PIL import Image
import numpy as np
import random
from tqdm import tqdm

image_folders_path = '/content/drive/MyDrive/Food Recognition Project/Resized Dataset'

selected_classes = [
 'commercial_deep_over',
 'commercial_unbaked',
 'inhouse_deep_normal',
 'inhouse_deep_over',
 'inhouse_old_deep_normal',
 'inhouse_old_deep_over',
 'inhouse_unbaked'
 'commercial_deep_normal']

save_directory = '/content/drive/MyDrive/Food Recognition Project/New_Dataset'

for class_name in selected_classes:

    class_folder_path = os.path.join(image_folders_path, class_name)

    save_folder_path = os.path.join(save_directory, class_name)
    os.makedirs(save_folder_path, exist_ok=True)

    if os.path.exists(class_folder_path):

        class_images = os.listdir(class_folder_path)

        for image_name in tqdm(class_images):
            image_path = os.path.join(class_folder_path, image_name)
            image_source, image = load_image(image_path)

            annotated_frame, detected_boxes = run_dino_detection(image_source, image)
            annotated_frame_with_mask = run_sam_segmentation(image_source, image, annotated_frame, detected_boxes)
            masked_image = generate_masked_image(annotated_frame, annotated_frame_with_mask)

            save_image(masked_image, save_folder_path, image_name)


    else:
        print(f"Class folder not found: {class_folder_path}")
