# YOLO + SAM2 object segmentation
Michal Balogh

### Imports

In [1]:
import torch

### Train YOLO model

In [2]:
from ultralytics import YOLO

yolo_model = YOLO("yolo11n.pt")

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
yolo_model.to(device)
yolo_model.device

device(type='cuda', index=0)

In [5]:
results = yolo_model.train(data="mei_dataset.yaml", epochs=10, imgsz=640)

New https://pypi.org/project/ultralytics/8.3.90 available  Update with 'pip install -U ultralytics'
[34m[1mengine\trainer: [0mtask=detect, mode=train, model=yolo11n.pt, data=mei_dataset.yaml, epochs=10, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=cuda:0, workers=8, project=None, name=train5, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, sh

[34m[1mtrain: [0mScanning C:\Users\balog\Code_win\Code\BP\BP\yolo_dataset\labels\train.cache... 5725 images, 0 backgrounds, 0 corrupt: 100%|██████████| 5725/5725 [00:00<?, ?it/s]


KeyboardInterrupt: 

### Load the best model

In [6]:
# all_runs = os.listdir("runs/detect")
# run = sorted(run)[-1] 

run = "train"

yolo_model = YOLO(f"runs/detect/{run}/weights/best.pt")

In [7]:
img_path = "000000000010.png"

results = yolo_model(img_path, conf=0.5)
results[0].show()


image 1/1 c:\Users\balog\Code_win\Code\BP\BP\YOLOv11\000000000010.png: 288x640 2 As, 1 E, 2 Rs, 45.6ms
Speed: 1.9ms preprocess, 45.6ms inference, 2.0ms postprocess per image at shape (1, 3, 288, 640)


In [8]:
img_path = "heightMap.png"

results = yolo_model(source=img_path, conf=0.5)
results[0].show()


image 1/1 c:\Users\balog\Code_win\Code\BP\BP\YOLOv11\heightMap.png: 288x640 1 /, 2 5s, 2 As, 1 D, 1 E, 1 I, 2 Ls, 3 Rs, 1 T, 1 U, 22.7ms
Speed: 3.8ms preprocess, 22.7ms inference, 5.0ms postprocess per image at shape (1, 3, 288, 640)


In [9]:
class_ids = results[0].boxes.cls.int().tolist()
boxes = results[0].boxes.xyxy.int().tolist()
scores = results[0].boxes.conf.tolist()
print("Class IDs: ", class_ids)
print("Boxes: ", boxes)
print("Scores: ", scores)

Class IDs:  [19, 13, 13, 19, 23, 38, 36, 30, 7, 36, 36, 30, 39, 22, 27]
Boxes:  [[1264, 498, 1405, 610], [207, 714, 503, 879], [1359, 705, 1657, 875], [802, 499, 940, 611], [2220, 487, 2347, 601], [1725, 493, 1818, 604], [2122, 699, 2439, 870], [1453, 496, 1552, 609], [592, 698, 889, 890], [629, 501, 756, 614], [2202, 736, 2356, 791], [2393, 487, 2499, 601], [1870, 493, 1997, 603], [991, 497, 1116, 611], [1177, 496, 1207, 609]]
Scores:  [0.9947116374969482, 0.9926922917366028, 0.9920687675476074, 0.9916700124740601, 0.9857156872749329, 0.9573426842689514, 0.9556028842926025, 0.9428810477256775, 0.9369704127311707, 0.933358907699585, 0.9242571592330933, 0.8971443176269531, 0.8158262968063354, 0.7977802753448486, 0.5703930258750916]


### Run inference on YOLO detection and pass the results to SAM2 for segmentation

In [10]:
from ultralytics import SAM

# Get bounding boxes from YOLO detection
img_path = "heightMap.png"
yolo_output = yolo_model(img_path, conf=0.3)[0]

# Load SAM model
sam_ckpt = "sam2_b.pt"
sam_model = SAM(sam_ckpt)

# Extract bounding boxes from YOLO detection output and pass them to SAM model
boxes = yolo_output.boxes.xyxy 
sam_output = sam_model(yolo_output.orig_img, bboxes=boxes, verbose=False, device=device, save=True)[0]

id2label = yolo_output.names
class_ids = yolo_output.boxes.cls.int().tolist()

sam_output_ids = {i: class_id for i, class_id in enumerate(class_ids)}
sam_output.names = {k: id2label[int(v)] for k,v in sam_output_ids.items()}

sam_output.show()


image 1/1 c:\Users\balog\Code_win\Code\BP\BP\YOLOv11\heightMap.png: 288x640 1 /, 2 5s, 2 As, 2 Ds, 1 E, 1 I, 2 Ls, 3 Rs, 1 T, 1 U, 133.9ms
Speed: 6.6ms preprocess, 133.9ms inference, 5.0ms postprocess per image at shape (1, 3, 288, 640)
Results saved to [1mruns\segment\predict13[0m


In [11]:
id2label = yolo_output.names
# print(id2label)
class_ids = yolo_output.boxes.cls.int().tolist()

print(class_ids)
print(sam_output.names)

sam_output_ids = {i: class_id for i, class_id in enumerate(class_ids)}
sam_output.names = {k: id2label[int(v)] for k,v in sam_output_ids.items()}
print(sam_output.names)

[19, 13, 13, 19, 23, 38, 36, 30, 7, 36, 36, 30, 39, 22, 27, 22]
{0: 'A', 1: '5', 2: '5', 3: 'A', 4: 'E', 5: 'T', 6: 'R', 7: 'L', 8: '/', 9: 'R', 10: 'R', 11: 'L', 12: 'U', 13: 'D', 14: 'I', 15: 'D'}
{0: 'A', 1: '5', 2: '5', 3: 'A', 4: 'E', 5: 'T', 6: 'R', 7: 'L', 8: '/', 9: 'R', 10: 'R', 11: 'L', 12: 'U', 13: 'D', 14: 'I', 15: 'D'}
