In [None]:
category_dict = {
    "FOO": {
        "PAS": ["pasta", "spiral pasta", "penne pasta"],
        "RIC": ["rice grain", "jasmine rice grain", "brown rice grain"],
        "LIM": ["citrus fruit", "lime", "calamansi"],
        "PEP": ["peppercorn", "black peppercorn", "white peppercorn"],
        "TOM": ["tomato", "normal tomato", "baby tomato"],
        "CHI": ["chili", "long chili", "short chili"],
        "PNT": ["peanut", "peanut with skin", "peanut without skin"],
        "BEA": ["bean", "black bean", "soy bean"],
        "SED": ["seed", "pumpkin seed", "sunflower seed"],
        "CFC": ["coffee candy", "brown coffee candy", "black coffee candy"],
        "ONI": ["shallot"],
        "CAN": ["candy"],
        "GAR": ["garlic"]
    },
    "FUN": {
        "CHK": ["checker piece", "black checker piece", "white checker piece"],
        "MAH": ["mahjong tile", "bamboo mahjong tile", "character mahjong tile"],
        "LEG": ["lego piece", "green lego piece", "light pink lego piece"],
        "CHS": ["chess piece", "black chess piece", "white chess piece"],
        "PZP": ["puzzle piece", "edge puzzle piece", "center puzzle piece"],
        "PUZ": ["puzzle piece", "edge puzzle piece", "center puzzle piece"],
        "PKC": ["poker chip", "blue poker chip", "white poker chip"],
        "PLC": ["playing card", "red playing card", "black playing card"],
        "MAR": ["marble", "big marble", "small marble"],
        "DIC": ["dice", "green dice", "white dice"],
        "CSC": ["chinese slim card", "chinese slim card without red marks", "chinese slim card with red marks"]
    },
    "HOU": {
        "TPK": ["toothpick", "straight plastic toothpick", "dental floss"],
        "CTB": ["cotton bud", "wooden cotton bud", "plastic cotton bud"],
        "PIL": ["pill", "white pill", "yellow pill"],
        "BAT": ["battery", "small AAA battery", "big AA battery"],
        "HCP": ["hair clipper", "black hair clipper", "brown hair clipper"],
        "MNY": ["money bill", "1000 vietnamese dong bill", "5000 vietnamese dong bill"],
        "COI": ["coin", "5 Australian cents coin", "10 Australian cents coin"],
        "BOT": ["bottle cap", "beer bottle cap", "plastic bottle cap"],
        "BBT": ["button", "button with 4 holes", "button with 2 holes"],
        "ULT": ["plastic utensil", "plastic spoon", "plastic fork"]
    },
    "OFF": {
        "PPN": ["push pin", "normal push pin", "round push pin"],
        "HST": ["heart sticker", "big heart sticker", "small heart sticker"],
        "CRS": ["craft stick", "red or orange craft stick", "blue or purple craft stick"],
        "RUB": ["rubber band", "yellow rubber band", "blue rubber band"],
        "STN": ["sticky note", "dark green sticky note", "light green sticky note"],
        "PPC": ["paper clip", "colored paper clip", "silver paper clip"],
        "PEN": ["pen", "pen with cap", "pen without cap"],
        "PNC": ["pencil"],
        "RHS": ["rhinestone", "round rhinestone", "star rhinestone"],
        "ZPT": ["zip tie", "short zip tie", "long zip tie"],
        "SFP": ["safety pin", "big safety pin", "small safety pin"],
        "LPP": ["lapel pin"],
        "WWO": ["wall wire organizer"]
    },
    "OTR": {
        "SCR": ["screw", "long silver concrete screw", "short bronze screw"],
        "BOL": ["bolt", "hex head bolt", "mushroom head bolt"],
        "NUT": ["nut", "hex nut", "square nut"],
        "WAS": ["washer", "metal washer", "nylon washer"],
        "BUT": ["button", "Beige", "Clear"],
        "NAI": ["nail", "common nail", "concrete nail"],
        "BEA": ["bead", "Blue and purple", "Orange and pink"],
        "IKC": ["ikea clip", "green ikea clip", "red ikea clip"],
        "IKE": ["ikea clip", "green ikea clip", "red ikea clip"],
        "PEG": ["peg", "grey peg", "white peg"],
        "STO": ["stone", "red stone", "yellowstone"]
    }
}

In [None]:
import json
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import requests
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import GroundingDinoProcessor
from hf_model.modeling_grounding_dino import GroundingDinoForObjectDetection

def draw_boxes(image, boxes, labels, img_save_path):
    draw = ImageDraw.Draw(image)
    for box, label in zip(boxes, labels):
        x_min, y_min, x_max, y_max = box
        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
    image.save(img_save_path)

model_id = "fushh7/llmdet_swin_tiny_hf"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = GroundingDinoProcessor.from_pretrained(model_id)
model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(device)

with open("../../dataset/pairtally_dataset/annotations/bbx_anno_valid.json", "r") as f:
    valid_img_dict = json.load(f)

res_save_root = "results/llm_det_res"
os.makedirs(res_save_root, exist_ok=True)
llm_det_res_dict = {}
for k, v in valid_img_dict.items():
    caption = " ".join(k.split("_")[0].split("-"))
    cat_name = "_".join(k.split("_")[2:6])
    obj_code_0 = cat_name.split("_")[2]
    obj_code_1 = cat_name.split("_")[3]
    category = cat_name.split("_")[1]
    def parse_obj_code(obj_code):
        code, idx = obj_code[:-1], int(obj_code[-1])
        obj_name = category_dict[category][code][idx]
        return obj_name
    obj_name_0 = parse_obj_code(obj_code_0)
    obj_name_1 = parse_obj_code(obj_code_1)

    obj_1_num = k.split("_")[6]
    obj_2_num = k.split("_")[7]
    obj_1_num = int(obj_1_num)
    obj_2_num = int(obj_2_num)

    frame_path = v["frame_path"]
    image = Image.open(frame_path)
    text = obj_name_0
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=0.25,
        text_threshold=0.3,
        target_sizes=[image.size[::-1]]
    )
    
    img_save_path = os.path.join(res_save_root, k)
    draw_boxes(image, results[0]['boxes'].cpu().numpy(), results[0]['labels'], img_save_path)
    print(results[0]['boxes'].cpu().numpy().tolist())
    llm_det_res_dict[k] = {}
    llm_det_res_dict[k]["boxes"] = results[0]['boxes'].cpu().numpy().tolist()
    llm_det_res_dict[k]['pred_count'] = len(results[0]['labels'])
    llm_det_res_dict[k]['frame_path'] = frame_path
    llm_det_res_dict[k]['caption'] = obj_name_0    
    llm_det_res_dict[k]['obj_1_num'] = obj_1_num
    llm_det_res_dict[k]['obj_2_num'] = obj_2_num
    break

with open(os.path.join(res_save_root, "result.json"), "w") as f:
    json.dump(llm_det_res_dict, f)