# LangSAM to predict wave contours

In [1]:
import os
import cv2
import numpy as np
from PIL import Image
from lang_sam import LangSAM
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import psutil
import traceback

In [2]:
#Specify paths
video_path = os.path.join("workspace", "original_video.mp4") #This will be split into frames
frames_path = os.path.join("workspace", "frames_wave") #This is where the original video frames will be saved
save_path = os.path.join("workspace", "results_wave") #This is where the output frames with the mask will be saved

combined_video_path = os.path.join(save_path, "combined_output_video.mp4")
mask_video=os.path.join(save_path,'output_mask_video.mp4')
final_video=os.path.join(save_path,'output_video.mp4')

#If paths do not exist, create them
if not os.path.exists(frames_path):
    os.makedirs(frames_path)
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [3]:
model = LangSAM(sam_type="sam2.1_hiera_large") #Default is "sam2.1_hiera_small"
text_prompt = "wave." #This is what we want to segment

Downloading: "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt" to /root/.cache/torch/hub/checkpoints/sam2.1_hiera_large.pt
100%|██████████| 856M/856M [00:19<00:00, 46.9MB/s] 


preprocessor_config.json:   0%|          | 0.00/457 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/933M [00:00<?, ?B/s]

In [4]:
def extract_frames(video_path, output_path=frames_path):
    # Create output directory if it doesn't exist
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    # Open the video file
    video = cv2.VideoCapture(video_path)
    
    # Get video properties
    fps = video.get(cv2.CAP_PROP_FPS)
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Initialize frame counter
    count = 0
    
    print(f"Extracting frames from video...")
    print(f"Total frames: {frame_count}")
    
    # Read and save frames
    while video.isOpened():
        ret, frame = video.read()
        
        if not ret:
            break
            
        # Save frame as JPEG file
        frame_path = os.path.join(output_path, f'frame_{count:04d}.jpg')
        cv2.imwrite(frame_path, frame)
        
        count += 1
        
        # Print progress
        if count % 100 == 0:
            print(f"Processed {count} frames...")
    
    # Release video capture object
    video.release()
    
    print(f"Extraction complete. {count} frames saved to {output_path}")

# Example usage:
extract_frames(video_path)

Extracting frames from video...
Total frames: 602
Processed 100 frames...
Processed 200 frames...
Processed 300 frames...
Processed 400 frames...
Processed 500 frames...
Processed 600 frames...
Extraction complete. 602 frames saved to workspace/frames_wave


Plot the mask on top of original frame

In [5]:
np.random.seed(3)
matplotlib.use('Agg')

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.array([*np.random.random(3), 1.0])
    else:
        color = np.array([0/255, 255/255, 0/255, 1.0])  # Changed to green (RGB: 0,255,0) from blue 30/255, 144/255, 255/255

    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    
    # Create empty image for contours only
    mask_image = np.zeros((h, w, 4))
    
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    # Smooth contours
    contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
    # Draw contours with thickness=1 for single pixel width
    cv2.drawContours(mask_image, contours, -1, color, thickness=1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, save_path=None, mask_path=None, point_coords=None, box_coords=None, input_labels=None, borders=True):
    
    for i, (mask, score) in enumerate(zip(masks, scores)):
        # Original image with mask
        original_width, original_height = image.size
        
        # Create figure with exact pixel dimensions
        fig1 = plt.figure(figsize=(original_width/100, original_height/100), dpi=100)
        ax1 = plt.Axes(fig1, [0., 0., 1., 1.])
        fig1.add_axes(ax1)
        
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        
        if save_path:
            plt.savefig(save_path, dpi=100, bbox_inches=None, pad_inches=0)

        # Save mask only
        fig2 = plt.figure(figsize=(original_width/100, original_height/100), dpi=100)
        ax2 = plt.Axes(fig2, [0., 0., 1., 1.])
        fig2.add_axes(ax2)
        
        show_mask(mask, plt.gca(), borders=borders)
        
        if mask_path:
            plt.savefig(mask_path, dpi=100, bbox_inches=None, pad_inches=0)

        # Force cleanup
        plt.clf()
        plt.close('all')
        gc.collect()

Create frames of mask and of final segmented object (±1sec/frame)

In [6]:
# import tracemalloc
# tracemalloc.start()

