In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
import os
from tqdm import tqdm
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Set default tensor type to float32
torch.set_default_dtype(torch.float32)
# Ensure numpy arrays are also float32 by default
np.set_printoptions(precision=6)

# Define this as a global variable to use across functions
mask_generator = None

# Load the SAM model - using MPS for macOS GPU by default
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "mps" if torch.backends.mps.is_available() else "cpu"  # Using MPS for macOS GPU
print(f"Using device: {device}")

def load_sam_model(device="mps"):
    """Load the SAM model on the specified device."""
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    
    # Create the mask generator with same parameters as in the notebook
    mask_gen = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.9,
        stability_score_thresh=0.96,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    return mask_gen

# Initialize global mask generator
mask_generator = load_sam_model(device)

def apply_masks_to_frame(frame, masks):
    """Apply colored masks to a frame."""
    # Create a transparent overlay
    mask_overlay = np.ones((frame.shape[0], frame.shape[1], 4), dtype=np.float32)
    mask_overlay[:, :, 3] = 0  # Set alpha channel to 0 (transparent)
    
    # Sort masks by area for better visualization (larger masks behind smaller ones)
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    
    # Apply each mask with a random color
    for mask in sorted_masks:
        m = mask['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])  # RGB + alpha
        mask_overlay[m] = color_mask
    
    # Convert frame to RGBA
    frame_rgba = np.ones((frame.shape[0], frame.shape[1], 4), dtype=np.float32)
    frame_rgba[:, :, :3] = frame / 255.0
    frame_rgba[:, :, 3] = 1.0  # Fully opaque
    
    # Combine the two images (background frame and mask overlay)
    composite = frame_rgba * (1 - mask_overlay[:, :, 3:4]) + mask_overlay * mask_overlay[:, :, 3:4]
    
    # Convert back to uint8 RGB
    result = (composite[:, :, :3] * 255).astype(np.uint8)
    return result

def create_binary_mask_visualization(frame_shape, masks):
    """Create a visualization where each mask gets a different color."""
    # Create an empty image
    vis = np.zeros((*frame_shape[:2], 3), dtype=np.uint8)
    
    # Sort masks by area
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    
    # Assign a random color to each mask
    for i, mask in enumerate(sorted_masks):
        color = np.random.randint(0, 255, 3, dtype=np.uint8)
        vis[mask['segmentation']] = color
    
    return vis

def process_video(input_video_path, output_segmentation_path, output_masks_path, use_cpu=False):
    global mask_generator
    
    # If forcing CPU, reload the model on CPU
    if use_cpu and device != "cpu":
        print("Switching to CPU as requested")
        mask_generator = load_sam_model("cpu")
    
    # Open the input video
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {input_video_path}")
        return
    
    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"Video properties: {width}x{height}, {fps} fps, {total_frames} frames")
    
    # Create output video writers
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    segmentation_writer = cv2.VideoWriter(output_segmentation_path, fourcc, fps, (width, height))
    masks_writer = cv2.VideoWriter(output_masks_path, fourcc, fps, (width, height))
    
    # Process each frame
    frame_count = 0
    progress_bar = tqdm(total=total_frames, desc="Processing frames")
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Convert from BGR to RGB (SAM expects RGB)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Ensure the frame is in float32 format
        frame_rgb = frame_rgb.astype(np.float32) / 255.0
        frame_rgb = (frame_rgb * 255.0).astype(np.uint8)
        
        # Generate masks for the current frame
        try:
            masks = mask_generator.generate(frame_rgb)
            print(f"Frame {frame_count}: Generated {len(masks)} masks")
        except TypeError as e:
            if "Cannot convert a MPS Tensor to float64" in str(e):
                # Fall back to CPU if MPS has float64 issues
                print("MPS float64 issue detected. Falling back to CPU for this frame.")
                mask_generator = load_sam_model("cpu")
                masks = mask_generator.generate(frame_rgb)
                print(f"Frame {frame_count}: Generated {len(masks)} masks (CPU fallback)")
            else:
                raise
        
        # Create visualization with segmentation overlay
        segmentation_frame = apply_masks_to_frame(frame_rgb, masks)
        segmentation_frame_bgr = cv2.cvtColor(segmentation_frame, cv2.COLOR_RGB2BGR)
        
        # Create mask visualization
        masks_vis = create_binary_mask_visualization(frame.shape, masks)
        
        # Write to output videos
        segmentation_writer.write(segmentation_frame_bgr)
        masks_writer.write(masks_vis)
        
        frame_count += 1
        progress_bar.update(1)
        
        # Optional: Display frame (comment out for faster processing)
        # cv2.imshow('Segmentation', segmentation_frame_bgr)
        # cv2.imshow('Masks', masks_vis)
        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break
    
    # Clean up
    progress_bar.close()
    cap.release()
    segmentation_writer.release()
    masks_writer.release()
    cv2.destroyAllWindows()
    
    print(f"Processing complete. Processed {frame_count} frames.")
    print(f"Output videos saved to:\n- {output_segmentation_path}\n- {output_masks_path}")

