In [1]:
headless = True

In [2]:
import os
import numpy as np
import torch
import imageio.v2 as imageio
from PIL import Image
import cv2 # make sure to use headless!

if not headless:
    import matplotlib.pyplot as plt

In [4]:
device = torch.device("cuda")

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

# honestly not sure the purpose of this, but leaving it in
np.random.seed(3)

In [21]:
#names = ["basketball", "boxes", "football", "juggle", "softball", "tennis"]
names = ["tennis"]
namesidx = 0
camidx = 0
frameidx = 0

for name in names:
    for cam in range(31):
        inpath = f'./data/{name}/ims/{cam}/' # input data
        boxpath = f'./annotations/{name}/{cam}.txt' # bound boxes from boxbounder
        outpath = f'./output/{name}/{cam}/' # output masks
        print(f"{name}/{cam}")

        #if os.path.exists(os.path.join(outpath, "%06d.png"%(149))):
        #    continue

        inference_state = predictor.init_state(video_path=inpath)

        ann_frame_idx = 0
        ann_obj_id = 0

        boxes = []
        textf = open(boxpath, "r")
        boxes = textf.readlines()
        boxes = [[int(chara) for chara in line.split(", ")] for line in boxes]
        textf.close()

        for box in boxes:
            _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
                inference_state=inference_state,
                frame_idx=ann_frame_idx,
                obj_id=ann_obj_id,
                box=box,
            )
            ann_obj_id += 1
        
        video_segments = {}  # video_segments contains the per-frame segmentation results
        for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
            video_segments[out_frame_idx] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }

            finalimg = np.zeros((360, 640, 1))

            
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                h, w = out_mask.shape[-2:]
                mask_image = out_mask.reshape(h, w, 1)
                finalimg = cv2.bitwise_or(finalimg.astype("uint8"), mask_image.astype("uint8"))

            finalimg *= 255

            imageio.imwrite(os.path.join(outpath, "%06d.png"%(out_frame_idx)), finalimg)

tennis/0


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.11it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.27it/s]


tennis/1


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 17.61it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.39it/s]


tennis/2


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.51it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]


tennis/3


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.97it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.29it/s]


tennis/4


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.31it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.30it/s]


tennis/5


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.46it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.33it/s]


tennis/6


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.70it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.37it/s]


tennis/7


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 20.00it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.32it/s]


tennis/8


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.46it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.28it/s]


tennis/9


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 18.99it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]


tennis/10


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.47it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.25it/s]


tennis/11


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 18.94it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.16it/s]


tennis/12


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.75it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.26it/s]


tennis/13


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.24it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.27it/s]


tennis/14


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.19it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.30it/s]


tennis/15


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.78it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.29it/s]


tennis/16


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.75it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.28it/s]


tennis/17


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.73it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.33it/s]


tennis/18


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.31it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.28it/s]


tennis/19


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.23it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]


tennis/20


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.39it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.33it/s]


tennis/21


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.68it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.26it/s]


tennis/22


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 18.77it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.25it/s]


tennis/23


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.32it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.23it/s]


tennis/24


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 18.79it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.27it/s]


tennis/25


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.72it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]


tennis/26


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 18.77it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.27it/s]


tennis/27


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.19it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.24it/s]


tennis/28


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.74it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.29it/s]


tennis/29


frame loading (JPEG): 100%|██████████| 150/150 [00:08<00:00, 18.71it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.37it/s]


tennis/30


frame loading (JPEG): 100%|██████████| 150/150 [00:07<00:00, 19.23it/s]
propagate in video: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]
