In [7]:
import SimpleITK as sitk
import numpy as np
import cv2
import os
import random
import shutil
from ultralytics import YOLO

In [8]:
base_dir = 'SPIDER_cleaned'
img_dir = os.path.join(base_dir, 'images')
mask_dir = os.path.join(base_dir, 'masks')

output_base = 'YOLO_SPIDER_dataset'
images_train = os.path.join(output_base, 'images/train')
images_val = os.path.join(output_base, 'images/val')
labels_train = os.path.join(output_base, 'labels/train')
labels_val = os.path.join(output_base, 'labels/val')

for d in [images_train, images_val, labels_train, labels_val]:
    if os.path.exists(d):
        shutil.rmtree(d) 
    os.makedirs(d)

LABELS_MAP = {
    201: 0, 
    202: 1, 
    203: 2, 
    204: 3, 
    205: 4  
}

def normalize_to_jpg(img_slice):
    img_min = np.min(img_slice)
    img_max = np.max(img_slice)
    if img_max == img_min:
        return np.zeros_like(img_slice, dtype=np.uint8)
    img_norm = 255.0 * (img_slice - img_min) / (img_max - img_min)
    return img_norm.astype(np.uint8)

def bounding_box(mask_slice, label_id, img_w, img_h):
    y_indices, x_indices = np.where(mask_slice == label_id)
    
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)

    x_center = (x_min + x_max) / 2.0 / img_w
    y_center = (y_min + y_max) / 2.0 / img_h
    bbox_w = (x_max - x_min) / img_w
    bbox_h = (y_max - y_min) / img_h

    return (x_center, y_center, bbox_w, bbox_h)

files = sorted([f for f in os.listdir(img_dir) if f.endswith('.mha')])
random.seed(42)
random.shuffle(files)

split_idx = int(len(files) * 0.8)
train_files = files[:split_idx]
val_files = files[split_idx:]

def process_set(file_list, img_dest, lbl_dest):
    count = 0
    for filename in file_list:
        img_path = os.path.join(img_dir, filename)
        mask_path = os.path.join(mask_dir, filename)

        if not os.path.exists(mask_path):
            continue

        try:
            sitk_img = sitk.ReadImage(img_path)
            sitk_mask = sitk.ReadImage(mask_path)
            
            arr_img = sitk.GetArrayFromImage(sitk_img)
            arr_mask = sitk.GetArrayFromImage(sitk_mask)

            mid = arr_img.shape[2] // 2
            slices_to_take = [mid-1, mid, mid+1]

            for s_idx in slices_to_take:
                img_slice = arr_img[:, :, s_idx]
                mask_slice = arr_mask[:, :, s_idx]

                img_slice = np.flipud(img_slice)
                mask_slice = np.flipud(mask_slice)

                base_name = f"{filename.replace('.mha', '')}_s{s_idx}"
                jpg_name = base_name + ".jpg"
                txt_name = base_name + ".txt"

                img_uint8 = normalize_to_jpg(img_slice)
                cv2.imwrite(os.path.join(img_dest, jpg_name), img_uint8)

                labels_content = []
                h, w = mask_slice.shape
                
                for disc_id, class_id in LABELS_MAP.items():
                    bbox = bounding_box(mask_slice, disc_id, w, h)
                    if bbox:
                        labels_content.append(f"{class_id} {bbox[0]:.6f} {bbox[1]:.6f} {bbox[2]:.6f} {bbox[3]:.6f}")
                
                if labels_content:
                    with open(os.path.join(lbl_dest, txt_name), 'w') as f:
                        f.write("\n".join(labels_content))
                    count += 1

        except Exception as e:
            print(f"Error {filename}: {e}")
    return count

process_set(train_files, images_train, labels_train)
process_set(val_files, images_val, labels_val)
print("Data generation complete.")

Data generation complete.


In [9]:
model = YOLO('yolov8n.pt')

results = model.train(
    data='/Users/leniecka/python/AIiID/Projekt/spider_data.yaml',
    epochs=50,
    imgsz=320,
    batch=16,
    name='disc_location'
)

[KDownloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt to 'yolov8n.pt': 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 6.2MB 2.1MB/s 3.0s.9s<0.2s2.4s
Ultralytics 8.3.233 üöÄ Python-3.13.2 torch-2.6.0 CPU (Apple M4)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=/Users/leniecka/python/AIiID/Projekt/spider_data.yaml, degrees=0.0, deterministic=True, device=cpu, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=50, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=320, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=yol