<a href="https://colab.research.google.com/github/daminnock/ChainSentinel/blob/main/Video_Segmentation_with_SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Before you start

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!nvidia-smi

Sat Apr 29 15:57:41 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

**NOTE:** To make it easier for us to manage datasets, images and models we create a `HOME` constant. 

In [5]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /content


## Install Segment Anything Model (SAM) and other dependencies

In [6]:
%cd {HOME}

import sys
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

/content
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-uponpap7
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-uponpap7
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 567662b0fd33ca4b022d94d3b8de896628cd32dd
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36610 sha256=87619179c13cfe89c1d06cfddc126355b093b93d1e0073bb37fb4ae5eee27055
  Stored in directory: /tmp/pip-ephem-wheel-cache-4np4q1o_/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6c

### Download SAM weights

In [7]:
%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

/content
/content/weights


In [8]:
import os

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

/content/weights/sam_vit_h_4b8939.pth ; exist: True


## Load Model

In [9]:
import torch

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [10]:
from segment_anything import sam_model_registry, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

In [None]:
import os

IMAGE_NAME = "dog.jpeg"
IMAGE_PATH = os.path.join(HOME, "data", IMAGE_NAME)

### Video Bounding Box to Mask

In [13]:
import cv2
import numpy as np
mask_predictor = SamPredictor(sam)
from tqdm import tqdm

In [14]:
import cv2
import json

# Flags to control visualization
SHOW_BOXES = True  # Display bounding boxes
SHOW_LABELS = True  # Display product names
SHOW_SEGMENTATION = True # Display segmentation
SEGMENTATE = False # Apply SAM

# Input and output files
VIDEO_FILE = "/content/drive/MyDrive/videos/cam1_raw_concat.mkv"
JSON_FILE = '/content/drive/MyDrive/videos/cam1_person_tracking.json'
OUTPUT_FILE = "video_with_obj_detection.mkv"
OUTPUT_FILE_MASKS = "video_with_masks.mkv"
OUTPUT_FILE_MASKED = "video_masked.mkv"

# Read JSON file
with open(JSON_FILE) as f:
    data = json.load(f)

# Read video file
cap = cv2.VideoCapture(VIDEO_FILE)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define color for each product
COLORS = {0: (100, 0, 0), 1: (0, 100, 0), 2: (100, 0, 0)}

# Define video codec and output video writer
fourcc = cv2.VideoWriter_fourcc(*"XVID")
out = cv2.VideoWriter(OUTPUT_FILE, fourcc, fps, (width, height))
out_masks = cv2.VideoWriter(OUTPUT_FILE_MASKS, fourcc, fps, (width, height))
out_masked = cv2.VideoWriter(OUTPUT_FILE_MASKED, fourcc, fps, (width, height))


# Process each frame in the video
for i in tqdm(range(len(data))):
    ret, frame = cap.read()
    if not ret:
        break

    masks_frame = frame.copy() * 0

    # Draw bounding boxes and product names
    for obj in data[i]['instances']:

        x1, y1 = obj["bbox"][0], obj["bbox"][1]
        x2, y2 = obj["bbox"][2], obj["bbox"][3]
        color = COLORS.get(obj["instance_id"], (55, 55, 55)) # (0,0,255) #COLORS.get(obj["product_name"], (0, 0, 255))
        cv2.drawMarker(frame, obj["bbox_centroid"], [0, 0, 255], cv2.MARKER_CROSS, 100, 2)
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        if SHOW_LABELS:
            cv2.putText(frame, str(obj["instance_id"]), (x1, y1+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        # Segment with SAM
        # multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often
        # produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used
        # to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.
        if SEGMENTATE:
            mask_predictor.set_image(frame)

            input_box = np.array(obj["bbox"])
            input_point = np.array([obj["bbox_centroid"]])
            input_label = np.array([1]) # 0 means that the input_point shouldn't be included in the mask. 1 means that is part of the mask.

            mask, scores, logits = mask_predictor.predict(box = input_box,
                                                          point_coords = input_point,
                                                          point_labels = input_label,
                                                          multimask_output = False
                                                          )
          
            masks_frame[mask[0]] += 100

        for key, value in obj['keypoints'].items():
            cv2.drawMarker(frame, (int(value[0]), int(value[1])), [128,128,128], cv2.MARKER_CROSS, 30, 2)
            if SHOW_LABELS:
                cv2.putText(frame, key, (int(value[0]), int(value[1])+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    out.write(frame)

    masks_frame = np.clip(masks_frame,0,255)
    out_masks.write(masks_frame)

    if SHOW_SEGMENTATION and SEGMENTATE:
        masked_frame = cv2.addWeighted(frame, 1.0, masks_frame, 0.5, 0)
        out_masked.write(masked_frame)


# Release video capture and writer
cap.release()
out.release()
out_masks.release()
out_masked.release()

100%|██████████| 1240/1240 [00:21<00:00, 58.04it/s]
