# An efficient lung cancer detection model based on improved YOLOv8

## 17.3 LTS ML GPU

This notebook demonstrates an efficient lung cancer detection pipeline using an improved YOLOv8-based model, enabling slice-by-slice object detection in 3D medical scans. The approach leverages MONAI for DICOM preprocessing and integrates a custom YOLO-ed implementation for robust detection.

For more details on the original research and methodology, refer to the [source article on PMC](https://pmc.ncbi.nlm.nih.gov/articles/PMC12410754/).

In [0]:
%sh
git clone https://github.com/111sadf/YOLO-ed.git

In [0]:
%pip install opencv-python==4.12.0.88
%pip install timm
#%pip install -U ultralytics>=8.3.162

%pip install -r ../../monailabel_model/vista3d/requirements.txt
%pip install ../../monailabel_model/artifacts/monailabel-0.8.5-py3-none-any.whl --no-deps
%pip install monai==1.4.0 --no-deps

In [0]:
%pip install ./YOLO-ed/

In [0]:
dbutils.library.restartPython()

In [0]:
# ==========================================
# 1. CONFIGURATION SECTION
# ==========================================
import torch
class Config:

    # INPUTS
    # Folder containing DICOMs or path to a single .nii.gz file
    #STUDY_PATH = "/Volumes/ema_rina/pixels_solacc_tcia/pixels_volume/unzipped/1.3.6.1.4.1.9328.50.1.112793496676844175431872842334447612042/"
    INPUT_FOLDER = "/Volumes/ema_rina/pixels_solacc_tcia/pixels_volume/LUNA16/subset0/"

    # MODEL SETUP
    YOLO_ED_REPO = "./YOLO-ed"  # Path to the cloned repo
    WEIGHTS_PATH = "./YOLO-ed/best.pt"
    CONFIDENCE = 0.50
    
    # OUTPUTS
    OUTPUT_DIR = "./results_pipeline"
    SAVE_DEBUG_SLICES = False  # Set True to save every 10th slice to check quality
    
    # COMPUTATIONAL
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
import torch
# ==========================================
# 2. SYSTEM PATCHES (CRITICAL)
# ==========================================
def patch_environment():
    """
    1. Fixes PyTorch 2.6+ security block for custom models.
    2. Injects YOLO-ed repo into Python path.
    """
    # A. PyTorch 2.6 Fix
    _original_torch_load = torch.load
    def patched_torch_load(*args, **kwargs):
        if 'weights_only' not in kwargs:
            kwargs['weights_only'] = False
        return _original_torch_load(*args, **kwargs)
    torch.load = patched_torch_load
    
    # B. Path Injection
    repo_abs_path = str(Path(Config.YOLO_ED_REPO).resolve())
    if repo_abs_path not in sys.path:
        sys.path.insert(0, repo_abs_path)
        print(f"üîå System: Added YOLO-ed repo to path: {repo_abs_path}")

In [0]:
import sys
import os
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pathlib import Path
from PIL import Image

# --- MONAI IMPORTS ---
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    EnsureTyped
)
from monai.data import Dataset

# ==========================================
# 3. FILE DISCOVERY (THE FIX)
# ==========================================
def discover_scans(folder_path):
    """
    Smartly finds scans inside a folder.
    - If it finds .mhd files (LUNA16), returns a list of them.
    - If it finds .nii.gz, returns a list of them.
    - If neither, assumes the folder ITSELF is one DICOM series.
    """
    folder_path = os.path.abspath(folder_path)
    
    # 1. Look for LUNA16 (.mhd)
    mhd_files = glob.glob(os.path.join(folder_path, "*.mhd"))
    if mhd_files:
        print(f"üìÇ Found {len(mhd_files)} .mhd files (LUNA16 format).")
        return mhd_files

    # 2. Look for NIfTI (.nii.gz)
    nii_files = glob.glob(os.path.join(folder_path, "*.nii.gz"))
    if nii_files:
        print(f"üìÇ Found {len(nii_files)} .nii.gz files.")
        return nii_files

    # 3. Fallback: Assume the folder itself is a single DICOM study
    print("üìÇ No standalone files found. Assuming folder is a DICOM Series.")
    return [folder_path]

# ==========================================
# 4. PROCESSING FUNCTIONS
# ==========================================
def get_transforms():
    return Compose([
        LoadImaged(keys=["image"], image_only=False), # image_only=False keeps metadata
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(0.7, 0.7, 1.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"], 
            a_min=-1000.0, a_max=400.0, 
            b_min=0.0, b_max=255.0, 
            clip=True
        ),
        EnsureTyped(keys=["image"])
    ])

def load_model():
    try:
        from ultralytics import YOLO
        return YOLO(Config.WEIGHTS_PATH)
    except Exception as e:
        print(f"‚ùå Model Error: {e}")
        sys.exit(1)

def save_verification_image(slice_img, z_index, boxes, save_dir, patient_id):
    """Saves image with patient_id in filename so they don't overwrite each other"""
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(slice_img, cmap="gray", vmin=0, vmax=255)
    ax.axis('off')

    for box in boxes:
        x1, y1, x2, y2 = box['bbox_2d']
        w, h = x2 - x1, y2 - y1
        rect = patches.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x1, y1-5, f"{box['confidence']:.2f}", color='red', fontsize=8, weight='bold')

    # Filename includes Patient ID now
    filename = f"{patient_id}_z{z_index:04d}.png"
    plt.savefig(os.path.join(save_dir, filename), bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)

