In [None]:
# 1. Logging Setup (run first)

import logging

class CustomFormatter(logging.Formatter):
    """Custom formatter with colors for notebook output"""
    grey = "\x1b[38;21m"
    blue = "\x1b[38;5;39m"
    yellow = "\x1b[38;5;226m"
    red = "\x1b[38;5;196m"
    reset = "\x1b[0m"

    def __init__(self):
        super().__init__(fmt="%(asctime)s - %(levelname)s - %(message)s",
                         datefmt="%Y-%m-%d %H:%M:%S")
        self.FORMATS = {
            logging.DEBUG: self.grey,
            logging.INFO: self.blue,
            logging.WARNING: self.yellow,
            logging.ERROR: self.red
        }

    def format(self, record):
        color = self.FORMATS.get(record.levelno, self.grey)
        log_fmt = f"{color}%(asctime)s - %(levelname)s - %(message)s{self.reset}"
        formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
        return formatter.format(record)

# Setup logger
logger = logging.getLogger("YOLOComparison")
logger.setLevel(logging.INFO)
if not logger.handlers:
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(CustomFormatter())
    logger.addHandler(console_handler)
logger.info("Logger initialized")

[38;5;39m2025-06-07 09:56:12 - INFO - Logger initialized[0m
INFO:YOLOComparison:Logger initialized


In [None]:
# Install all required packages (run this first if you get ModuleNotFoundError)
!pip install ultralytics sahi pycocotools pandas seaborn --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.7/114.7 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m119.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.4 MB/s[0m eta [36m0:0

In [None]:
# 2. Imports & Config (run after logging setup)

import os
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import matplotlib.patches as patches
from ultralytics import YOLO, __version__ as yolo_version
from sklearn.metrics import confusion_matrix

# Google Drive mounting (run only in Google Colab)
from google.colab import drive
drive.mount('/content/drive')

# Path configs (Change if needed)
BASE_DIR = '/content/drive/MyDrive/new scope model'
DATA_YAML = f'{BASE_DIR}/data.yaml'
DATASET_DIR = '/content/drive/MyDrive/Samplesmall_dataset'

logger.info(f"Using Ultralytics YOLO version: {yolo_version}")
for path in [BASE_DIR, DATA_YAML, DATASET_DIR]:
    if not os.path.exists(path):
        logger.error(f"Path not found: {path}")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


[38;5;39m2025-06-07 09:58:12 - INFO - Using Ultralytics YOLO version: 8.3.151[0m
INFO:YOLOComparison:Using Ultralytics YOLO version: 8.3.151


Mounted at /content/drive


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_comparison_grid(images, detections_list, class_names, method_names, save_path):
    """
    images: list of np.array images [img1, img2]
    detections_list: list of lists, shape [n_methods][n_images][detections]
    class_names: list of str
    method_names: list of str, e.g. ['YOLOv8', 'Enhanced YOLO', 'SAHI']
    save_path: where to save the output image
    """
    n_images = len(images)
    n_methods = len(method_names)
    fig, axes = plt.subplots(n_images, n_methods, figsize=(5 * n_methods, 5 * n_images))

    for i in range(n_images):
        for j in range(n_methods):
            ax = axes[i, j] if n_images > 1 else axes[j]
            ax.imshow(images[i])
            # Draw detections for this method/image
            for det in detections_list[j][i]:
                bbox = det['bbox']
                label = class_names[det['class_id']]
                conf = det['confidence']
                rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3],
                                         linewidth=2, edgecolor='r', facecolor='none')
                ax.add_patch(rect)
                ax.text(bbox[0], bbox[1]-5, f'{label}: {conf:.2f}',
                        color='white', bbox=dict(facecolor='red', alpha=0.5))
            if i == 0:
                ax.set_title(method_names[j])
            ax.axis('off')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"[INFO] Saved comparison grid to {save_path}")

In [None]:
# 3. ResultsVisualizer Class (data visualization)

class ResultsVisualizer:
    """Handles all visualization tasks"""
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.results_dir = os.path.join(base_dir, f'comparison_results_{self.timestamp}')
        self.dirs = {
            'plots': os.path.join(self.results_dir, 'plots'),
            'metrics': os.path.join(self.results_dir, 'metrics'),
            'detections': os.path.join(self.results_dir, 'detection_examples'),
            'logs': os.path.join(self.results_dir, 'logs')
        }
        for dir_path in self.dirs.values():
            os.makedirs(dir_path, exist_ok=True)
        logger.info(f"Created results directory at {self.results_dir}")

    def plot_metrics_comparison(self, metrics_dict):
        plt.figure(figsize=(12, 8))
        df = pd.DataFrame(metrics_dict).T
        ax = df.plot(kind='bar', width=0.8)
        plt.title('Performance Comparison Across Models')
        plt.xlabel('Model Type')
        plt.ylabel('Score')
        for container in ax.containers:
            ax.bar_label(container, fmt='%.3f')
        plt.tight_layout()
        save_path = os.path.join(self.dirs['plots'], 'metrics_comparison.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        csv_path = os.path.join(self.dirs['metrics'], 'metrics_comparison.csv')
        df.to_csv(csv_path)
        markdown_path = os.path.join(self.dirs['metrics'], 'metrics_summary.md')
        with open(markdown_path, 'w') as f:
            f.write("# Model Performance Comparison\n\n")
            f.write(df.to_markdown())
        return df

    def plot_confusion_matrix(self, true_labels, pred_labels, class_names, model_name):
        plt.figure(figsize=(12, 10))
        cm = confusion_matrix(true_labels, pred_labels)
        sns.heatmap(cm, annot=True, fmt='d',
                    xticklabels=class_names,
                    yticklabels=class_names)
        plt.title(f'Confusion Matrix - {model_name}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        save_path = os.path.join(self.dirs['plots'], f'confusion_matrix_{model_name}.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    def plot_precision_recall_curves(self, precisions, recalls, model_names):
        plt.figure(figsize=(10, 8))
        for i, model_name in enumerate(model_names):
            plt.plot(recalls[i], precisions[i], label=model_name)
        plt.title('Precision-Recall Curves')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend()
        plt.grid(True)
        save_path = os.path.join(self.dirs['plots'], 'precision_recall_curves.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    def plot_per_class_map(self, class_maps, class_names, model_names):
        df = pd.DataFrame(class_maps, index=model_names, columns=class_names)
        plt.figure(figsize=(15, 8))
        ax = df.plot(kind='bar', width=0.8)
        plt.title('Per-Class mAP Comparison')
        plt.xlabel('Model')
        plt.ylabel('mAP')
        plt.legend(title='Classes', bbox_to_anchor=(1.05, 1), loc='upper left')
        for container in ax.containers:
            ax.bar_label(container, fmt='%.3f', rotation=90)
        plt.tight_layout()
        save_path = os.path.join(self.dirs['plots'], 'per_class_map.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    def create_detection_grid(self, images, detections, class_names, model_names):
        n_images = len(images)
        n_models = len(model_names)
        fig, axes = plt.subplots(n_images, n_models, figsize=(5*n_models, 5*n_images))
        for i in range(n_images):
            for j in range(n_models):
                ax = axes[i, j] if n_images > 1 else axes[j]
                ax.imshow(images[i])
                for det in detections[j][i]:
                    bbox = det['bbox']
                    label = class_names[det['class_id']]
                    conf = det['confidence']
                    rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3],
                                             linewidth=2, edgecolor='r', facecolor='none')
                    ax.add_patch(rect)
                    ax.text(bbox[0], bbox[1]-5, f'{label}: {conf:.2f}',
                            color='white', bbox=dict(facecolor='red', alpha=0.5))
                if i == 0:
                    ax.set_title(model_names[j])
                ax.axis('off')
        plt.tight_layout()
        save_path = os.path.join(self.dirs['detections'], 'detection_grid.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    def save_markdown_report(self, metrics_df, additional_notes=None):
        report_path = os.path.join(self.dirs['metrics'], 'complete_report.md')
        with open(report_path, 'w') as f:
            f.write("# Model Comparison Report\n\n")
            f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            f.write("## Summary Metrics\n")
            f.write(metrics_df.to_markdown())
            f.write("\n\n")
            f.write("## Visualization Directory Structure\n")
            for dir_name, dir_path in self.dirs.items():
                f.write(f"- {dir_name}: {dir_path}\n")
            f.write("\n")
            if additional_notes:
                f.write("## Additional Notes\n")
                f.write(additional_notes)
                f.write("\n")
            f.write("\n## Plots Generated\n")
            f.write("1. Metrics Comparison (Bar Plot)\n")
            f.write("2. Confusion Matrices\n")
            f.write("3. Precision-Recall Curves\n")
            f.write("4. Per-Class mAP Comparison\n")
            f.write("5. Detection Examples Grid\n")

    def generate_all_visualizations(self, results_dict, class_names):
        # 1. Overall metrics comparison
        metrics_df = self.plot_metrics_comparison(results_dict['summary_metrics'])
        # 2. Confusion matrices
        for model_name in results_dict['confusion_matrices']:
            cm_data = results_dict['confusion_matrices'][model_name]
            self.plot_confusion_matrix(
                cm_data['true'],
                cm_data['pred'],
                class_names,
                model_name
            )
        # 3. Precision-recall curves
        self.plot_precision_recall_curves(
            results_dict['precision'],
            results_dict['recall'],
            list(results_dict['summary_metrics'].keys())
        )
        # 4. Per-class mAP
        if 'per_class_map' in results_dict:
            self.plot_per_class_map(
                results_dict['per_class_map'],
                class_names,
                list(results_dict['summary_metrics'].keys())
            )
        # 5. Detection grid
        if 'example_images' in results_dict and 'example_detections' in results_dict:
            self.create_detection_grid(
                results_dict['example_images'],
                results_dict['example_detections'],
                class_names,
                list(results_dict['summary_metrics'].keys())
            )
        # 6. Markdown report
        self.save_markdown_report(
            metrics_df,
            additional_notes=results_dict.get('notes', None)
        )
        logger.info(f"All visualizations saved in {self.results_dir}")
        return self.results_dir

In [None]:
class ModelEvaluator:
    def __init__(self, base_dir, data_yaml):
        self.base_dir = base_dir
        self.data_yaml = data_yaml
        self.visualizer = ResultsVisualizer(base_dir)

        # Load class names from yaml
        with open(data_yaml, 'r') as f:
            self.data_config = yaml.safe_load(f)
        self.class_names = self.data_config['names']

        logger.info(f"Initialized evaluator with {len(self.class_names)} classes")

    def train_and_evaluate_baseline(self):
        """Train and evaluate baseline YOLOv8 model"""
        logger.info("Starting baseline model training...")

        model = YOLO('yolov8n.pt')
        results = model.train(
            data=self.data_yaml,
            epochs=100,
            imgsz=640,
            project=self.base_dir,
            name='baseline_model'
        )

        # Evaluate
        val_results = model.val(data=self.data_yaml)
        return model, val_results

    def train_and_evaluate_enhanced(self):
        """Train and evaluate enhanced model (larger size + TTA)"""
        logger.info("Starting enhanced model training...")

        model = YOLO('yolov8n.pt')
        results = model.train(
            data=self.data_yaml,
            epochs=100,
            imgsz=1024,
            project=self.base_dir,
            name='enhanced_model'
        )

        # Evaluate with TTA
        val_results = model.val(
            data=self.data_yaml,
            imgsz=1024,
            augment=True
        )
        return model, val_results

    def evaluate_with_sahi(self, model_path):
        """Evaluate using SAHI with better error handling"""
        try:
            from sahi import AutoDetectionModel
            from sahi.predict import get_sliced_prediction

            logger.info(f"Loading model from {model_path}")
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")

            detection_model = AutoDetectionModel.from_pretrained(
                model_type='ultralytics',
                model_path=model_path,
                confidence_threshold=0.3,
                device='cuda'
            )

            test_images_dir = os.path.join(DATASET_DIR, 'test/images')
            if not os.path.exists(test_images_dir):
                raise FileNotFoundError(f"Test images dir not found: {test_images_dir}")

            results = []
            for image_name in os.listdir(test_images_dir):
                if image_name.endswith(('.jpg', '.png')):
                    image_path = os.path.join(test_images_dir, image_name)
                    try:
                        result = get_sliced_prediction(
                            image=image_path,
                            detection_model=detection_model,
                            slice_height=512,
                            slice_width=512,
                            overlap_height_ratio=0.2,
                            overlap_width_ratio=0.2
                        )
                        results.append(result)
                        logger.debug(f"Processed {image_name} successfully")
                    except Exception as e:
                        logger.warning(f"Failed to process {image_name}: {str(e)}")
                        continue

            return results

        except Exception as e:
            logger.error(f"SAHI evaluation failed: {str(e)}")
            return []

    def run_complete_evaluation(self):
        """Run complete evaluation pipeline"""
        try:
            # 1. Baseline evaluation
            baseline_model, baseline_results = self.train_and_evaluate_baseline()

            # 2. Enhanced evaluation
            enhanced_model, enhanced_results = self.train_and_evaluate_enhanced()

            # 3. SAHI evaluation
            sahi_results = self.evaluate_with_sahi(
                f'{self.base_dir}/enhanced_model/weights/best.pt'
            )

            # Function for safe metric extraction
            def safe_get_metrics(results):
                """Safe metric extraction with fallbacks"""
                if not hasattr(results, 'box'):
                    logger.error("Validation results missing 'box' attribute")
                    return {
                        'mAP50': 0,
                        'mAP50-95': 0,
                        'recall': 0,
                        'precision': 0
                    }

                box = results.box
                return {
                    'mAP50': getattr(box, 'map50', 0),
                    'mAP50-95': getattr(box, 'map', 0),
                    'recall': getattr(box, 'r', 0),
                    'precision': getattr(box, 'p', 0)
                }

            # 4. Collect all results
            baseline_metrics = safe_get_metrics(baseline_results)
            enhanced_metrics = safe_get_metrics(enhanced_results)

            results_dict = {
                'summary_metrics': {
                    'Baseline': {
                        'mAP50': baseline_metrics['mAP50'],
                        'mAP50-95': baseline_metrics['mAP50-95'],
                        'recall': baseline_metrics['recall']
                    },
                    'Enhanced+TTA': {
                        'mAP50': enhanced_metrics['mAP50'],
                        'mAP50-95': enhanced_metrics['mAP50-95'],
                        'recall': enhanced_metrics['recall']
                    }
                },
                'confusion_matrices': {
    'Baseline': {
        'true': [0]*len(self.class_names),
        'pred': [0]*len(self.class_names)
    },
    'Enhanced+TTA': {
        'true': [0]*len(self.class_names),
        'pred': [0]*len(self.class_names)
    }
}
            }

            # 5. Generate visualizations
            self.visualizer.generate_all_visualizations(
                results_dict,
                self.class_names
            )

            logger.info("Evaluation completed successfully!")
            return results_dict

        except Exception as e:
            logger.error(f"Error during evaluation: {str(e)}", exc_info=True)
            raise

    def test_metrics_extraction(self):
        """Test metrics extraction works with current YOLO version"""
        from unittest.mock import MagicMock

        # Create mock results object
        mock_results = MagicMock()
        mock_results.box = MagicMock()
        mock_results.box.map50 = 0.5
        mock_results.box.map = 0.4
        mock_results.box.r = 0.3
        mock_results.box.p = 0.6
        mock_results.confusion_matrix = None

        # Since safe_get_metrics is inside run_complete_evaluation, redefining it here
        def safe_get_metrics(results):
            if not hasattr(results, 'box'):
                return {
                    'mAP50': 0,
                    'mAP50-95': 0,
                    'recall': 0,
                    'precision': 0
                }

            box = results.box
            return {
                'mAP50': getattr(box, 'map50', 0),
                'mAP50-95': getattr(box, 'map', 0),
                'recall': getattr(box, 'r', 0),
                'precision': getattr(box, 'p', 0)
            }

        metrics = safe_get_metrics(mock_results)
        assert metrics['mAP50'] == 0.5
        assert metrics['recall'] == 0.3
        logger.info("✅ Metrics extraction test passed!")

        # Test missing attribute handling
        mock_results.box = None
        metrics = safe_get_metrics(mock_results)
        assert metrics['mAP50'] == 0
        logger.info("✅ Error handling test passed!")

In [None]:
class ModelEvaluator:
    def __init__(self, base_dir, data_yaml):
        self.base_dir = base_dir
        self.data_yaml = data_yaml
        self.visualizer = ResultsVisualizer(base_dir)

        # Load class names from yaml
        with open(data_yaml, 'r') as f:
            self.data_config = yaml.safe_load(f)
        self.class_names = self.data_config['names']

        logger.info(f"Initialized evaluator with {len(self.class_names)} classes")

    def train_and_evaluate_baseline(self):
        """Train and evaluate baseline YOLOv8 model"""
        logger.info("Starting baseline model training...")

        model = YOLO('yolov8n.pt')
        results = model.train(
            data=self.data_yaml,
            epochs=15,
            imgsz=640,
            project=self.base_dir,
            name='baseline_model'
        )

        # Evaluate
        val_results = model.val(data=self.data_yaml)
        return model, val_results

    def train_and_evaluate_enhanced(self):
        """Train and evaluate enhanced model (larger size + TTA)"""
        logger.info("Starting enhanced model training...")

        model = YOLO('yolov8n.pt')
        results = model.train(
            data=self.data_yaml,
            epochs=15,
            imgsz=1024,
            project=self.base_dir,
            name='enhanced_model'
        )

        # Evaluate with TTA
        val_results = model.val(
            data=self.data_yaml,
            imgsz=1024,
            augment=True
        )
        return model, val_results

    def evaluate_with_sahi(self, model_path):
        """Evaluate using SAHI with better error handling"""
        try:
            from sahi import AutoDetectionModel
            from sahi.predict import get_sliced_prediction

            logger.info(f"Loading model from {model_path}")
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")

            detection_model = AutoDetectionModel.from_pretrained(
                model_type='ultralytics',
                model_path=model_path,
                confidence_threshold=0.3,
                device='cuda'
            )

            test_images_dir = os.path.join(DATASET_DIR, 'test/images')
            if not os.path.exists(test_images_dir):
                raise FileNotFoundError(f"Test images dir not found: {test_images_dir}")

            results = []
            for image_name in os.listdir(test_images_dir):
                if image_name.endswith(('.jpg', '.png')):
                    image_path = os.path.join(test_images_dir, image_name)
                    try:
                        result = get_sliced_prediction(
                            image=image_path,
                            detection_model=detection_model,
                            slice_height=512,
                            slice_width=512,
                            overlap_height_ratio=0.2,
                            overlap_width_ratio=0.2
                        )
                        results.append(result)
                        logger.debug(f"Processed {image_name} successfully")
                    except Exception as e:
                        logger.warning(f"Failed to process {image_name}: {str(e)}")
                        continue

            return results

        except Exception as e:
            logger.error(f"SAHI evaluation failed: {str(e)}")
            return []

    def run_complete_evaluation(self):
        """Run complete evaluation pipeline"""
        try:
            # 1. Baseline evaluation
            baseline_model, baseline_results = self.train_and_evaluate_baseline()

            # 2. Enhanced evaluation
            enhanced_model, enhanced_results = self.train_and_evaluate_enhanced()

            # 3. SAHI evaluation
            sahi_results = self.evaluate_with_sahi(
                f'{self.base_dir}/enhanced_model/weights/best.pt'
            )

            # Function for safe metric extraction
            def safe_get_metrics(results):
                """Safe metric extraction with fallbacks"""
                if not hasattr(results, 'box'):
                    logger.error("Validation results missing 'box' attribute")
                    return {
                        'mAP50': 0,
                        'mAP50-95': 0,
                        'recall': 0,
                        'precision': 0
                    }

                box = results.box
                return {
                    'mAP50': getattr(box, 'map50', 0),
                    'mAP50-95': getattr(box, 'map', 0),
                    'recall': getattr(box, 'r', 0),
                    'precision': getattr(box, 'p', 0)
                }

            # 4. Collect all results
            baseline_metrics = safe_get_metrics(baseline_results)
            enhanced_metrics = safe_get_metrics(enhanced_results)

            results_dict = {
                'summary_metrics': {
                    'Baseline': {
                        'mAP50': baseline_metrics['mAP50'],
                        'mAP50-95': baseline_metrics['mAP50-95'],
                        'recall': baseline_metrics['recall']
                    },
                    'Enhanced+TTA': {
                        'mAP50': enhanced_metrics['mAP50'],
                        'mAP50-95': enhanced_metrics['mAP50-95'],
                        'recall': enhanced_metrics['recall']
                    }
                },
                'confusion_matrices': {
                    'Baseline': baseline_results.confusion_matrix,
                    'Enhanced+TTA': enhanced_results.confusion_matrix
                }
            }

            # 5. Generate visualizations
            self.visualizer.generate_all_visualizations(
                results_dict,
                self.class_names
            )

            logger.info("Evaluation completed successfully!")
            return results_dict

        except Exception as e:
            logger.error(f"Error during evaluation: {str(e)}", exc_info=True)
            raise

    def test_metrics_extraction(self):
        """Test metrics extraction works with current YOLO version"""
        from unittest.mock import MagicMock

        # Create mock results object
        mock_results = MagicMock()
        mock_results.box = MagicMock()
        mock_results.box.map50 = 0.5
        mock_results.box.map = 0.4
        mock_results.box.r = 0.3
        mock_results.box.p = 0.6
        mock_results.confusion_matrix = None

        # Since safe_get_metrics is inside run_complete_evaluation, redefining it here
        def safe_get_metrics(results):
            if not hasattr(results, 'box'):
                return {
                    'mAP50': 0,
                    'mAP50-95': 0,
                    'recall': 0,
                    'precision': 0
                }

            box = results.box
            return {
                'mAP50': getattr(box, 'map50', 0),
                'mAP50-95': getattr(box, 'map', 0),
                'recall': getattr(box, 'r', 0),
                'precision': getattr(box, 'p', 0)
            }

        metrics = safe_get_metrics(mock_results)
        assert metrics['mAP50'] == 0.5
        assert metrics['recall'] == 0.3
        logger.info("✅ Metrics extraction test passed!")

        # Test missing attribute handling
        mock_results.box = None
        metrics = safe_get_metrics(mock_results)
        assert metrics['mAP50'] == 0
        logger.info("✅ Error handling test passed!")

In [None]:
# 5. Main Pipeline

# Initialize evaluator
evaluator = ModelEvaluator(BASE_DIR, DATA_YAML, DATASET_DIR)
evaluator.test_metrics_extraction()

# Run evaluation
results = evaluator.run_complete_evaluation()

logger.info(f"Results saved in: {evaluator.visualizer.results_dir}")

TypeError: ModelEvaluator.__init__() takes 3 positional arguments but 4 were given