if __name__ == "__main__":
    input_video = "pottery.mp4"
    output_segmentation = "pottery_segmentation.mp4"
    output_masks = "pottery_masks.mp4"
    
    # If MPS is causing issues, you can set use_cpu=True to force CPU usage
    use_cpu = False
    
    process_video(input_video, output_segmentation, output_masks, use_cpu)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
import os
from tqdm import tqdm
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Set default tensor type to float32
torch.set_default_dtype(torch.float32)
# Ensure numpy arrays are also float32 by default
np.set_printoptions(precision=6)

# Monkeypatch tensor conversion to prevent float64 on MPS
if torch.backends.mps.is_available():
    # Override any potential float64 operations to use float32 instead
    torch._C._set_float32_matmul_precision('high')
    
    # Monkeypatch the tensor conversion to intercept float64 conversions
    original_tensor_method = torch.Tensor.to
    def patched_to_method(self, *args, **kwargs):
        # Check if trying to convert to float64
        if len(args) > 0 and (args[0] == torch.float64 or 
                             (isinstance(args[0], torch.dtype) and args[0] == torch.float64)):
            print("Intercepted attempt to convert tensor to float64, using float32 instead")
            return original_tensor_method(self, torch.float32, *args[1:], **kwargs)
        return original_tensor_method(self, *args, **kwargs)
    
    # Apply the patch
    torch.Tensor.to = patched_to_method
    
    # Also patch the tensor creation functions
    original_tensor = torch.tensor
    def patched_tensor(*args, **kwargs):
        if 'dtype' in kwargs and kwargs['dtype'] == torch.float64:
            print("Intercepted attempt to create float64 tensor, using float32 instead")
            kwargs['dtype'] = torch.float32
        return original_tensor(*args, **kwargs)
    
    torch.tensor = patched_tensor

# Define this as a global variable to use across functions
mask_generator = None

# Load the SAM model - using MPS for macOS GPU by default
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "mps" if torch.backends.mps.is_available() else "cpu"  # Using MPS for macOS GPU
print(f"Using device: {device}")

def load_sam_model(device="mps"):
    """Load the SAM model on the specified device."""
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    
    # Force model parameters to float32 before moving to device
    for param in sam.parameters():
        param.data = param.data.to(torch.float32)
    
    # Move model to the specified device with explicit float32 dtype
    sam.to(device=device)
    sam.to(dtype=torch.float32)  # Explicitly set dtype after moving to device
    
    # Force all buffers to float32 as well
    for buffer in sam.buffers():
        if buffer.data.dtype == torch.float64:
            buffer.data = buffer.data.to(torch.float32)
    
    # Create the mask generator with same parameters as in the notebook
    mask_gen = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.9,
        stability_score_thresh=0.96,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Requires open-cv to run post-processing
    )
    return mask_gen

# Initialize global mask generator
mask_generator = load_sam_model(device)

