In [1]:
import torch
import cv2
import numpy as np
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
from torchvision.transforms import functional as F

In [2]:
import cv2
import os
import re
import torch
import torchvision
import numpy as np
import pandas as pd
import random
import xml.etree.ElementTree as ET
import torchvision.transforms as T
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from matplotlib import patches
from tqdm import tqdm

In [3]:
#!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
#!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device="cpu"
print(device)

cpu


In [5]:
class FasterRCNN_SAM:
    def __init__(self, faster_rcnn_model, device=device):
        # Initialize Faster R-CNN
        self.detector = faster_rcnn_model
        self.device = device
        self.detector.to(device)
        self.detector.eval()

        # Initialize SAM
        self.sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")

        self.sam.to(device)
        self.predictor = SamPredictor(self.sam)

    def predict_and_segment(self, image, conf_threshold=0.6):
        # Convert PIL Image to numpy array if needed
        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image
    
        # Prepare image for Faster R-CNN
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        # Convert to tensor and create a list with single tensor (important fix)
        image_tensor = transform(Image.fromarray(image_np)).to(self.device)
        image_list = [image_tensor]  # Create list with single tensor
    
        # Get Faster R-CNN predictions
        with torch.no_grad():
            detections = self.detector(image_list)[0]  # Pass list of tensors
    
        # Filter detections by confidence
        keep = detections['scores'] > conf_threshold
        boxes = detections['boxes'][keep]
        labels = detections['labels'][keep]
        scores = detections['scores'][keep]
    
        # Set image for SAM
        self.predictor.set_image(image_np)
    
        # Generate masks for each detection
        masks = []
        for box in boxes:
            # Convert box to prompt format
            box_prompt = box.cpu().numpy()
            
            # Get SAM prediction
            masks_pred, _, _ = self.predictor.predict(
                box=box_prompt,
                multimask_output=False
            )
            masks.append(masks_pred[0])  # Take first mask
    
        return {
            'boxes': boxes.cpu().numpy(),
            'labels': labels.cpu().numpy(),
            'scores': scores.cpu().numpy(),
            'masks': np.array(masks) if masks else np.array([])  # Handle case with no detections
        }
        
    def visualize_predictions(self, image, predictions, class_names, alpha=0.5):
        """Visualize detections with both boxes and segmentation masks"""
        image_np = np.array(image)
        overlay = image_np.copy()

        # Generate random colors for each class
        colors = np.random.randint(0, 255, size=(len(class_names), 3))

        # Draw masks
        for mask, label, score, box in zip(
            predictions['masks'],
            predictions['labels'],
            predictions['scores'],
            predictions['boxes']
        ):
            color = colors[label - 1]
            
            # Draw segmentation mask
            overlay[mask] = overlay[mask] * (1 - alpha) + color * alpha
            
            # Draw bounding box
            cv2.rectangle(
                overlay,
                (int(box[0]), int(box[1])),
                (int(box[2]), int(box[3])),
                color.tolist(),
                2
            )
            
            # Add label
            class_name = VOC_CLASSES_inverted.get(label, "Unknown")
            label_text = f'{class_name}: {score:.2f}'
            cv2.putText(
                overlay,
                label_text,
                (int(box[0]), int(box[1] - 5)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                color.tolist(),
                2
            )

        return overlay

def process_video_with_sam(model_combo, video_path, output_path, class_names, 
                          conf_threshold=0.6, fps=30):
    """Process video with combined Faster R-CNN and SAM"""
    cap = cv2.VideoCapture(video_path)
    
    if not cap.isOpened():
        print("Error: Could not open video stream.")
        return
        
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    out = cv2.VideoWriter(
        output_path,
        cv2.VideoWriter_fourcc(*'mp4v'),
        fps,
        (frame_width, frame_height)
    )
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Get predictions and segmentations
        predictions = model_combo.predict_and_segment(
            frame_rgb,
            conf_threshold=conf_threshold
        )
        
        # Visualize results
        output_frame = model_combo.visualize_predictions(
            frame_rgb,
            predictions,
            class_names
        )
        
        # Convert back to BGR for OpenCV
        output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
        
        # Write frame
        out.write(output_frame)
        
        # Display frame
        #cv2.imshow('Video', output_frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()



In [6]:
# Define VOC_CLASSES for segmentation
VOC_CLASSES = {
    "bicycle": 2, "bus": 6, "car": 7, "motorbike": 14, "person": 15
}
# Re-index values
VOC_CLASSES_ReIndex = {name: idx+1 for idx, (name, _) in enumerate(VOC_CLASSES.items())}
print(VOC_CLASSES_ReIndex)

{'bicycle': 1, 'bus': 2, 'car': 3, 'motorbike': 4, 'person': 5}


In [7]:
# get model
def get_model(num_classes):
    # Load the pre-trained Faster R-CNN model with ResNet50 backbone
    model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT, progress=True, 
                                weights_backbone=ResNet50_Weights.DEFAULT, trainable_backbone_layers=0)

    # replace for custom classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features  # Get input features of the classifier
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)  # Replace with a new head
    
    #print(model)
    return model

In [8]:
# Number of object classes + 1 background
num_classes = len(VOC_CLASSES_ReIndex) + 1 

# Load the Faster R-CNN model
model = get_model(num_classes)
# to device
model.to(device)
# fine tunning classes
model.roi_heads.box_predictor

FastRCNNPredictor(
  (cls_score): Linear(in_features=1024, out_features=6, bias=True)
  (bbox_pred): Linear(in_features=1024, out_features=24, bias=True)
)

In [9]:
# VOC_CLASSES_inverted
VOC_CLASSES_inverted = {v: k for k, v in VOC_CLASSES_ReIndex.items()}
print(VOC_CLASSES_inverted)

{1: 'bicycle', 2: 'bus', 3: 'car', 4: 'motorbike', 5: 'person'}


In [10]:
# get paramns trainables
params = [param for name, param in model.named_parameters() if param.requires_grad]
# Initialize optimizer and learning rate scheduler
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3
                                                      )

In [11]:

# Load your trained Faster R-CNN model
faster_rcnn_model = get_model(len(VOC_CLASSES) + 1)
checkpoint = torch.load('checkpoints/best_model_fasterRCNN.pth')
faster_rcnn_model.load_state_dict(checkpoint['model_state_dict'])

# Initialize combined model
model_combo = FasterRCNN_SAM(faster_rcnn_model)

# Process single image
image = Image.open('img.jpg')
predictions = model_combo.predict_and_segment(image)
result = model_combo.visualize_predictions(image, predictions, VOC_CLASSES)
cv2.imwrite('output_image.jpg', cv2.cvtColor(result, cv2.COLOR_RGB2BGR))



  checkpoint = torch.load('checkpoints/best_model_fasterRCNN.pth')
  state_dict = torch.load(f)


True

In [13]:
# Process video
process_video_with_sam(
    model_combo,
    'input.mp4',
    'output_video.mp4',
    VOC_CLASSES_ReIndex
)

KeyboardInterrupt: 