In [1]:
import torch
import torchvision 
import os 
import io
import subprocess
import sys
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import zipfile 
import cv2
import shutil
import subprocess
from tensorboard import program
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from PIL import Image as PILImage
from tqdm import tqdm

In [2]:
#######################SET UP ###########################
# Selecting the best device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Enable precision tuning for CUDA
if device.type == "cuda":
    # Enable TF32 for Ampere GPUs (compute capability >= 8)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

elif device.type == "mps":
    print(
        "\n[Warning] Support for MPS devices is preliminary. SAM 2 was trained with CUDA and may "
        "produce numerically different outputs or slower performance on MPS."
        "\nSee: https://github.com/pytorch/pytorch/issues/84936"
    )

Using device: cuda


In [3]:
########################### ENV SETUP ################################
# display system info
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

# Install required packages
subprocess.check_call([sys.executable, "-m", "pip", "install", "opencv-python", "matplotlib"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/sam2.git"])

# Ensuring the folders exist
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("outputs/VidWGroundTruth", exist_ok=True)


PyTorch version: 2.7.0+cu126
Torchvision version: 0.22.0+cu126
CUDA is available: True
Collecting git+https://github.com/facebookresearch/sam2.git
  Cloning https://github.com/facebookresearch/sam2.git to /tmp/pip-req-build-l5zq3tje


  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/sam2.git /tmp/pip-req-build-l5zq3tje


  Resolved https://github.com/facebookresearch/sam2.git to commit 2b90b9f5ceec907a1c18123530e92e794ad901a4
  Installing build dependencies: started
  Installing build dependencies: still running...
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


In [None]:
############# SAVING ONLY VIDEOS WITH GROUND TRUTH #############
#paths
common_filesCSVPath = "/vol/bitbucket/nc624/echonet/dynamic/common_files.csv"
originalVideosDirPath = "/vol/bitbucket/nc624/echonet/dynamic/a4c-video-dir/Videos"
outputDirPath = "/vol/bitbucket/nc624/sam2/outputs/VidWGroundTruth"

#Get all the filenames from the csv file
df = pd.read_csv(common_filesCSVPath)
unique_filenames = df['Filename'].drop_duplicates() #So now the dataframe only contains unique filename and frames 

filtered_basenames = set() #removes duplicates 
for filename in unique_filenames:
    base_name = os.path.splitext(filename)[0] #Get the file number without .png
    filtered_base = base_name.split('_')[0]
    filtered_basenames.add(filtered_base)


count = 0
for basename in filtered_basenames:
    # search for the video in originalVideosDirPath
    video_filename = f"{basename}.avi"
    video_path = os.path.join(originalVideosDirPath, video_filename)
    if not os.path.exists(video_path):
        print(f"Video not found: {video_path}")
        continue

    shutil.copy2(video_path, outputDirPath)
    count +=1

print(f"Total number of videos copied: {count}")

Total number of videos copied: 1276


In [4]:

# Download the model checkpoint
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
subprocess.check_call([
    "wget", "-P", "checkpoints", checkpoint_url
])

--2025-07-04 12:07:37--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 2600:9000:2684:b600:13:6e38:acc0:93a1, 2600:9000:2684:9400:13:6e38:acc0:93a1, 2600:9000:2684:8800:13:6e38:acc0:93a1, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|2600:9000:2684:b600:13:6e38:acc0:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 898083611 (856M) [application/vnd.snesdev-page-table]
Saving to: ‘checkpoints/sam2.1_hiera_large.pt.10’

     0K .......... .......... .......... .......... ..........  0% 15.7M 54s
    50K .......... .......... .......... .......... ..........  0% 21.9M 47s
   100K .......... .......... .......... .......... ..........  0% 54.1M 36s
   150K .......... .......... .......... .......... ..........  0% 41.0M 33s
   200K .......... .......... .......... .......... ..........  0% 66.6M 29s
   250K .......... .......... .......... .........

0

In [5]:
############### LOADING SAM2 VIDEO PREDICTOR ####################
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

#initialise the SAM2 video predictor
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)


''' SHOW MASK
This function draws a segmentation mask on a matplotlib axis (ax).
It helps visually show what area of the image corresponds to the segmented object.

How it works:
mask: a binary (0/1) mask where 1 indicates the segmented region.
ax: the matplotlib axis to draw on.
obj_id: ID of the object, used to pick a consistent color (optional).
random_color: if True, picks a random color.

If random_color is True, generate 3 random RGB values and append 0.6 (transparency).
So color becomes an RGBA array (e.g., [0.8, 0.3, 0.1, 0.6]).
In visualization, colors are often represented as arrays of 4 values:
[Red, Green, Blue, Alpha]
Each of R, G, B is between 0 and 1

Alpha (A) = transparency, also between 0 and 1:
1.0 = fully opaque
0.0 = fully transparent

Otherwise, use a predefined color map (tab10) which has 10 distinct colors.
obj_id selects which color to use from this map.
Adds transparency (alpha = 0.6) to make the overlay semi-transparent.

Reshape mask to match image size.
Multiply it with the RGBA color to create a colored overlay.
Use imshow() to draw it on the plot.
'''
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


''' SHOW POINTS
This function draws the point prompts you give to SAM2.
coords: array of 2D points.
labels: same length as coords — 1 for positive points (e.g., "this is part of the object"), 0 for negative.
Split the coordinates into positive and negative points.
Plot:
Green stars for positive points.
Red stars for negative points.
edgecolor='white' helps the points stand out against different backgrounds.
'''
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

''' SHOW BOX
This function draws a bounding box prompt around an object.
box: usually in the format [x0, y0, x1, y1].
Convert (x0, y0, x1, y1) into (x, y, width, height) — which is what plt.Rectangle() expects.
Draw a green rectangle on the image:
facecolor=(0, 0, 0, 0) makes it transparent inside.
lw=2 sets the line width.
'''
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


def log_image_to_tensorboard(writer, image_pil, tag="Image", step=0):
    image = image_pil.convert("RGB")
    image_np = np.array(image).astype(np.float32) / 255.0
    image_tensor = torch.tensor(image_np).permute(2, 0, 1)
    writer.add_image(tag, image_tensor, global_step=step)

#Defining my writer 
log_path = os.path.abspath("outputs/logs/run1")
os.makedirs(log_path, exist_ok=True)
tracking_address = log_path #the path of my log file
writer = SummaryWriter(log_dir=log_path) 

In [21]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "/vol/bitbucket/nc624/sam2/outputs/VidWGroundTruth"
output_dir = "/vol/bitbucket/nc624/sam2/outputs/videoFrame"

for filename in os.listdir(video_dir):
    if filename.endswith(".avi"):
        video_path = os.path.join(video_dir, filename)

        video_name = os.path.splitext(filename)[0]
        base_name = os.path.join(output_dir, video_name)

        os.makedirs(base_name, exist_ok=True)

        cmd = [
            "ffmpeg",
            "-i", video_path, 
            "-q:v", "2",
            "-start_number", "0",
            os.path.join(base_name, "%05d.jpg")
        ]
        subprocess.run(cmd, check=True)
print("All videos saved as .jpg frames")


ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena

All videos saved as .jpg frames


[out#0/image2 @ 0x63cea4305cc0] video:677kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown
frame=  168 fps=0.0 q=2.0 Lsize=N/A time=00:00:03.34 bitrate=N/A speed=5.93x    


In [11]:
############## Previewing the first video ###################
%matplotlib widget 

import matplotlib.pyplot as plt


# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "/vol/bitbucket/nc624/sam2/outputs/videoFrame"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
frame_path = os.path.join(video_dir, frame_names[frame_idx])
img = Image.open(frame_path)

#Convert PIL Image to NumPy array for pixel value
img_np = np.array(img)

#Create the figure
fig, ax = plt.subplots(figsize=(9, 6))
ax.set_title(f"frame {frame_idx}")
im = ax.imshow(img_np)
plt.show()

IndexError: list index out of range

In [7]:
torch.cuda.empty_cache()

In [None]:
################## INITIALIZING THE INFERENCE STATE #########################
'''
SAM 2 requires us to initialise an inference state on the video 
Like preparing all the video frames ahead of time so that 
 its ready to respond instantly when you give a prompt.
 the results live teporarily in RAM, and not saved to files. 
'''
inference_state = predictor.init_state(video_path=video_dir)

'''to reset the inference state, clears any previous prompts or segmentation'''
#UNCOMMENT BELOW 
predictor.reset_state(inference_state)



In [6]:
%pip install tqdm

Note: you may need to restart the kernel to use updated packages.


In [None]:
############ BOUNDING BOXES AS INPUT ##################
'''
for each row of bounding box.csv
for filename with more than one frame, pick the earliest frame occurence
then using the filename column before.avi, look cor the corresponding directory name in /vol/bitbucket/nc624/sam2/outputs/videoFrame
on the earliest frame occurence, input the bounding box coordinate and propagate 
'''
# Paths
bbox_csv_path = "/vol/bitbucket/nc624/echonet/dynamic/bounding_boxes.csv"
video_frame_root = "/vol/bitbucket/nc624/sam2/outputs/videoFrame"
filesWithGroundtruth_csv = "/vol/bitbucket/nc624/sam2/outputs/filesWithGroundtruth.csv"

#Create saving directory
segmented_mask_dir = "/vol/bitbucket/nc624/sam2/outputs/segmentedMasksCommon"
os.makedirs(segmented_mask_dir, exist_ok=True)
print(f"Created/verified directory: {segmented_mask_dir}")


# Load bounding box data
df = pd.read_csv(bbox_csv_path)

# Sort to get earliest frame
df = df.sort_values("Frame")
grouped = df.groupby("Filename", as_index=False).first()  # Get earliest frame per video

existing_filenames = []

for idx, row in tqdm(grouped.iterrows(), total=len(grouped), desc="Processing videos"): #loop through each row 
    filename = row["Filename"]  # e.g., "0X18B2F3A2E992AF3E.avi"
    ann_frame_idx = row["Frame"]
    box = np.array([row["Left_X"], row["Top_Y"], row["Right_X"], row["Bottom_Y"]], dtype=np.float32)

    # Get directory name by removing `.avi`
    video_dir_name = filename.replace(".avi", "")
    video_dir = os.path.join(video_frame_root, video_dir_name)

    # Check if directory exists
    if not os.path.isdir(video_dir):
        print(f"Skipping {video_dir_name}: directory not found")
        continue

    existing_filenames.append(filename)

    inference_state = predictor.init_state(video_path=video_dir)
    '''to reset the inference state, clears any previous prompts or segmentation'''
    #predictor.reset_state(inference_state)

    # Load frame names
    frame_names = sorted([
        p for p in os.listdir(video_dir)
        if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"] #filters for only jpg images 
    ])

    if ann_frame_idx >= len(frame_names):
        print(f"Frame index {ann_frame_idx} out of range for {video_dir_name}")
        continue

    print(f"Processing: {filename} at frame {ann_frame_idx}")

    ann_obj_id = 1  # assuming this is set somewhere in your environment
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        box=box,
    )
    
    #Propagate throughout video 
    video_segments = {}
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }

final_df = pd.DataFrame(existing_filenames, columns=["Filename"])
final_df.to_csv(filesWithGroundtruth_csv, index=False)
print(f"Saved {len(existing_filenames)} files to {filesWithGroundtruth_csv}")


In [None]:
'''
now for each filename with ground truth in the csv file 
we want 3 randomly generated coordinates of the ground truth mask from the same first frame
and then the respective box coordinate
propagate it through 
save only the 2 frame numbers 
'''
#csv path
csv_path = "/vol/bitbucket/nc624/sam2/outputs/filesWithGroundtruth.csv"
df = pd.read_csv(csv_path)

for filename in df["Filename"]: 


ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (460, 60) to refine the mask
points = np.array([[460, 60]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
# note that we also need to send the original box input along with
# the new refinement click together into `add_new_points_or_box`
box = np.array([300, 0, 500, 400], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
    box=box,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 30
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

In [None]:
#################################### TENSORBOARD ######################################

writer.close()
'''
Function purpose: 
Convert to RGB:
Ensures the image is in RGB mode, which is required for consistent visualization.

Convert to NumPy + Normalize:
Converts the image to a NumPy array of type float32 and scales pixel values to [0, 1] range.

Reformat for PyTorch:
Changes the shape from (H, W, C) to (C, H, W), which is the format PyTorch expects for images.

Log to TensorBoard:
Writes the image to TensorBoard under a given tag and step number.
Like recording data to visualise it later

'''

# Start TensorBoard
if __name__ == "__main__":
    tb = program.TensorBoard()
    tb.configure(argv=[None, '--logdir', tracking_address])
    url = tb.launch()
    print(f"TensorBoard is listed on {url}")