def apply_masks_to_frame(frame, masks):
    """Apply colored masks to a frame."""
    # Create a transparent overlay
    mask_overlay = np.ones((frame.shape[0], frame.shape[1], 4), dtype=np.float32)
    mask_overlay[:, :, 3] = 0  # Set alpha channel to 0 (transparent)
    
    # Sort masks by area for better visualization (larger masks behind smaller ones)
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    
    # Apply each mask with a random color
    for mask in sorted_masks:
        m = mask['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])  # RGB + alpha
        mask_overlay[m] = color_mask
    
    # Convert frame to RGBA
    frame_rgba = np.ones((frame.shape[0], frame.shape[1], 4), dtype=np.float32)
    frame_rgba[:, :, :3] = frame / 255.0
    frame_rgba[:, :, 3] = 1.0  # Fully opaque
    
    # Combine the two images (background frame and mask overlay)
    composite = frame_rgba * (1 - mask_overlay[:, :, 3:4]) + mask_overlay * mask_overlay[:, :, 3:4]
    
    # Convert back to uint8 RGB
    result = (composite[:, :, :3] * 255).astype(np.uint8)
    return result

def create_binary_mask_visualization(frame_shape, masks):
    """Create a visualization where each mask gets a different color."""
    # Create an empty image
    vis = np.zeros((*frame_shape[:2], 3), dtype=np.uint8)
    
    # Sort masks by area
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    
    # Assign a random color to each mask
    for i, mask in enumerate(sorted_masks):
        color = np.random.randint(0, 255, 3, dtype=np.uint8)
        vis[mask['segmentation']] = color
    
    return vis

def process_video(input_video_path, output_segmentation_path, output_masks_path, use_cpu=False, frame_skip=0, resize_factor=1.0):
    global mask_generator
    
    # If forcing CPU, reload the model on CPU
    if use_cpu and device != "cpu":
        print("Switching to CPU as requested")
        mask_generator = load_sam_model("cpu")
    else:
        # For MPS, let's try to ensure everything is float32
        if device == "mps":
            print("Ensuring all model parameters are float32 for MPS compatibility")
            mask_generator = load_sam_model("mps")
    
    # Open the input video
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file {input_video_path}")
        return
    
    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"Video properties: {width}x{height}, {fps} fps, {total_frames} frames")
    
    # Create output video writers
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    segmentation_writer = cv2.VideoWriter(output_segmentation_path, fourcc, fps, (width, height))
    masks_writer = cv2.VideoWriter(output_masks_path, fourcc, fps, (width, height))
    
    # Process each frame
    frame_count = 0
    processed_count = 0
    progress_bar = tqdm(total=total_frames, desc="Processing frames")
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Optional frame skipping
        if frame_skip > 0 and frame_count % (frame_skip + 1) != 0:
            frame_count += 1
            progress_bar.update(1)
            continue
        
        # Optional resizing for faster processing
        if resize_factor != 1.0:
            h, w = frame.shape[:2]
            new_h, new_w = int(h * resize_factor), int(w * resize_factor)
            frame = cv2.resize(frame, (new_w, new_h))
        
        # Convert from BGR to RGB (SAM expects RGB)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Ensure the frame is in float32 format
        frame_rgb = frame_rgb.astype(np.float32) / 255.0
        frame_rgb = (frame_rgb * 255.0).astype(np.uint8)
        
        # Make sure the frame is explicitly set to float32
        frame_rgb = frame_rgb.astype(np.float32) / 255.0
        frame_rgb = (frame_rgb * 255.0).astype(np.uint8)
        
        # Generate masks for the current frame
        try:
            with torch.no_grad():  # Disable gradient tracking for inference
                masks = mask_generator.generate(frame_rgb)
            device_used = "MPS (GPU)"
        except TypeError as e:
            if "Cannot convert a MPS Tensor to float64" in str(e):
                # Fall back to CPU if MPS has float64 issues
                print("MPS float64 issue detected. Falling back to CPU for this frame.")
                # Reload model on CPU with explicit float32
                mask_generator = load_sam_model("cpu")
                with torch.no_grad():
                    masks = mask_generator.generate(frame_rgb)
                device_used = "CPU fallback"
            else:
                raise
        except Exception as e:
            print(f"Error processing frame {frame_count}: {str(e)}")
            # Try CPU as last resort
            mask_generator = load_sam_model("cpu")
            with torch.no_grad():
                masks = mask_generator.generate(frame_rgb)
            device_used = "CPU fallback (after error)"
            
        print(f"Frame {frame_count}: Generated {len(masks)} masks ({device_used})")
        
        # Create visualization with segmentation overlay
        segmentation_frame = apply_masks_to_frame(frame_rgb, masks)
        segmentation_frame_bgr = cv2.cvtColor(segmentation_frame, cv2.COLOR_RGB2BGR)
        
        # Create mask visualization
        masks_vis = create_binary_mask_visualization(frame.shape, masks)
        
        # Write to output videos
        segmentation_writer.write(segmentation_frame_bgr)
        masks_writer.write(masks_vis)
        
        frame_count += 1
        processed_count += 1
        progress_bar.update(1)
        
        # Optionally save a frame periodically to check progress
        if processed_count % 10 == 0:
            cv2.imwrite(f"progress_frame_{processed_count}.jpg", segmentation_frame_bgr)
        
        # Optional: Display frame (comment out for faster processing)
        # cv2.imshow('Segmentation', segmentation_frame_bgr)
        # cv2.imshow('Masks', masks_vis)
        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break
    
    # Clean up
    progress_bar.close()
    cap.release()
    segmentation_writer.release()
    masks_writer.release()
    cv2.destroyAllWindows()
    
    print(f"Processing complete. Processed {processed_count}/{frame_count} frames.")
    print(f"Output videos saved to:\n- {output_segmentation_path}\n- {output_masks_path}")

