In [17]:
import cv2
import numpy as np
import torch
from sam2.build_sam import build_sam2_video_predictor
import matplotlib.pyplot as plt
import os

def create_frames_directory(video_path, max_frames=100):
    # Get the directory of the video
    video_dir = os.path.dirname(video_path)
    
    # Create a new directory named 'frame_directory' in the same location as the video
    frames_dir = os.path.join(video_dir, 'frame_directory')
    os.makedirs(frames_dir, exist_ok=True)
    
    # Open the video file
    video = cv2.VideoCapture(video_path)
    
    if not video.isOpened():
        raise ValueError(f"Error opening video file: {video_path}")
    
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_step = max(1, total_frames // max_frames)
    
    frame_count = 0
    saved_count = 0
    while saved_count < max_frames:
        ret, frame = video.read()
        if not ret:
            break
        
        if frame_count % frame_step == 0:
            # Save the frame as an image file
            frame_filename = os.path.join(frames_dir, f"{saved_count:05d}.jpg")
            cv2.imwrite(frame_filename, frame)
            saved_count += 1
        
        frame_count += 1
    
    video.release()
    
    print(f"Extracted {saved_count} frames to {frames_dir} (from total {total_frames} frames)")
    
    # Return the path to the directory containing the frames
    return frames_dir

def generate_masklet(predictor, frames_dir, coordinate, frame_number):
    x, y = coordinate
    points = np.array([[x, y]], dtype=np.float32)
    labels = np.array([1], dtype=np.int32)  # 1 for positive click

    # Initialize inference state with the frames directory
    inference_state = predictor.init_state(video_path=frames_dir)

    # Add the point
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=frame_number,
        obj_id=0,  # Assuming we're working with the first object
        points=points,
        labels=labels,
    )

    # Propagate through the video and collect results
    video_segments = {}
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }

    return video_segments

def plot_masklet(frames_dir, video_segments, num_frames=5):
    frame_files = sorted(os.listdir(frames_dir))
    num_frames = min(num_frames, len(frame_files))
    fig, axs = plt.subplots(2, num_frames, figsize=(5*num_frames, 10))
    
    frame_indices = sorted(video_segments.keys())
    selected_indices = [frame_indices[i * len(frame_indices) // num_frames] for i in range(num_frames)]
    
    for i, frame_idx in enumerate(selected_indices):
        # Load and plot original frame
        frame_path = os.path.join(frames_dir, frame_files[frame_idx])
        frame = cv2.imread(frame_path)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        axs[0, i].imshow(frame_rgb)
        axs[0, i].set_title(f"Frame {frame_idx}")
        axs[0, i].axis('off')
        
        # Plot frame with masklet overlay
        axs[1, i].imshow(frame_rgb)
        mask = next(iter(video_segments[frame_idx].values()))  # Get the first (and likely only) mask
        axs[1, i].imshow(mask, alpha=0.7, cmap='jet')
        axs[1, i].set_title(f"Masklet {frame_idx}")
        axs[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

def main():
    # Default parameters
    sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
    model_cfg = "sam2_hiera_l.yaml"
    coordinate = (210, 350)
    frame_number = 10

    video_path = "/lisc/scratch/neurobiology/zimmer/schaar/Behavior/High_Res_Population/110620024/test_SAM2/2024-06-10_14-58-26_trainingsdata_clean3/2024-06-10_14-58-26_trainingsdata_clean3_track_0/output/track.avi"

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Extract frames from video (limit to 100 frames)
    frames_dir = create_frames_directory(video_path, max_frames=100)

    # Load the SAM2 model
    predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
    
    # Generate masklet across the video based on the specified coordinate and frame
    video_segments = generate_masklet(predictor, frames_dir, coordinate, frame_number)
    
    print("Masklet generated across the video.")
    
    # Plot the masklet
    plot_masklet(frames_dir, video_segments)
    
    return video_segments

if __name__ == "__main__":
    main()

Extracted 100 frames to /lisc/scratch/neurobiology/zimmer/schaar/Behavior/High_Res_Population/110620024/test_SAM2/2024-06-10_14-58-26_trainingsdata_clean3/2024-06-10_14-58-26_trainingsdata_clean3_track_0/output/frame_directory (from total 6000 frames)


frame loading (JPEG): 100%|███████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.01it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 14.57 GiB of which 200.75 MiB is free. Process 1557306 has 7.95 GiB memory in use. Including non-PyTorch memory, this process has 6.42 GiB memory in use. Of the allocated memory 5.93 GiB is allocated by PyTorch, and 372.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)