files = sorted(os.listdir(frames_path))
for ind, path_img in tqdm(enumerate(files), desc="Processing frames", total=len(files)):
    # if ind > 260: #Activate this if we run out of memory and have to restart from specific frame

        # Track memory usage before processing
        process = psutil.Process()
        # print(f"\nMemory usage before frame {ind}:")
        # print(f"RAM used: {process.memory_info().rss / 1024 / 1024:.2f} MB")
        
        image_path = frames_path+'/'+path_img
        image_pil = Image.open(image_path).convert("RGB")
        
        try:
            results = model.predict([image_pil], [text_prompt])
            # print(results)
            show_masks(image_pil,
                    results[0]['masks'],
                    results[0]['scores'], 
                    save_path=save_path+'/output_'+path_img,
                    mask_path=save_path+'/output_mask_'+path_img,
                    point_coords=results[0]['boxes'][0],
                    input_labels=results[0]['labels'],
                    borders=True)
            
            # Clean up variables explicitly
            del results
            image_pil.close()
            
        except Exception as e:
            if results[0]['masks'] is None:
                print("No masks found")
            else:
                print(f"Error: {e}")
                traceback.print_exc()
        
        # Clean up memory
        gc.collect()
        plt.close('all')
        
        # Check memory usage every 20 frames
        if ind > 0 and ind % 10 == 0:
            ram_used = process.memory_info().rss / 1024 / 1024  # Convert to MB
            if ram_used > 10000:  # If using more than 10GB RAM
                print(f"\nHigh RAM usage detected ({ram_used:.0f}MB) at frame {ind}")

                # # Memory snapshot for analysis
                # snapshot = tracemalloc.take_snapshot()
                # top_stats = snapshot.statistics("lineno")
                # print("\n[Top Memory Consumers]")
                # for stat in top_stats[:5]:
                #     print(stat)

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Processing frames: 100%|██████████| 602/602 [09:27<00:00,  1.06it/s]


Save frames as a video

In [7]:
# Get list of image files and sort them
image_files = sorted([f for f in os.listdir(save_path) if f.endswith(('.png', '.jpg', '.jpeg')) and 'output' in f and 'output_output' not in f and 'output_mask' not in f])
mask_files = sorted([f for f in os.listdir(save_path) if f.endswith(('.png', '.jpg', '.jpeg')) and 'output_mask' in f])

if len(image_files) > 0:
    # Read first image to get dimensions
    first_image = cv2.imread(os.path.join(save_path, image_files[0]))
    height, width, layers = first_image.shape

    # Define video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(final_video, fourcc, 30.0, (width, height)) #30 is the frame rate

    # Add each image to video
    for image_file in image_files:
        image_path = os.path.join(save_path, image_file)
        frame = cv2.imread(image_path)
        out.write(frame)

    # Release video writer
    out.release()
    print(f"Video saved as {final_video}")
else:
    print("No images found in the directory")

#Save mask
if len(mask_files) > 0:
    # Read first image to get dimensions
    first_image = cv2.imread(os.path.join(save_path, mask_files[0]))
    height, width, layers = first_image.shape

    # Define video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(mask_video, fourcc, 30.0, (width, height)) #30 is the frame rate

    # Add each image to video
    for image_file in mask_files:
        image_path = os.path.join(save_path, image_file)
        frame = cv2.imread(image_path)
        out.write(frame)

    # Release video writer
    out.release()
    print(f"Mask video saved as {mask_video}")
else:
    print("No mask images found in the directory")

Video saved as workspace/results_wave/output_video.mp4
Mask video saved as workspace/results_wave/output_mask_video.mp4


Create a figure with the original video, the mask video and the overlay video

In [8]:
# Open the video files
cap1 = cv2.VideoCapture(video_path)
cap2 = cv2.VideoCapture(mask_video)
cap3 = cv2.VideoCapture(final_video)

