In [3]:
%load_ext autoreload
%autoreload 2

GIT_ROOT_LINES = !git rev-parse --show-toplevel
WORK_DIR = GIT_ROOT_LINES[0]
import sys

# assume that grounded-sam-2 is installed in `Grounded-SAM-2` at project root
GROUNDING_SAM2_PATH = "${WORK_DIR}/Grounded-SAM-2"
sys.path.append(GROUNDING_SAM2_PATH)

TEXT_PROMPT_DICT = {
    'iphone': {
        "teddy": "person.teddy bear.",
        "apple": "person.apple.",
        "backpack": "person.backpack.",
        "block": "person.block.", # needs reverse track
        "pillow": "person.pillow."                
    }
}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import os
import shutil
import cv2
import torch
import numpy as np
import supervision as sv
from glob import glob
from torchvision.ops import box_convert
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor 
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
import json
import copy


def reverse_files(folder_path):
    # Get all files in the directory, excluding subdirectories
    files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    
    # Sort the files lexicographically
    sorted_files = sorted(files)
    
    # Determine the number of pairs to swap
    n = len(sorted_files)
    for i in tqdm(range(n // 2), desc=f"Swapping files in {folder_path}"):
        # Get the i-th file from the start and end
        file1 = sorted_files[i]
        file2 = sorted_files[n - i - 1]
        
        # Full paths for the files
        path1 = os.path.join(folder_path, file1)
        path2 = os.path.join(folder_path, file2)
        
        # Read the contents of both files
        with open(path1, 'rb') as f:
            content1 = f.read()
        with open(path2, 'rb') as f:
            content2 = f.read()
        
        # Write the swapped contents back to the files
        with open(path1, 'wb') as f:
            f.write(content2)
        with open(path2, 'wb') as f:
            f.write(content1)
        

    

"""
Hyperparam for Ground and Tracking
"""
ROOT = "/scratch/xz653/datasets/iphone"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"

# change them for better results
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
REDETECT_INTERVAL = 1000

SCENE_NAME = 'pillow'

REVERSE_TRACK = False
VIDEO_PATH = os.path.join(ROOT, SCENE_NAME, 'rgb/1x/0_*.png')

SAVE_TRACKING_RESULTS_DIR = os.path.join(ROOT, SCENE_NAME, 'flow3d_preprocessed/grounding_sam2')
OUTPUT_VIDEO_PATH = os.path.join(SAVE_TRACKING_RESULTS_DIR, 'output.mp4')

TEXT_PROMPT = TEXT_PROMPT_DICT['iphone'][SCENE_NAME]


# SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
# OUTPUT_VIDEO_PATH = "./output.mp4"


############################################ 
ALL_DETECTIONS_OUTPUT = []

SOURCE_VIDEO_FRAME_DIR = "/tmp/custom_video_frames"

PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


for tmp_folder in [SOURCE_VIDEO_FRAME_DIR, ]:
    if os.path.exists(tmp_folder):
        shutil.rmtree(tmp_folder)
    Path(tmp_folder).mkdir(parents=True, exist_ok=True)

output_dir = SAVE_TRACKING_RESULTS_DIR
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
result_dir = os.path.join(output_dir, "result")
CommonUtils.creat_dirs(mask_data_dir)
CommonUtils.creat_dirs(json_data_dir)

Path '/scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/mask_data' did not exist and has been created.
Path '/scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/json_data' did not exist and has been created.


In [10]:
"""
Custom video input directly using video files
"""
if VIDEO_PATH.endswith(".mp4"):
    print("Extracting frames from video...")
    video_info = sv.VideoInfo.from_video_path(VIDEO_PATH)  # get video info
    print(video_info)
    frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)

    # saving video to frames
    source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
    source_frames.mkdir(parents=True, exist_ok=True)

    with sv.ImageSink(
        target_dir_path=source_frames, 
        overwrite=True, 
        image_name_pattern="{:05d}.jpg"
    ) as sink:
        for frame in tqdm(frame_generator, desc="Saving Video Frames"):
            sink.save_image(frame)
else:
    # VIDEO_PATH is a glob pattern
    im_paths = sorted(glob(VIDEO_PATH))
    print(f'images from {VIDEO_PATH}', len(im_paths))
    for i, im_path in enumerate(tqdm(im_paths)):
        im = Image.open(im_path)
        bname = os.path.splitext(os.path.basename(im_path))[0]
        im.convert("RGB").save(os.path.join(SOURCE_VIDEO_FRAME_DIR, f"{bname}.jpg"))

    if REVERSE_TRACK:
        reverse_files(SOURCE_VIDEO_FRAME_DIR)
    

images from /scratch/xz653/datasets/iphone/pillow/rgb/1x/0_*.png 330


100%|██████████| 330/330 [00:10<00:00, 32.85it/s]


# Multi Frame Grounding

In [6]:
# init grounding dino model from huggingface
model_id = "IDEA-Research/grounding-dino-base"
processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [7]:
# init sam image predictor and video predictor model
sam2_checkpoint = GROUNDING_SAM2_PATH + "/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)

In [11]:
# init video predictor state
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
step = REDETECT_INTERVAL # the step to sample frames for Grounding DINO predictor

sam2_masks = MaskDictionaryModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
frame_object_count = {}


sam2_masks = MaskDictionaryModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
frame_object_count = {}
frame_names = [
    p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
text = TEXT_PROMPT
device = DEVICE
video_dir = SOURCE_VIDEO_FRAME_DIR
output_video_path = OUTPUT_VIDEO_PATH


"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
"""
print("Total frames:", len(frame_names))
for start_frame_idx in range(0, len(frame_names), step):
# prompt grounding dino to get the box coordinates on specific frame
    print("start_frame_idx", start_frame_idx)
    # continue
    img_path = os.path.join(video_dir, frame_names[start_frame_idx])
    image = Image.open(img_path).convert("RGB")
    image_base_name = frame_names[start_frame_idx].split(".")[0]
    mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")

    # run Grounding DINO on the image
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = grounding_model(**inputs)

    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=BOX_THRESHOLD,
        text_threshold=TEXT_THRESHOLD,
        target_sizes=[image.size[::-1]]
    )

    # prompt SAM image predictor to get the mask for the object
    image_predictor.set_image(np.array(image.convert("RGB")))

    # process the detection results
    input_boxes = results[0]["boxes"] # .cpu().numpy()
    # print("results[0]",results[0])
    OBJECTS = results[0]["labels"]
    if input_boxes.shape[0] != 0:

        # prompt SAM 2 image predictor to get the mask for the object
        masks, scores, logits = image_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )
        # convert the mask shape to (n, H, W)
        if masks.ndim == 2:
            masks = masks[None]
            scores = scores[None]
            logits = logits[None]
        elif masks.ndim == 4:
            masks = masks.squeeze(1)
        """
        Step 3: Register each object's positive points to video predictor
        """

        # If you are using point prompts, we uniformly sample positive points based on the mask
        if mask_dict.promote_type == "mask":
            mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
        else:
            raise NotImplementedError("SAM 2 video predictor only support mask prompts")
    else:
        print("No object detected in the frame, skip merge the frame merge {}".format(frame_names[start_frame_idx]))
        mask_dict = sam2_masks

    """
    Step 4: Propagate the video predictor to get the segmentation results for each frame
    """
    objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
    frame_object_count[start_frame_idx] = objects_count
    print("objects_count", objects_count)
    
    if len(mask_dict.labels) == 0:
        mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])
        print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
        continue
    else:
        video_predictor.reset_state(inference_state)

        for object_id, object_info in mask_dict.labels.items():
            frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
                    inference_state,
                    start_frame_idx,
                    object_id,
                    object_info.mask,
                )
        
        video_segments = {}  # output the following {step} frames tracking masks
        for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
            frame_masks = MaskDictionaryModel()
            
            for i, out_obj_id in enumerate(out_obj_ids):
                out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
                object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))
                object_info.update_box()
                frame_masks.labels[out_obj_id] = object_info
                image_base_name = frame_names[out_frame_idx].split(".")[0]
                frame_masks.mask_name = f"mask_{image_base_name}.npy"
                frame_masks.mask_height = out_mask.shape[-2]
                frame_masks.mask_width = out_mask.shape[-1]

            video_segments[out_frame_idx] = frame_masks
            sam2_masks = copy.deepcopy(frame_masks)

        print("video_segments:", len(video_segments))
    """
    Step 5: save the tracking masks and json files
    """
    for frame_idx, frame_masks_info in tqdm(video_segments.items(), desc="Saving Tracking Masks", total=len(video_segments)):
        mask = frame_masks_info.labels
        mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
        for obj_id, obj_info in mask.items():
            mask_img[obj_info.mask == True] = obj_id

        mask_img = mask_img.numpy().astype(np.uint16)
        np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)

        json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
        frame_masks_info.to_json(json_data_path)
       

CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
if REVERSE_TRACK:
    for d in [mask_data_dir, json_data_dir, result_dir]:
        reverse_files(d)

create_video_from_images(result_dir, output_video_path, frame_rate=15)

frame loading (JPEG / PNG): 100%|██████████| 330/330 [00:10<00:00, 31.94it/s]


Total frames: 330
start_frame_idx 0



Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).


objects_count 2


propagate in video: 100%|██████████| 330/330 [01:16<00:00,  4.34it/s]


video_segments: 330


Saving Tracking Masks: 100%|██████████| 330/330 [00:02<00:00, 125.28it/s]


Path '/scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result' did not exist and has been created.
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00000.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00001.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00002.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00003.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00004.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00005.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/result/0_00006.jpg
Annotated image saved as /scratch/xz653/datasets/iphone/pi

100%|██████████| 330/330 [00:03<00:00, 91.07it/s]

Video saved at /scratch/xz653/datasets/iphone/pillow/flow3d_preprocessed/grounding_sam2/output.mp4



