This is one of the deliverables produced from this project: https://library.hkust.edu.hk/ds/project/p002/
> Created by LAU Ming Kit, Jack (Year 4 student, BEng in Computer Engineering, HKUST)

## GroundingDINO demo

In [None]:
import torch, os, cv2, copy
import matplotlib.pyplot as plt
from PIL import ImageDraw, ImageFont
from torchvision.ops import box_convert
from PIL import Image
import argparse
import numpy as np

# RAM
from PIL import Image
from ram.models import ram_plus
from ram import inference_ram as inference
from ram import get_transform

# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Hugging face
from huggingface_hub import hf_hub_download

## Follow instruction on https://github.com/IDEA-Research/GroundingDINOg if you encounter any difficulties

### Global variable

In [None]:
# Global variable
image_path = "images/MED_VIN_004SM.jpg" # Change this can be path or image "images"
output_file = "output"  # specify the output dir
TEXT_PROMPT = ""   # Change the text prompt for different detection object

### Global function

In [None]:
# Global function
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
            cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

            args = SLConfig.fromfile(cache_config_file) 
            model = build_model(args)
            args.device = device

            cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
            checkpoint = torch.load(cache_file, map_location='cpu')
            log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
            print("Model loaded from {} \n => {}".format(cache_file, log))
            _ = model.eval()
            return model

### Download model

In [None]:
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

### Prediction

In [None]:
BOX_TRESHOLD = 0.3
TEXT_TRESHOLD = 0.25

In [None]:
def byfile(image_path):
    image_source, image = load_image(image_path)

    boxes, logits, phrases = predict(
    model=groundingdino_model, 
    image=image, 
    caption=TEXT_PROMPT, 
    box_threshold=BOX_TRESHOLD, 
    text_threshold=TEXT_TRESHOLD
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
    #annotated_frame = annotated_frame[...,::-1] # BGR to RGB


    cv2.imwrite(output_file+'/'+image_path, annotated_frame)

In [None]:
def byfolder(input_file):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # download the model from the URL below and put in the "pretrained" folder of your working directory 
    # https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth
    ram_pth =  "pretrained/ram_plus_swin_large_14m.pth"
    image_size = 384
    transform = get_transform(image_size=image_size)

    # load ram model
    model = ram_plus(pretrained=ram_pth,
                                image_size=image_size,
                                vit='swin_l')
    model.eval()
    model = model.to(device)
    
    for i in os.listdir(input_file):
        if (i[-3:] == "jpg" or i[-3:] == "png"):   
            image_path = input_file+'/'+i
            ori_image = Image.open(image_path)
            image = transform(ori_image).unsqueeze(0).to(device)
            res = inference(image, model)
            
            print("Image Tags: ", res[0])
            print(res[0].replace("|", ".")+' .')
            
            TEXT_PROMPT = res[0].replace("|", ".")+' .'
            BOX_TRESHOLD = 0.3
            TEXT_TRESHOLD = 0.25

            image_source, image = load_image(image_path)

            boxes, logits, phrases = predict(
                model=groundingdino_model, 
                image=image, 
                caption=TEXT_PROMPT, 
                box_threshold=BOX_TRESHOLD, 
                text_threshold=TEXT_TRESHOLD
            )

            print("boxes=",boxes)
            print("logits=",logits)
            print("phrases=",phrases)

            annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
            #annotated_frame = annotated_frame[...,::-1] # BGR to RGB
            cv2.imwrite(output_file+'/boxes_'+i, annotated_frame)
            print("image saved")

In [None]:
if len(image_path) <= 3:
    byfolder(image_path)
elif image_path[-3:] == 'png' or image_path[-3:] == 'jpg' or image_path[-4:] == 'jpeg':
    byfile(image_path)
elif '.' in image_path:
    print('wrong file type')
else:
    byfolder(image_path)