In [1]:
%load_ext autoreload
%autoreload 2

GIT_ROOT_LINES = !git rev-parse --show-toplevel
WORK_DIR = GIT_ROOT_LINES[0]
import sys

# assume that grounded-sam-2 is installed in `Grounded-SAM-2` at project root
GROUNDING_SAM2_PATH = "${WORK_DIR}/Grounded-SAM-2"
sys.path.append(GROUNDING_SAM2_PATH)

In [2]:
import os
import os.path as osp
import shutil
import cv2
import torch
import numpy as np
from skimage.color import label2rgb
import supervision as sv
from glob import glob
from torchvision.ops import box_convert
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor 
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
import json
import copy
from helper import reverse_files, make_tmp_folder

In [198]:
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
REDETECT_INTERVAL = 1000
REVERSE_TRACK = False
CAM_TAG = 'left1' # choose from ["left1", "right1"] (for hypernerf vrig)

# TEXT_PROMPT = "person.broom."
# SCENE = "broom2"

# TEXT_PROMPT = "person.banana."
# SCENE = "vrig-peel-banana"

# TEXT_PROMPT = "person.toy."
# SCENE = "vrig-chicken"

BOX_THRESHOLD = 0.2
TEXT_THRESHOLD = 0.15
TEXT_PROMPT = "machine.cable.motor.3dprinter head.plate.board."
SCENE = "vrig-3dprinter"

PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# set where to read images from
VIDEO_PATH = osp.join(f"/scratch/xz653/datasets/hypernerf/{SCENE}", f'rgb/2x/{CAM_TAG}_*.png') 
OUTPUT_PATH = osp.join(f"/scratch/xz653/datasets/hypernerf/{SCENE}", 'instance/2x')

# Track + Detect

In [199]:
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"

frame_dir = make_tmp_folder()

if VIDEO_PATH.endswith(".mp4"):
    print("Extracting frames from video...")
    video_info = sv.VideoInfo.from_video_path(VIDEO_PATH)  # get video info
    print(video_info)
    frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)
    source_frames = Path(frame_dir)

    with sv.ImageSink(
        target_dir_path=source_frames, 
        overwrite=True, 
        image_name_pattern="{:05d}.jpg"
    ) as sink:
        for frame in tqdm(frame_generator, desc="Saving Video Frames"):
            sink.save_image(frame)
else:
    # VIDEO_PATH is a glob pattern
    im_paths = sorted(glob(VIDEO_PATH))
    print(f'images from {VIDEO_PATH}', len(im_paths))
    for i, im_path in enumerate(tqdm(im_paths)):
        im = Image.open(im_path)
        bname = osp.splitext(osp.basename(im_path))[0]
        bname = bname.split('_')[-1]
        im.convert("RGB").save(osp.join(frame_dir, f"{bname}.jpg"))
    if REVERSE_TRACK:
        reverse_files(frame_dir)

images from /scratch/xz653/datasets/hypernerf/vrig-3dprinter/rgb/2x/left1_*.png 207


100%|██████████| 207/207 [00:12<00:00, 16.81it/s]


In [200]:
try: 
    grounding_model;
    image_predictor;
except NameError:
    model_id = "IDEA-Research/grounding-dino-base"
    processor = AutoProcessor.from_pretrained(model_id)
    grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)

    sam2_checkpoint = GROUNDING_SAM2_PATH + "/checkpoints/sam2.1_hiera_large.pt"
    model_cfg = "../sam2/configs/sam2.1/sam2.1_hiera_l.yaml"

    video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
    sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
    image_predictor = SAM2ImagePredictor(sam2_image_model)

In [201]:
# init video predictor state
inference_state = video_predictor.init_state(video_path=frame_dir)
step = REDETECT_INTERVAL # the step to sample frames for Grounding DINO predictor

sam2_masks = MaskDictionaryModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
frame_object_count = {}