# ==========================================
# 5. MAIN LOOP
# ==========================================
def run_pipeline():
    patch_environment()
    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    visuals_dir = os.path.join(Config.OUTPUT_DIR, "verified_detections")
    os.makedirs(visuals_dir, exist_ok=True)

    # 1. Get List of Scans
    scan_paths = discover_scans(Config.INPUT_FOLDER)
    model = load_model()

    total_detections = 0
    all_findings = []

    # 2. Loop through EVERY scan found
    for scan_path in scan_paths:
        patient_id = os.path.splitext(os.path.basename(scan_path))[0]
        print(f"\n‚û°Ô∏è  Processing Patient: {patient_id}")
        
        try:
            ds = Dataset(data=[{"image": scan_path}], transform=get_transforms())
            data_item = ds[0]
            volume = data_item["image"][0].numpy()
        except Exception as e:
            print(f"   ‚ùå Load Failed: {e}")
            continue

        depth = volume.shape[2]
        patient_detections = []
        
        # 3. Slice Loop
        for z in range(depth):
            slice_gray = volume[:, :, z]
            slice_rgb = np.stack((slice_gray,) * 3, axis=-1).astype(np.uint8)

            results = model(slice_rgb, verbose=False, conf=Config.CONFIDENCE)
            
            current_slice_boxes = []
            for result in results:
                for box in result.boxes:
                    x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                    conf = float(box.conf[0].cpu().numpy())
                    cls = int(box.cls[0].cpu().numpy())
                    
                    det = {
                        "patient_id": patient_id,
                        "z_index": z,
                        "bbox_2d": [float(x1), float(y1), float(x2), float(y2)],
                        "confidence": conf,
                        "class": cls
                    }
                    current_slice_boxes.append(det)
                    patient_detections.append(det)
            
            # Save visual proof if found
            if current_slice_boxes:
                save_verification_image(slice_gray, z, current_slice_boxes, visuals_dir, patient_id)
                print(".", end="", flush=True)

        # End of patient
        if patient_detections:
            print(f" Found {len(patient_detections)} candidates.")
            all_findings.extend(patient_detections)
            total_detections += len(patient_detections)
        else:
            print(" Clean.")

    # 4. Final Report
    report_path = os.path.join(Config.OUTPUT_DIR, "final_batch_report.json")
    with open(report_path, "w") as f:
        json.dump(all_findings, f, indent=4)

    print("\n" + "="*40)
    print(f"üèÅ BATCH COMPLETE")
    print(f"‚Ä¢ Scans Processed: {len(scan_paths)}")
    print(f"‚Ä¢ Total Detections: {total_detections}")
    print(f"‚Ä¢ Report: {report_path}")
    print("="*40)

In [0]:
%sh
rm -rf ./results_pipeline

In [0]:
run_pipeline()

In [0]:
from monai.transforms import LoadImage

# Point to an extracted LUNA16 file
mhd_file = Config.INPUT_FOLDER + "1.3.6.1.4.1.14519.5.2.1.6279.6001.105756658031515062000744821260.mhd"

transforms = Compose([
        LoadImaged(keys=["image"], image_only=False), # image_only=False keeps metadata
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"], 
            a_min=-1000.0, a_max=400.0, 
            b_min=0.0, b_max=255.0, 
            clip=True
        ),
        EnsureTyped(keys=["image"])
    ])

ds = Dataset(data=[{"image": mhd_file}], transform=transforms)
data_item = ds[0]
volume = data_item["image"][0].numpy()
depth = volume.shape[2]
        
# 3. Slice Loop
for z in range(depth):
    slice_gray = volume[:, :, z]
    slice_rgb = np.stack((slice_gray,) * 3, axis=-1).astype(np.uint8)

    try:
        print("‚úÖ Success!")
        print(f"Shape: {volume.shape}")       # Should be (W, H, D)
        import matplotlib.pyplot as plt

        plt.figure(figsize=(5, 5))
        plt.imshow(slice_rgb)
        plt.axis('off')
        plt.title(f"Slice {z}")
        display(plt.gcf())
        plt.close()
        break
    except ImportError:
        print("‚ùå Error: SimpleITK not found. Please run 'pip install SimpleITK'")
    except Exception as e:
        print(f"‚ùå Error: {e}")

In [0]:
import random
from PIL import Image
from pathlib import Path

import pydicom

dicom_dir = "/Volumes/ema_rina/pixels_solacc_tcia/pixels_volume/unzipped/1.3.6.1.4.1.9328.50.1.112793496676844175431872842334447612042/"
dicom_files = sorted([f for f in Path(dicom_dir).iterdir() if f.is_file()])
if dicom_files:
    ds = pydicom.dcmread(str(dicom_files[0]), stop_before_pixels=True)
    print("StudyInstanceUID:", ds.StudyInstanceUID)
    print("SeriesInstanceUID:", ds.SeriesInstanceUID)

img_dir = "./results_pipeline/verified_detections"
img_files = list(Path(img_dir).glob("*.png"))
if img_files:
    img_path = str(random.choice(img_files))
    display(Image.open(img_path))