In [2]:
# REMOVAL BACKGROUND FOR VIDEO
import os
import cv2
import numpy as np
import sys
from tqdm import tqdm
sys.path.append("..")  # Adjust path accordingly
from segment_anything import sam_model_registry, SamPredictor
from yolov5.detect import YOLODetector

# Initialize the detector and the SAM model
detector = YOLODetector(weights='./yolov5/yolov5m.pt')
sam_checkpoint = "vit_h.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

# Define the video file to process
video_path = './input/golf6.mp4'
output_video_path = './output/masked_video.avi'

# Open the input video
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if not cap.isOpened():
    print("Error opening video file")
    sys.exit()

frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

frames = []  
bounding_boxes = []  

# First phase: Detection
for _ in tqdm(range(frame_count), desc="Detecting objects"):
    ret, frame = cap.read()
    if not ret:
        break
    frames.append(frame)  

boxes = detector.Prediction(video_path)

# Second phase: Segmentation
for frame, boxes in tqdm(zip(frames, boxes), total=len(frames), desc="Applying segmentation"):
    if boxes:  
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        predictor.set_image(rgb_frame)
        masks, _, _ = predictor.predict(
            point_coords=None, point_labels=None, box=np.array([boxes])[None, :], multimask_output=False)
        final_mask = masks[0] > 0.3
        final_mask = np.stack([final_mask*255]*3, axis=-1)
        # segmented_frame = (frame * final_mask).astype(np.uint8)
        segmented_frame = (final_mask).astype(np.uint8)
        out.write(segmented_frame)
    else:
        out.write(frame)  

# Cleanup
cap.release()
out.release()
print("Finished processing video.")


Fusing layers... 
YOLOv5m summary: 290 layers, 21172173 parameters, 0 gradients
Detecting objects: 100%|██████████| 98/98 [00:00<00:00, 731.80it/s]
Applying segmentation: 100%|██████████| 98/98 [00:34<00:00,  2.81it/s]

Finished processing video.





In [None]:
import cv2
import numpy as np
import argparse
import torch
import torchvision
from PIL import Image
import psutil
import shutil
from tqdm import tqdm
import os
# from track_anything import TrackingAnything, parse_augment
from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
from segment_anythingss import sam_model_registry, SamPredictor
from yolov5m.detect import YOLODetector


sam_checkpoint = "vit_h.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

def get_frames_from_video(video_path):

    frames = []
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret == True:
            current_memory_usage = psutil.virtual_memory().percent
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            if current_memory_usage > 90:
                operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
                print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
                break
        else:
            break
            
    cap.release()
    return frames

def process_frame_with_sahi(frame_path, sam_model_type, detection_model_path, output_mask_dir):   
    os.makedirs(output_mask_dir, exist_ok=True)

    
    boxes = sahi_sliced_predict(
        image_path=frame_path,
        detection_model_type="yolov8",
        detection_model_path=detection_model_path,
        conf_th=0.5,
        image_size=640,
        slice_height=256,
        slice_width=256,
        overlap_height_ratio=0.2,
        overlap_width_ratio=0.2
    )

    sahi_segmentation = SahiAutoSegmentation()
    mask_image = sahi_segmentation.predict(
        source=frame_path,
        model_type=sam_model_type,
        input_box=boxes,
        multimask_output=False,
        random_color=False,
        show=False,
        save=False
    )
    # cv2.imwrite("mask_image.png", mask_image)
    frame = cv2.imread(frame_path)
    
    if isinstance(mask_image, torch.Tensor):
        mask_image_np = mask_image.cpu().numpy()
    else:
        mask_image_np = np.array(mask_image)

    # Ensure mask is correctly formed
    if len(mask_image_np.shape) == 4:
        # Flatten the first two dimensions if necessary
        mask_np = mask_image_np.reshape(-1, mask_image_np.shape[2], mask_image_np.shape[3])

        mask_np = mask_np[0]  # Taking the first mask as an example
    elif len(mask_image_np.shape) == 3:
        mask_np = mask_image_np[0]
    else:
        print(f"Unexpected mask shape: {mask_image_np.shape}")
        mask_np = np.zeros_like(frame, dtype=np.uint8)

    if mask_np.dtype != np.uint8:
        mask_np = mask_np.astype(np.uint8) * 255

    return mask_np



def main():
    
    frames = get_frames_from_video("./input/golf6.mp4")
    
    temp_frames_dir = './input/frames'
    temp_masks_dir = './input/masks'

    # Clearing and creating the 'frames' directory
    if os.path.exists(temp_frames_dir):
        shutil.rmtree(temp_frames_dir)
    os.makedirs(temp_frames_dir, exist_ok=True)

    # Clearing and creating the 'masks' directory
    if os.path.exists(temp_masks_dir):
        shutil.rmtree(temp_masks_dir)
    os.makedirs(temp_masks_dir, exist_ok=True)

    for i in tqdm(range(len(frames)), desc="Processing frames and masks"):
        frame = frames[i]
        frame_path = os.path.join(temp_frames_dir, f'{i:05d}.png')
        bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)  # Convert back to BGR for saving
        cv2.imwrite(frame_path, bgr_frame)

        mask_np = process_frame_with_sahi(
            frame_path=frame_path,
            sam_model_type="vit_h",
            detection_model_path="./checkpoints/yolov8n.pt",
            output_mask_dir=temp_masks_dir,
        )
        # print('dddd: ', mask_np)
        # mask_np = cv2.resize(mask_np, (frame.shape[1], frame.shape[0]))

        # mask_filename = f"{i:05d}.png"
        # cv2.imwrite(os.path.join(temp_masks_dir, mask_filename), mask_np)
        # if mask_np.ndim == 2:
        #     mask_np = np.stack([mask_np]*3, axis=-1)
        # elif mask_np.ndim == 3 and mask_np.shape[2] == 1:
        #     mask_np = np.concatenate([mask_np]*3, axis=-1)
        
        # mask_np = (mask_np * 255).astype(np.uint8)
        # print(mask_np.shape)
        # mask_np = cv2.resize(mask_np, (frame.shape[1], frame.shape[0]))

        # mask_filename = f"{i:05d}.png"
        # cv2.imwrite(os.path.join(temp_masks_dir, mask_filename), mask_np)

    

if __name__ == "__main__":
    main()