if __name__ == "__main__":
    input_video = "pottery.mp4"
    output_segmentation = "pottery_segmentation.mp4"
    output_masks = "pottery_masks.mp4"
    
    # If MPS is causing issues, you can set use_cpu=True to force CPU usage
    use_cpu = False
    
    # Optional: Skip frames to speed up processing (0 = process every frame)
    frame_skip = 0
    
    # Optional: Resize factor to reduce processing time (1.0 = original size)
    # For example, 0.5 will reduce the width and height by half, processing 4x faster
    resize_factor = 0.5
    
    process_video(input_video, output_segmentation, output_masks, 
                 use_cpu=use_cpu, frame_skip=frame_skip, resize_factor=resize_factor)

Using device: mps
Ensuring all model parameters are float32 for MPS compatibility
Video properties: 1080x1920, 25.0 fps, 84 frames



[Acessing frames:   0%|                                                                      | 0/84 [00:00<?, ?it/s]

MPS float64 issue detected. Falling back to CPU for this frame.



[Acessing frames:   1%|▋                                                          | 1/84 [01:49<2:32:00, 109.89s/it]

Frame 0: Generated 31 masks (CPU fallback)



[Acessing frames:   2%|█▍                                                          | 2/84 [03:17<2:12:22, 96.86s/it]

Frame 1: Generated 31 masks (MPS (GPU))



[Acessing frames:   4%|██▏                                                         | 3/84 [04:45<2:04:56, 92.55s/it]

Frame 2: Generated 29 masks (MPS (GPU))



[Acessing frames:   5%|██▊                                                         | 4/84 [06:10<1:59:45, 89.82s/it]

Frame 3: Generated 29 masks (MPS (GPU))



[Acessing frames:   6%|███▌                                                        | 5/84 [07:35<1:55:40, 87.85s/it]

Frame 4: Generated 32 masks (MPS (GPU))



[Acessing frames:   7%|████▎                                                       | 6/84 [08:59<1:52:53, 86.83s/it]

Frame 5: Generated 30 masks (MPS (GPU))



[Acessing frames:   8%|█████                                                       | 7/84 [10:32<1:54:00, 88.84s/it]

Frame 6: Generated 32 masks (MPS (GPU))



[Acessing frames:  10%|█████▋                                                      | 8/84 [12:00<1:52:09, 88.55s/it]

Frame 7: Generated 30 masks (MPS (GPU))



[Acessing frames:  11%|██████▍                                                     | 9/84 [13:27<1:49:50, 87.88s/it]

Frame 8: Generated 29 masks (MPS (GPU))



[Acessing frames:  12%|███████                                                    | 10/84 [14:53<1:47:41, 87.32s/it]

Frame 9: Generated 33 masks (MPS (GPU))



[Acessing frames:  13%|███████▋                                                   | 11/84 [16:26<1:48:34, 89.24s/it]

Frame 10: Generated 30 masks (MPS (GPU))



[Acessing frames:  14%|████████▍                                                  | 12/84 [17:53<1:46:18, 88.59s/it]

Frame 11: Generated 31 masks (MPS (GPU))



[Acessing frames:  15%|█████████▏                                                 | 13/84 [19:21<1:44:32, 88.34s/it]

Frame 12: Generated 31 masks (MPS (GPU))



[Acessing frames:  17%|█████████▊                                                 | 14/84 [20:46<1:41:55, 87.36s/it]

Frame 13: Generated 33 masks (MPS (GPU))



[Acessing frames:  18%|██████████▌                                                | 15/84 [22:12<1:39:57, 86.92s/it]

Frame 14: Generated 36 masks (MPS (GPU))



[Acessing frames:  19%|███████████▏                                               | 16/84 [23:46<1:41:00, 89.12s/it]

Frame 15: Generated 33 masks (MPS (GPU))



[Acessing frames:  20%|███████████▉                                               | 17/84 [25:25<1:42:42, 91.98s/it]

Frame 16: Generated 34 masks (MPS (GPU))



[Acessing frames:  21%|████████████▋                                              | 18/84 [27:04<1:43:28, 94.06s/it]

Frame 17: Generated 37 masks (MPS (GPU))