sam2_masks = MaskDictionaryModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
frame_object_count = {}
frame_names = [
    p for p in os.listdir(frame_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
text = TEXT_PROMPT
device = DEVICE
video_dir = frame_dir
output_video_path = osp.join(OUTPUT_PATH, CAM_TAG, "groundingsam.mp4")

mask_data_dir = osp.join(OUTPUT_PATH, CAM_TAG, "mask_data")
json_data_dir = osp.join(OUTPUT_PATH, CAM_TAG, "json_data")
result_dir = osp.join(OUTPUT_PATH, CAM_TAG, "img_result")
for d in [mask_data_dir, json_data_dir, result_dir]:
    os.makedirs(d, exist_ok=True)

"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
"""
print("Total frames:", len(frame_names))
for start_frame_idx in range(0, len(frame_names), step):
# prompt grounding dino to get the box coordinates on specific frame
    print("start_frame_idx", start_frame_idx)
    # continue
    img_path = os.path.join(video_dir, frame_names[start_frame_idx])
    image = Image.open(img_path).convert("RGB")
    image_base_name = frame_names[start_frame_idx].split(".")[0]
    mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")

    # run Grounding DINO on the image
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = grounding_model(**inputs)

    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=BOX_THRESHOLD,
        text_threshold=TEXT_THRESHOLD,
        target_sizes=[image.size[::-1]]
    )

    # prompt SAM image predictor to get the mask for the object
    image_predictor.set_image(np.array(image.convert("RGB")))

    # process the detection results
    input_boxes = results[0]["boxes"] # .cpu().numpy()
    # print("results[0]", results[0])
    OBJECTS = results[0]["labels"]
    if input_boxes.shape[0] != 0:
        # prompt SAM 2 image predictor to get the mask for the object
        masks, scores, logits = image_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )
        # convert the mask shape to (n, H, W)
        if masks.ndim == 2:
            masks = masks[None]
            scores = scores[None]
            logits = logits[None]
        elif masks.ndim == 4:
            masks = masks.squeeze(1)
        """
        Step 3: Register each object's positive points to video predictor
        """

        # If you are using point prompts, we uniformly sample positive points based on the mask
        if mask_dict.promote_type == "mask":
            mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
        else:
            raise NotImplementedError("SAM 2 video predictor only support mask prompts")
    else:
        print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
        mask_dict = sam2_masks

    """
    Step 4: Propagate the video predictor to get the segmentation results for each frame
    """
    objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
    frame_object_count[start_frame_idx] = objects_count
    print("objects_count", objects_count)
    
    if len(mask_dict.labels) == 0:
        mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
        print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
        continue
    else:
        video_predictor.reset_state(inference_state)

        for object_id, object_info in mask_dict.labels.items():
            frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
                    inference_state,
                    start_frame_idx,
                    object_id,
                    object_info.mask,
                )
        
        video_segments = {}  # output the following {step} frames tracking masks
        for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
            frame_masks = MaskDictionaryModel()
            
            for i, out_obj_id in enumerate(out_obj_ids):
                out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
                object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))
                object_info.update_box()
                frame_masks.labels[out_obj_id] = object_info
                image_base_name = frame_names[out_frame_idx].split(".")[0]
                frame_masks.mask_name = f"mask_{image_base_name}.npy"
                frame_masks.mask_height = out_mask.shape[-2]
                frame_masks.mask_width = out_mask.shape[-1]

            video_segments[out_frame_idx] = frame_masks
            sam2_masks = copy.deepcopy(frame_masks)

        print("video_segments:", len(video_segments))
    """
    Step 5: save the tracking masks and json files
    """
    for frame_idx, frame_masks_info in video_segments.items():
        mask = frame_masks_info.labels
        mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
        for obj_id, obj_info in mask.items():
            mask_img[obj_info.mask == True] = obj_id

        mask_img = mask_img.numpy().astype(np.uint16)
        np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)

        json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
        frame_masks_info.to_json(json_data_path)
       

CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
if REVERSE_TRACK:
    for d in [mask_data_dir, json_data_dir, result_dir]:
        reverse_files(d)

create_video_from_images(result_dir, output_video_path, frame_rate=15)

frame loading (JPEG / PNG): 100%|██████████| 207/207 [00:05<00:00, 38.56it/s]


Total frames: 207
start_frame_idx 0




objects_count 10


propagate in video: 100%|██████████| 207/207 [01:46<00:00,  1.95it/s]


video_segments: 207
Path '/scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result' already exists.
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000000.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000001.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000002.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000003.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000004.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000005.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000006.jpg
Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instanc

100%|██████████| 207/207 [00:01<00:00, 123.32it/s]

Video saved at /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/groundingsam.mp4





# Option-1: Transform Masks To Each Class

In [187]:
print("Labels of each instance:", OBJECTS, ". Please check them to be identical in all cameras when using multiple cameras, otherwise, fix them by setting class_map")
class_map = {i: i for i in range(len(OBJECTS)+1)}
# class_map = {0: 0, 1: 2, 2: 1}

Labels of each instance: ['##printer head', '##printer head', 'cable'] . Please check them to be aligned when using multiple cameras


In [165]:
# transform the results to a standard format
names = [f"{OBJECTS[class_map[name_i]-1]}:{'kinematic' if OBJECTS[class_map[name_i]-1] == 'person' else 'deformable'}" for name_i in range(1, 1 + len(OBJECTS))]
out_path = Path(OUTPUT_PATH)
(out_path / "names.json").write_text(json.dumps(names))
(out_path / f"names.{CAM_TAG}.json").write_text(json.dumps(names))

imask_path = out_path / "imask"
cmask_path = out_path / "cmask"

imask_path.mkdir(parents=True, exist_ok=True)
cmask_path.mkdir(parents=True, exist_ok=True)
class_map_np = np.vectorize(class_map.get)

for mask_path in tqdm(list(Path(mask_data_dir).glob("*.npy"))):
    imask = np.load(mask_path)
    dtype = imask.dtype
    imask = class_map_np(imask).astype(dtype)
    cmask = label2rgb(imask, bg_label=0)

    imask = Image.fromarray(imask)
    cmask = Image.fromarray((cmask * 255).astype(np.uint8))

    img_id = mask_path.name.split(".")[0].split("_")[-1]

    if CAM_TAG: img_id = f"{CAM_TAG}_{img_id}"

    imask.save(imask_path / (img_id + '.png'))
    cmask.save(cmask_path / (img_id + '.png')) 

100%|██████████| 207/207 [00:16<00:00, 12.45it/s]


# Option-2: Transform Masks To Single Class

In [205]:
names = ["object:deformable"]
out_path = Path(OUTPUT_PATH)
(out_path / "names.json").write_text(json.dumps(names))

imask_path = out_path / "imask"
cmask_path = out_path / "cmask"

imask_path.mkdir(parents=True, exist_ok=True)
cmask_path.mkdir(parents=True, exist_ok=True)

for mask_path in tqdm(list(Path(mask_data_dir).glob("*.npy"))):
    imask = np.load(mask_path)
    dtype = imask.dtype
    imask = (imask > 0).astype(dtype)
    cmask = label2rgb(imask, bg_label=0)

    imask = Image.fromarray(imask)
    cmask = Image.fromarray((cmask * 255).astype(np.uint8))

    img_id = mask_path.name.split(".")[0].split("_")[-1]

    if CAM_TAG: img_id = f"{CAM_TAG}_{img_id}"

    imask.save(imask_path / (img_id + '.png'))
    cmask.save(cmask_path / (img_id + '.png')) 

100%|██████████| 207/207 [00:21<00:00,  9.74it/s]
