In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gc
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2util import convert_images_to_mp4, sam2_output_export
from ultralytics import YOLO

# Use bfloat16 for the entire notebook
torch.autocast(device_type='cuda', dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # Turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
# Parameters
image_dir = Path.home() / 'source/driver-dataset/images/2021_09_06_poli_enyaq/normal'

In [None]:
IMAGES_DIR = Path(image_dir)

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SAM2_MODEL = 'sam2_hiera_large'
CONFIG = 'sam2_hiera_l.yaml'
CHECKPOINT = Path.home() / f'source/driver-segmentation/segmentation-model/notebooks/sam/segment-anything-2/checkpoints/{SAM2_MODEL}.pt'

assert CHECKPOINT.exists(), 'Checkpoint not found'

In [None]:
sam2_base = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
predictor = SAM2ImagePredictor(sam2_base)
video_predictor = build_sam2_video_predictor(
    CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False
)

In [None]:
frame_names = list(IMAGES_DIR.glob('*.jpg'))
# frame_names = [p for p in frame_names if int(p.stem) < IMAGES_LIMIT]
frame_names.sort(key=lambda p: int(p.stem))

# Visualize the first video frame
frame_idx = 0
plt.figure(figsize=(12, 8))
plt.title(f'frame {frame_idx}')
plt.imshow(Image.open(frame_names[frame_idx]))

In [None]:
mask_refinement = {}

In [None]:
model_yolo = YOLO('/home/lanter/source/driver-segmentation/sam2/yolov8x.pt')
person_cls_id = 0
frame = frame_names[0]
results = model_yolo(frame)[0]
detections = sv.Detections.from_ultralytics(results)
person_detections = detections[detections.class_id == 0]

# Prepare bbox prompt for SAM
person_boxes = results.boxes[results.boxes.cls == person_cls_id]
sorted_indices = torch.argsort(person_boxes.conf, descending=True)
bbox_prompt: np.ndarray = person_boxes[sorted_indices].xyxy.cpu().numpy()[0]

print(bbox_prompt)

image_rgb = Image.open(frame)
predictor.set_image(image_rgb)

masks, scores, logits = predictor.predict(
    box=bbox_prompt,
    multimask_output=False,
    **mask_refinement,
)

box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)

detections = sv.Detections(
    xyxy=sv.mask_to_xyxy(masks=masks),
    mask=masks.astype(bool),
)

source_image = box_annotator.annotate(scene=image_rgb.copy(), detections=detections)
segmented_image = mask_annotator.annotate(scene=image_rgb.copy(), detections=detections)

sv.plot_images_grid(
    images=[source_image, segmented_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image'],
)

print(masks.shape)
mask_prompt = masks.squeeze()
print(mask_prompt.shape)

In [None]:
del predictor
del model_yolo
gc.collect()
torch.cuda.empty_cache()

In [None]:
# This attempts to allocate ~75 GB of (video!) memory for 3:35 video (6450 frames)
# Kernel crashed for 1000 frames, worked fine for 800

inference_state = video_predictor.init_state(video_path=str(IMAGES_DIR))

In [None]:
# Add mask prompt to the first frame
_, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    mask=mask_prompt,
)

In [None]:
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
    inference_state
):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

In [None]:
sam2_output_export(video_segments, frame_names, IMAGES_DIR / 'sam2_output')

In [None]:
image_folder = IMAGES_DIR / 'sam2_output/visualization'
output_video_path = IMAGES_DIR / 'sam2_output' / 'output.mp4'
convert_images_to_mp4(image_folder, output_video_path)

In [None]:
torch.cuda.empty_cache()
gc.collect()