In [1]:
# !python -m ipykernel install --user --name=grounded_sam2_env
!source grounded_sam2_env/bin/activate

In [2]:
import os
import torch
import json
import pandas as pd
from IPython.display import Image, clear_output



In [3]:
print(f"Setup complete. Using torch {torch.__version__} " \
      f"({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")

Setup complete. Using torch 2.6.0+cu124 (NVIDIA GeForce RTX 3090)


In [4]:
ds_path = "/SKU110K_fixed/"

In [5]:
# annotations = pd.read_csv("/SKU110K_fixed/annotations/annotations_test.csv")
# annotations.columns=['image_name','x1','y1','x2','y2','class','image_width','image_height']
# annotations.head(5)

In [6]:
# raw_data = {
#     'train': None,
#     'val': None,
#     'test': None
# }

# for split in raw_data.keys():
#     annotations = pd.read_csv(f"/SKU110K_fixed/annotations/annotations_{split}.csv")
#     annotations.columns=['image_name','x1','y1','x2','y2','class','image_width','image_height']
#     raw_data[split] = annotations
#     print(f"{split} split: {len(annotations)} annotations, {len(annotations.groupby('image_name'))} samples")


In [7]:
# ds = {}

# for split, data in raw_data.items():
    
#     data = data.reset_index()
#     images = set(data['image_name'])

#     ds[split] = []

#     for i, image_name in enumerate(list(images)):
#         df = data[data['image_name'] == image_name]

#         img_path = os.path.join(ds_path,image_name)
#         bboxes = []

#         for idx, ann in df.iterrows():
#             bbox = [ann['x1'],ann['x2'],ann['y1'],ann['y2']]
#             bboxes.append(bbox)        

#         ds[split].append({"image_path": img_path, "bboxes": bboxes})
#         if i%50 == 0:
#             print(f"{100*i/len(images):.1f}% of {split} split processed")


# with open('sku110_dataset.json', 'w') as f: 
#     json.dump(ds, f)

In [8]:
with open('sku110_dataset.json') as f:
    ds = json.load(f)

print(len(ds['train']))

8219


In [9]:
# !pip install supervision
# !pip install iopath
# !pip install addict
# !pip install yapf
# !pip install pycocotools
# !pip install timm

In [10]:
# Grounding DINO SAM-2

import os
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from grounding_dino.groundingdino.util.inference import load_model, predict, load_image






In [11]:
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
print(f"Video predictor loaded from config {model_cfg}, checkpoint {sam2_checkpoint}")
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
print(f"SAM2 image model loaded from {sam2_checkpoint}")
image_predictor = SAM2ImagePredictor(sam2_image_model)
print(f"image_predictor loaded from SAM2 image model")

grounding_model_config = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
grounding_model_weights = "checkpoints/groundingdino_swint_ogc.pth"
grounding_model = load_model(grounding_model_config, grounding_model_weights)
print(f"Grounding model loaded from config {grounding_model_config}, checkpoint {grounding_model_weights}")
grounding_model = grounding_model.float()
print(f"Model converted to float32")

Video predictor loaded from config configs/sam2.1/sam2.1_hiera_l.yaml, checkpoint ./checkpoints/sam2.1_hiera_large.pt
SAM2 image model loaded from ./checkpoints/sam2.1_hiera_large.pt
image_predictor loaded from SAM2 image model




final text_encoder_type: bert-base-uncased
Grounding model loaded from config grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py, checkpoint checkpoints/groundingdino_swint_ogc.pth
Model converted to float32


In [None]:
"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for specific frame
"""

# Load image
image_path = os.path.join(ds_path, 'images', 'train_7763.jpg')
image_source, image = load_image(image_path)  # PIL and Tensor image
print(f"Image loaded from {image_path}")

# Define text prompt
prompt = "product . object"  # multiple objects separated by ' . '
print(f"Using prompt: {prompt}")

# Run prediction (includes processing internally)
with torch.cuda.amp.autocast(enabled=False): # "ms_deform_attn_forward_cuda" not implemented for 'BFloat16'
    boxes, logits, phrases = predict(
        model=grounding_model,
        image=image,
        caption=prompt,
        box_threshold=0.3,
        text_threshold=0.25
    )

print(boxes, logits, phrases)

Image loaded from /SKU110K_fixed/images/train_7763.jpg
Using prompt: product . object


In [None]:
def plot_boxes_on_image(image_path, boxes, phrases, box_color='red', text_color='white'):
    """
    Plots bounding boxes and phrases on an image.
    :param image_path: Path to the image file.
    :param boxes: torch.Tensor of shape (N, 4), normalized [x1, y1, x2, y2]
    :param phrases: List of strings, same length as boxes.
    """
    image = Image.open(image_path).convert("RGB")
    width, height = image.size

    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(image)

    for box, phrase in zip(boxes, phrases):
        x1, y1, x2, y2 = box
        x1 *= width
        x2 *= width
        y1 *= height
        y2 *= height
        w, h = x2 - x1, y2 - y1

        rect = patches.Rectangle((x1, y1), w, h, linewidth=2, edgecolor=box_color, facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1 - 5, phrase, color=text_color, fontsize=12, bbox=dict(facecolor=box_color, alpha=0.5))

    plt.axis('off')
    plt.tight_layout()
    plt.show()


plot_boxes_on_image(image_path, boxes, phrases)

