# SAM-Enhanced Auto-Labeling Pipeline

**Description**: Uses YOLO for detection + SAM for precise segmentation.

**How it works**:
1. YOLO detects objects (Hardhat, Person, Vest)
2. SAM refines each detection with pixel-perfect masks
3. Tight bounding boxes are extracted from masks

**Result**: More accurate bounding boxes than YOLO alone.

In [None]:
# Step 1: Install Dependencies
!pip install ultralytics segment-anything opencv-python-headless

In [None]:
# Step 2a: Upload Model
import os
from google.colab import files

os.makedirs("models", exist_ok=True)

print("--- UPLOAD MODEL ---")
print("Upload your 'best.pt' model file.")

uploaded_model = files.upload()

MODEL_PATH = "yolov8n.pt"

for filename in uploaded_model.keys():
    if filename.endswith('.pt'):
        print(f"Model detected: {filename}")
        os.rename(filename, os.path.join("models", filename))
        MODEL_PATH = os.path.join("models", filename)

if MODEL_PATH == "yolov8n.pt":
    print("WARNING: No .pt file uploaded. Using default.")
else:
    print(f"Model loaded: {MODEL_PATH}")

In [None]:
# Step 2b: Upload Images
import zipfile

os.makedirs("input_images", exist_ok=True)
IMAGES_DIR = "input_images"

print("--- UPLOAD IMAGES ---")
print("Upload a ZIP file containing your images.")

uploaded_images = files.upload()

for filename in uploaded_images.keys():
    if filename.endswith('.zip'):
        print(f"Extracting: {filename}")
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall(IMAGES_DIR)
        print(f"Extracted to: {IMAGES_DIR}")

In [None]:
# Step 3: Load YOLO and SAM Models
from ultralytics import YOLO
from segment_anything import sam_model_registry, SamPredictor
import torch

# Load YOLO
print("Loading YOLO...")
yolo = YOLO(MODEL_PATH)

# Download and Load SAM
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT):
    print("Downloading SAM checkpoint (~2.5GB)...")
    !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
    print("Downloaded!")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

print("Loading SAM...")
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)

print("Both models loaded!")

In [None]:
# Step 4: SAM-Enhanced Auto-Labeling
import cv2
import numpy as np
import shutil

OUTPUT_ROOT = "output_data"
IMAGES_OUT = os.path.join(OUTPUT_ROOT, "images")
LABELS_OUT = os.path.join(OUTPUT_ROOT, "labels")

os.makedirs(IMAGES_OUT, exist_ok=True)
os.makedirs(LABELS_OUT, exist_ok=True)

# Target classes
NEW_CLASSES = ['Hardhat', 'Person', 'Safety Vest']
with open(os.path.join(OUTPUT_ROOT, "classes.txt"), "w") as f:
    for c in NEW_CLASSES:
        f.write(c + "\n")

OLD_NAMES = yolo.names

def get_bbox_from_mask(mask):
    """Extract tight bounding box from SAM mask."""
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    if not np.any(rows) or not np.any(cols):
        return None
    y_min = np.where(rows)[0][0]
    y_max = np.where(rows)[0][-1]
    x_min = np.where(cols)[0][0]
    x_max = np.where(cols)[0][-1]
    return x_min, y_min, x_max, y_max

# Find images
files_list = []
for root, dirs, files in os.walk(IMAGES_DIR):
    for file in files:
        if file.lower().endswith(('.jpg', '.png', '.jpeg')):
            files_list.append(os.path.join(root, file))

print(f"Processing {len(files_list)} images with SAM refinement...")

count = 0

for img_path in files_list:
    filename = os.path.basename(img_path)
    
    # Read image
    image_bgr = cv2.imread(img_path)
    if image_bgr is None:
        continue
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    h, w = image_bgr.shape[:2]
    
    # Set image for SAM
    predictor.set_image(image_rgb)
    
    # YOLO Detection
    results = yolo.predict(img_path, conf=0.25, verbose=False)
    boxes = results[0].boxes
    
    if len(boxes) == 0:
        continue
    
    valid_lines = []
    
    for box in boxes:
        old_id = int(box.cls[0])
        old_name = OLD_NAMES[old_id]
        
        # Filter classes
        new_id = -1
        if old_name == 'Hardhat':
            new_id = 0
        elif old_name == 'Person':
            new_id = 1
        elif old_name == 'Safety Vest':
            new_id = 2
        
        if new_id == -1:
            continue
        
        # Get YOLO box coordinates (xyxy format)
        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
        
        # Use box center as SAM prompt
        center_x = int((x1 + x2) / 2)
        center_y = int((y1 + y2) / 2)
        
        # SAM prediction using center point
        input_point = np.array([[center_x, center_y]])
        input_label = np.array([1])  # 1 = foreground
        
        masks, scores, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        
        # Use best mask
        best_idx = np.argmax(scores)
        mask = masks[best_idx]
        
        # Get refined bounding box from mask
        refined_bbox = get_bbox_from_mask(mask)
        
        if refined_bbox is None:
            # Fall back to YOLO box
            bx = box.xywhn[0].cpu().numpy()
            valid_lines.append(f"{new_id} {bx[0]:.6f} {bx[1]:.6f} {bx[2]:.6f} {bx[3]:.6f}")
        else:
            # Use SAM-refined box (convert to YOLO format)
            rx1, ry1, rx2, ry2 = refined_bbox
            x_center = ((rx1 + rx2) / 2) / w
            y_center = ((ry1 + ry2) / 2) / h
            box_w = (rx2 - rx1) / w
            box_h = (ry2 - ry1) / h
            valid_lines.append(f"{new_id} {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}")
    
    # Save if valid objects found
    if valid_lines:
        shutil.copy(img_path, os.path.join(IMAGES_OUT, filename))
        txt_name = os.path.splitext(filename)[0] + ".txt"
        with open(os.path.join(LABELS_OUT, txt_name), "w") as f:
            f.write("\n".join(valid_lines))
        count += 1
        if count % 10 == 0:
            print(f"Processed {count} images...")

print(f"\nDone! {count} images with SAM-refined labels.")

In [None]:
# Step 5: Download Results
!zip -r sam_labels.zip output_data

from google.colab import files
files.download('sam_labels.zip')