# YOLO + SAM2 Object Segmentation

**Author**: Michal Balogh

This notebook demonstrates a combined approach for object segmentation using:
1. **YOLO** (You Only Look Once) for object detection
2. **SAM2** (Segment Anything Model 2) for precise segmentation

The workflow involves detecting objects with YOLO to get bounding boxes, then using these boxes with SAM2 to generate high-quality segmentation masks.

## 1. Setup and Imports

In [None]:
import torch
import os
from pathlib import Path
from ultralytics import YOLO, SAM

In [2]:
os.makedirs('results', exist_ok=True)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 2. Train YOLO Model

In this section, we train a YOLO model on a custom dataset.

In [4]:
yolo_model = YOLO("yolo11n.pt")
yolo_model.to(device)
print(f"YOLO model loaded on {yolo_model.device}")

YOLO model loaded on cpu


In [5]:
# Train the model
# Uncomment to run training
# results = yolo_model.train(data="mei_dataset.yaml", epochs=10, imgsz=640)

## 3. Object Detection with YOLO

Load the trained model and run object detection.

In [6]:
# Load the best model from training
run = "train"
model_path = f"runs/detect/{run}/weights/best.pt"

# Or get the latest model from the runs directory:
# all_runs = os.listdir("runs/detect")
# run = sorted(all_runs)[-1]
# model_path = f"runs/detect/{run}/weights/best.pt"

yolo_model = YOLO(model_path)
print(f"Loaded model from {model_path}")

Loaded model from runs/detect/train/weights/best.pt


In [7]:
img_path = "heightMap.png"

results = yolo_model(img_path, conf=0.5)
results[0].show()


image 1/1 c:\Users\balog\Code_win\Code\BP\yolo-sam2\heightMap.png: 288x640 1 /, 2 5s, 2 As, 1 D, 1 E, 1 I, 2 Ls, 3 Rs, 1 T, 1 U, 78.2ms
Speed: 1.9ms preprocess, 78.2ms inference, 1.2ms postprocess per image at shape (1, 3, 288, 640)


In [8]:
# Analyze detection results
class_ids = results[0].boxes.cls.int().tolist()
boxes = results[0].boxes.xyxy.int().tolist()
scores = results[0].boxes.conf.tolist()

print(f"Detected {len(class_ids)} objects:")
print("Class IDs: ", class_ids)
print("Boxes: ", boxes)
print("Confidence scores: ", [f"{score:.3f}" for score in scores])

Detected 15 objects:
Class IDs:  [19, 13, 13, 19, 23, 38, 36, 30, 7, 36, 36, 30, 39, 22, 27]
Boxes:  [[1264, 498, 1405, 610], [207, 714, 503, 879], [1359, 705, 1657, 875], [802, 499, 940, 611], [2220, 487, 2347, 601], [1725, 493, 1818, 604], [2122, 699, 2439, 870], [1453, 496, 1552, 609], [592, 698, 889, 890], [629, 501, 756, 614], [2202, 736, 2356, 791], [2393, 487, 2499, 601], [1870, 493, 1997, 603], [991, 497, 1116, 611], [1177, 496, 1207, 609]]
Confidence scores:  ['0.995', '0.993', '0.992', '0.992', '0.986', '0.958', '0.956', '0.943', '0.937', '0.933', '0.924', '0.897', '0.815', '0.797', '0.572']


## 4. Object Segmentation with SAM2

Use SAM2 model to generate segmentation masks based on YOLO's bounding boxes.

In [9]:
img_path = "heightMap.png"
conf_threshold = 0.3

# Get bounding boxes from YOLO detection
print(f"Running YOLO detection with confidence threshold {conf_threshold}...")
yolo_output = yolo_model(img_path, conf=conf_threshold)[0]

# Load SAM model
sam_ckpt = "sam2_b.pt"
print(f"Loading SAM2 model from {sam_ckpt}...")
sam_model = SAM(sam_ckpt)

# Extract bounding boxes from YOLO detection
boxes = yolo_output.boxes.xyxy 
print(f"Found {len(boxes)} bounding boxes for segmentation")

Running YOLO detection with confidence threshold 0.3...

image 1/1 c:\Users\balog\Code_win\Code\BP\yolo-sam2\heightMap.png: 288x640 1 /, 2 5s, 2 As, 2 Ds, 1 E, 1 I, 2 Ls, 3 Rs, 1 T, 1 U, 63.3ms
Speed: 1.4ms preprocess, 63.3ms inference, 1.0ms postprocess per image at shape (1, 3, 288, 640)
Loading SAM2 model from sam2_b.pt...
Found 16 bounding boxes for segmentation


In [10]:
# Run SAM2 segmentation
print(f"Running SAM2 segmentation on {device}...")
sam_output = sam_model(
    yolo_output.orig_img, 
    bboxes=boxes, 
    verbose=False, 
    device=device, 
    save=True
)[0]

# Save segmentation results
sam_output.save(filename=f"results/segmentation_{Path(img_path).name}")

Running SAM2 segmentation on cpu...
Results saved to [1mruns\segment\predict6[0m


'results/segmentation_heightMap.png'

In [11]:
# Map class IDs to class names
id2label = yolo_output.names
class_ids = yolo_output.boxes.cls.int().tolist()

print("Detected class IDs:", class_ids)
print("Class mapping:", id2label)

# Assign class names to segmentation masks
sam_output_ids = {i: class_id for i, class_id in enumerate(class_ids)}
sam_output.names = {k: id2label[int(v)] for k,v in sam_output_ids.items()}
print("\nSegmentation masks with class names:")
print(sam_output.names)

Detected class IDs: [19, 13, 13, 19, 23, 38, 36, 30, 7, 36, 36, 30, 39, 22, 27, 22]
Class mapping: {0: 'tick', 1: '(', 2: ')', 3: 'asterisk', 4: '+', 5: '_', 6: '.', 7: '/', 8: '0', 9: '1', 10: '2', 11: '3', 12: '4', 13: '5', 14: '6', 15: '7', 16: '8', 17: '9', 18: 'colon', 19: 'A', 20: 'B', 21: 'C', 22: 'D', 23: 'E', 24: 'F', 25: 'G', 26: 'H', 27: 'I', 28: 'J', 29: 'K', 30: 'L', 31: 'M', 32: 'N', 33: 'O', 34: 'P', 35: 'Q', 36: 'R', 37: 'S', 38: 'T', 39: 'U', 40: 'V', 41: 'W', 42: 'X', 43: 'Y', 44: 'Z', 45: 'a', 46: 'b', 47: 'd', 48: 'e', 49: 'g', 50: 'i', 51: 'k', 52: 'l', 53: 'n', 54: 'o', 55: 'p', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'speical_symbol_127', 61: 'speical_symbol_128', 62: 'speical_symbol_129', 63: 'speical_symbol_131', 64: 'speical_symbol_132', 65: 'speical_symbol_133', 66: 'speical_symbol_134', 67: 'speical_symbol_135', 68: 'speical_symbol_136', 69: 'speical_symbol_137', 70: 'speical_symbol_138'}

Segmentation masks with class names:
{0: 'A', 1: '5', 2: '5', 3: 'A',

In [19]:
sam_output.show(labels=True)