# Get video properties from the first video
fps = int(cap1.get(cv2.CAP_PROP_FPS))
width = int(cap1.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap1.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Create a VideoWriter object to save the combined video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4
out = cv2.VideoWriter(combined_video_path, fourcc, fps, (width * 3, height))

while True:
    # Read frames from each video
    ret1, frame1 = cap1.read()
    ret2, frame2 = cap2.read()
    ret3, frame3 = cap3.read()

    # Break the loop if any video ends
    if not ret1 or not ret2 or not ret3:
        break

    # Resize frames to the same height (optional, but recommended)
    frame1 = cv2.resize(frame1, (width, height))
    frame2 = cv2.resize(frame2, (width, height))
    frame3 = cv2.resize(frame3, (width, height))

    # Combine frames side by side
    combined_frame = np.hstack((frame1, frame2, frame3))

    # Write the combined frame to the output video
    out.write(combined_frame)

print(f"Video 1 - Width: {int(cap1.get(cv2.CAP_PROP_FRAME_WIDTH))}, Height: {int(cap1.get(cv2.CAP_PROP_FRAME_HEIGHT))}, FPS: {int(cap1.get(cv2.CAP_PROP_FPS))}")
print(f"Video 2 - Width: {int(cap2.get(cv2.CAP_PROP_FRAME_WIDTH))}, Height: {int(cap2.get(cv2.CAP_PROP_FRAME_HEIGHT))}, FPS: {int(cap2.get(cv2.CAP_PROP_FPS))}")
print(f"Video 3 - Width: {int(cap3.get(cv2.CAP_PROP_FRAME_WIDTH))}, Height: {int(cap3.get(cv2.CAP_PROP_FRAME_HEIGHT))}, FPS: {int(cap3.get(cv2.CAP_PROP_FPS))}")

# Release everything when done
cap1.release()
cap2.release()
cap3.release()
out.release()

print(f"Combined video saved as {combined_video_path}")

Video 1 - Width: 1024, Height: 1024, FPS: 30
Video 2 - Width: 1024, Height: 1024, FPS: 30
Video 3 - Width: 1024, Height: 1024, FPS: 30
Combined video saved as workspace/results_wave/combined_output_video.mp4


# SAM2

In [9]:
#https://ai.meta.com/blog/segment-anything-2/?utm_source=twitter&utm_medium=organic_social&utm_content=reel&utm_campaign=sam2
#Load image from desktop
# from PIL import Image
# import numpy as np
# import matplotlib.pyplot as plt
# import os
# import torch
# import cv2
# from sam2.build_sam import build_sam2
# from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
# from sam2.sam2_image_predictor import SAM2ImagePredictor

In [10]:
# image_path = frames_path+"/frame_0180.jpg"#"/workspace/wave1.png"#os.path.join(os.path.expanduser("~"), "Desktop", "wave1.png")
# image = Image.open(image_path)
# image_np = np.array(image.convert("RGB"))
# image

In [11]:
# download_url="https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"

In [12]:
# #Taken from https://github.com/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb
# # checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
# # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
# # checkpoint = "/workspace/sam2/checkpoints/sam2.1_hiera_large.pt"
# # model_cfg = "/workspace/sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
# checkpoint = desktop_path+"/sam2.1_hiera_large.pt"
# model_cfg = desktop_path+"/sam2.1_hiera_l.yaml"

# # mask_generator = SAM2AutomaticMaskGenerator(build_sam2(model_cfg, checkpoint, device="cuda"))
# sam2 = build_sam2(model_cfg, checkpoint, device="cuda", apply_postprocessing=False)
# mask_generator=SAM2AutomaticMaskGenerator(
#     model=sam2,
#     points_per_side=60,
#     # points_per_batch=128,
#     # pred_iou_thresh=0.7,
#     # stability_score_thresh=0.92,
#     # stability_score_offset=0.7,
#     # crop_n_layers=1,
#     # box_nms_thresh=0.7,
#     # crop_n_points_downscale_factor=2,
#     # min_mask_region_area=25.0,
#     # use_m2m=True,
# )

# masks = mask_generator.generate(image_np)

In [13]:
# np.random.seed(3)

# def show_anns(anns, borders=True):
#     if len(anns) == 0:
#         return
#     sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
#     ax = plt.gca()
#     ax.set_autoscale_on(False)

#     img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
#     img[:, :, 3] = 0
#     for ann in sorted_anns:
#         m = ann['segmentation']
#         color_mask = np.concatenate([np.random.random(3), [0.5]])
#         img[m] = color_mask 
#         if borders:
#             contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
#             # Try to smooth contours
#             contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
#             cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

#     ax.imshow(img)

In [14]:
# plt.figure(figsize=(20, 20))
# plt.imshow(image)
# show_anns(masks)
# plt.axis('off')
# plt.show() 

Image Predictor

In [15]:
# # checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
# # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
# predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

# input_point = np.array([[300, 150]])
# input_label = np.array([1])

# with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
#     predictor.set_image(image_np)
#     masks, scores, logits = predictor.predict(point_coords=input_point,
#     point_labels=input_label,
#     multimask_output=True)

In [16]:
# np.random.seed(3)

# def show_mask(mask, ax, random_color=False, borders = True):
#     if random_color:
#         color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
#     else:
#         color = np.array([30/255, 144/255, 255/255, 0.6])
#     h, w = mask.shape[-2:]
#     mask = mask.astype(np.uint8)
#     mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#     if borders:
#         contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
#         # Try to smooth contours
#         contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
#         mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
#     ax.imshow(mask_image)

# def show_points(coords, labels, ax, marker_size=375):
#     pos_points = coords[labels==1]
#     neg_points = coords[labels==0]
#     ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
#     ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

# def show_box(box, ax):
#     x0, y0 = box[0], box[1]
#     w, h = box[2] - box[0], box[3] - box[1]
#     ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

# def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
#     for i, (mask, score) in enumerate(zip(masks, scores)):
#         plt.figure(figsize=(10, 10))
#         plt.imshow(image)
#         show_mask(mask, plt.gca(), borders=borders)
#         if point_coords is not None:
#             assert input_labels is not None
#             show_points(point_coords, input_labels, plt.gca())
#         if box_coords is not None:
#             # boxes
#             show_box(box_coords, plt.gca())
#         if len(scores) > 1:
#             plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
#         plt.axis('off')
#         plt.show()

In [17]:
# #Load image from desktop
# path='./test_video/'
# for path_img in os.listdir(path):
#     image_path = path+path_img
#     image = Image.open(image_path)
#     image_np = np.array(image.convert("RGB"))
#     show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)

In [18]:
# mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

# masks, scores, _ = predictor.predict(
#     point_coords=input_point,
#     point_labels=input_label,
#     mask_input=mask_input[None, :, :],
#     multimask_output=False,
# )

In [19]:
# show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)