In [None]:
%matplotlib inline
import os
import numpy as np
import json
from tqdm import tqdm
import pandas as pd
from inference import Detector, Classificator, Inference
from inference.Inference import quantify 

# Object detection
from data.load import load_yolo, loader, load_as_coco
from metrics.coco import CocoEvaluator
from metrics.utils import get_classify_ground_truth

from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    accuracy_score,
    confusion_matrix,
    explained_variance_score,
    mean_absolute_error,
    r2_score,
    mean_absolute_percentage_error
)

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
config = {
    "val_source": "../datasets/orobanche_cummana/autosplit_val.txt",
    "detect_model_path": os.path.normpath("./runs/detect/train2/weights/best.pt"),
    "classify_model_path": os.path.normpath("./runs/classify/train/weights/best.pt"),
    'results': 'results',
    
    "to_classes": {
      0: "background",
      1: "healthy",
      2: "necrotic", 
    },
    
    #DETECTION
    "device": 0,
    "overlap": 0.2,
    "patch_size": 1024,
    "size_factor": 1.0,
    "conf_thresh": 0.25,
    "nms_iou_thresh":  0.7,
    "max_patch_detections": 300,
    "patch_per_batch": 4,
    "pre_wbf_detections":  500,
    "wbf_ios_thresh": 0.5,
    "max_detections": 1000,
    
    #CLASSIFICATION (INFERENCE)
    "offset": 0.1,
    "classify_img_size": 224,
    
    #CLASSIFICATION (EVALUATION)
    "bg_iou_thresh": 0.7,
}

In [None]:
detect = Detector(
    model_path=config["detect_model_path"],
    device=config["device"],
    overlap=config["overlap"],
    patch_size=config["patch_size"],
    size_factor=config["size_factor"],
    conf_thresh=config["conf_thresh"],
    nms_iou_thresh=config["nms_iou_thresh"],
    max_patch_detections=config["max_patch_detections"],
    patch_per_batch=config["patch_per_batch"],
    pre_wbf_detections=config["pre_wbf_detections"],
    wbf_ios_thresh=config["wbf_ios_thresh"],
    max_detections=config["max_detections"],
    single_cls=True,
)

classify = Classificator(
    model_path=config["classify_model_path"], 
    device=config["device"], 
    img_size=config["classify_img_size"],
)

inference = Inference(detect, classify, offset=config['offset'])
os.makedirs(config['results'], exist_ok=True)

In [None]:
detection_gt = {k: np.asarray(v) for k, v in load_yolo(config["val_source"]).items()}
coco_gt_detection, names_to_ids = load_as_coco(config["val_source"], config["to_classes"])
coco_detection_evaluator = CocoEvaluator(coco_gt_detection)

In [None]:

def process_image(image_name, image, inference: Inference, detection_gt, config):
    """Process a single image through the detection and classification pipeline."""


    class_num = len(config["to_classes"])
    results = {
        "image_name": image_name,
        "detect_prediction": None,
        "classify_prediction": [],
        "classify_gt": [],
        "inference_prediction": {},
        "quantification_gt": {
            "areas": np.zeros(class_num, dtype=np.float64),
            "counts": np.zeros(class_num, dtype=np.uint64)
        },
        "quantification_prediction": {
            "areas": np.zeros(class_num, dtype=np.float64),
            "counts": np.zeros(class_num, dtype=np.uint64)
        }
    }
    
    # Detection step
    results["detect_prediction"] = inference.detect(image)
    
    # Get patches
    patched_images, patches = inference.patch(image, results["detect_prediction"].boxes)
    

    
    # Classification step
    results["classify_prediction"], confidences = inference.classify(patched_images)
    
    # Get ground truth for classification
    
    results["classify_gt"] = get_classify_ground_truth(
        results["detect_prediction"].boxes.xyxy,
        detection_gt[image_name][:, :4],
        detection_gt[image_name][:, 4],
        config["bg_iou_thresh"],
    )
    
    # Merge detection and classification
    results["inference_prediction"] = inference.merge_detect_and_classification(
        image,
        results["detect_prediction"].boxes.data,
        results["classify_prediction"],
        confidences,
    )
        
    # Save detection visualization
    # results["inference_prediction"].plot(
    #     img=np.asarray(image)[..., ::-1],
    #     filename=f"{config['results']}/{image_name}.jpg",
    #     save=True,
    #     line_width=5,
    #     font_size=16,
    # )
    
    # Quantify ground truth and predictions
    results["quantification_gt"] = quantify(
        detection_gt[image_name][:, :4],
        detection_gt[image_name][:, 4],
        class_num,
    )
    
    results["quantification_prediction"] = quantify(
        results["inference_prediction"].boxes.xyxy,
        results["inference_prediction"].boxes.cls,
            class_num,
    )
    
    profiler_results = {
        "inference_detect": inference.detect_profiler.dt,
        "inference_patch": inference.patch_profiler.dt,
        "inference_classify": inference.classify_profiler.dt,
        "inference_merge": inference.merge_profiler.dt,
        "detector_resizer": inference.detector.resizer_profiler.dt,
        "detector_patches": inference.detector.patches_profiler.dt,
        "detector_yolo": inference.detector.yolo_detector_profiler.dt,
        "detector_merge": inference.detector.merge_predictions_profiler.dt,
        "detector_wbf": inference.detector.wbf_profiler.dt,
    }
    return results, profiler_results

# Process all images
class_num = len(config["to_classes"])
image_results = []
profiler_results = []
image_loader = loader(config["val_source"])

for image_name, image in tqdm(image_loader):
    im_res, profile_res = process_image(image_name, image, inference, detection_gt, config)
    image_results.append(im_res)
    profiler_results.append(profile_res)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

def analyze_runtime(profiler_results):
    """
    Analyze and visualize the runtime of different components in the detection pipeline.
    
    Args:
        profiler_results: List of dictionaries containing timing information
    """
    # Convert list of dictionaries to DataFrame
    df = pd.DataFrame(profiler_results)
    
    # Calculate summary statistics
    summary = df.describe()
    
    # Group operations by category
    detector_cols = [col for col in df.columns if col.startswith('detector_')]
    inference_cols = [col for col in df.columns if col.startswith('inference_')]
    
    # Create category dataframes
    detector_df = df[detector_cols].copy()
    inference_df = df[inference_cols].copy()
    
    # Add total time columns
    detector_df['detector_total'] = detector_df.sum(axis=1)
    inference_df['inference_total'] = inference_df.sum(axis=1)
    
    # Combine all data for full pipeline analysis
    total_df = pd.DataFrame()
    total_df['Detector'] = detector_df['detector_total']
    total_df['Inference'] = inference_df['inference_total']
    total_df['Total'] = total_df['Detector'] + total_df['Inference']
    
    # Print summary statistics
    print("Runtime Analysis Summary (in seconds):")
    print("\nDetector Operations:")
    for col in detector_cols:
        print(f"  {col.replace('detector_', '')}: {df[col].mean():.4f}s ± {df[col].std():.4f}s")
    print(f"  Total detector time: {detector_df['detector_total'].mean():.4f}s")
    
    print("\nInference Operations:")
    for col in inference_cols:
        print(f"  {col.replace('inference_', '')}: {df[col].mean():.4f}s ± {df[col].std():.4f}s")
    print(f"  Total inference time: {inference_df['inference_total'].mean():.4f}s")
    
    print(f"\nTotal pipeline time: {total_df['Total'].mean():.4f}s ± {total_df['Total'].std():.4f}s")
    
    # Calculate percentage of total time for each operation
    total_time = df.sum(axis=1).mean()
    percent_df = df.mean() / total_time * 100
    
    # Visualizations
    # Set the style
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(14, 8))
    
    # 1. Box plot of all operations
    plt.figure(figsize=(14, 6))
    plt.title('Runtime Distribution of All Operations', fontsize=16)
    sns.boxplot(data=df)
    plt.xticks(rotation=45, ha='right')
    plt.ylabel('Time (seconds)')
    plt.tight_layout()
    plt.savefig('all_operations_boxplot.png', dpi=300)
    
    # 2. Bar plot with mean runtime for each operation
    plt.figure(figsize=(14, 6))
    plt.title('Mean Runtime of Each Operation', fontsize=16)
    sns.barplot(x=df.columns, y=df.mean(), palette='viridis')
    plt.xticks(rotation=45, ha='right')
    plt.ylabel('Time (seconds)')
    plt.tight_layout()
    plt.savefig('mean_runtime_barplot.png', dpi=300)
    
    # 3. Stacked bar for detector vs inference
    plt.figure(figsize=(10, 6))
    plt.title('Proportion of Time: Detector vs Inference', fontsize=16)
    
    # Reshape data for easier plotting
    detector_means = detector_df.drop('detector_total', axis=1).mean()
    inference_means = inference_df.drop('inference_total', axis=1).mean()
    
    # Create stacked bars for detector operations
    detector_data = pd.DataFrame({
        'Operation': detector_means.index.str.replace('detector_', ''),
        'Time': detector_means.values,
        'Category': 'Detector'
    })
    
    # Create stacked bars for inference operations
    inference_data = pd.DataFrame({
        'Operation': inference_means.index.str.replace('inference_', ''),
        'Time': inference_means.values,
        'Category': 'Inference'
    })
    
    # Combine data
    combined_data = pd.concat([detector_data, inference_data])
    
    # Plot stacked bars by category
    sns.barplot(x='Category', y='Time', hue='Operation', data=combined_data, palette='viridis')
    plt.ylabel('Time (seconds)')
    plt.legend(title='Operation', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig('detector_vs_inference.png', dpi=300)
    
    # 4. Pie chart of time distribution
    plt.figure(figsize=(12, 10))
    plt.title('Time Distribution Across All Operations', fontsize=16)
    
    # Format labels with both operation name and percentage
    def make_label(pct, values):
        absolute = pct / 100. * sum(values)
        return f"{pct:.1f}%\n({absolute:.3f}s)"
    
    plt.pie(
        percent_df, 
        labels=percent_df.index, 
        autopct=lambda pct: make_label(pct, percent_df.values),
        startangle=90, 
        shadow=False, 
        explode=[0.05] * len(percent_df),
        textprops={'fontsize': 9}
    )
    plt.axis('equal')
    plt.tight_layout()
    plt.savefig('time_distribution_pie.png', dpi=300)
    
    # 5. Detailed breakdown with horizontal bars
    plt.figure(figsize=(12, 8))
    plt.title('Detailed Runtime Breakdown (% of Total Pipeline)', fontsize=16)
    
    # Sort by percentage
    sorted_percent = percent_df.sort_values(ascending=False)
    
    # Create color mapping - detector operations in blue shades, inference in green
    colors = []
    for op in sorted_percent.index:
        if 'detector' in op:
            colors.append('royalblue')
        else:
            colors.append('forestgreen')
    
    # Plot horizontal bars with percentage of total time
    ax = sns.barplot(x=sorted_percent.values, y=sorted_percent.index, palette=colors)
    
    # Add percentage labels
    for i, v in enumerate(sorted_percent.values):
        ax.text(v + 0.5, i, f"{v:.1f}%", va='center')
    
    plt.xlabel('Percentage of Total Pipeline Time')
    plt.tight_layout()
    plt.savefig('detailed_time_breakdown.png', dpi=300)
    
    # 6. Time series plot showing runtime consistency across images
    plt.figure(figsize=(14, 8))
    plt.title('Runtime Consistency Across Images', fontsize=16)
    
    # Plot key operations over time
    key_ops = ['detector_yolo', 'detector_wbf', 'inference_detect', 'inference_classify']
    for op in key_ops:
        plt.plot(df.index, df[op], label=op, marker='o', alpha=0.7, markersize=4)
    
    plt.xlabel('Image Index')
    plt.ylabel('Time (seconds)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('runtime_consistency.png', dpi=300)
    
    # Return the processed dataframes for further analysis if needed
    return {
        'raw_data': df,
        'detector_data': detector_df,
        'inference_data': inference_df,
        'total_data': total_df
    }


results = analyze_runtime(profiler_results)  # Use your actual profiler_results list from your code

In [None]:
def structure_results(image_results, class_num):
    """Organize image-level results into structured collections for evaluation."""
    detect_predictions = {}
    classify_predictions = {}
    classify_gt = {}
    inference_predictions = {}
    quantification_gt = {}
    quantification_predictions = {}
    
    classify_predictions_flat = []
    classify_gt_flat = []
    
    # Process each image result
    for result in image_results:
        image_name = result["image_name"]
        
        # Store detection and inference predictions
        detect_predictions[image_name] = result["detect_prediction"]
        
        # Store classification data
        classify_predictions[image_name] = result["classify_prediction"]
        classify_gt[image_name] = result["classify_gt"]
        
        # Extend flat lists for classification
        classify_predictions_flat.extend(result["classify_prediction"])
        classify_gt_flat.extend(result["classify_gt"])
        
        # Store inference predictions
        inference_predictions[image_name] = result["inference_prediction"]
        
        # Store quantification data
        quantification_gt[image_name] = result["quantification_gt"]
        quantification_predictions[image_name] = result["quantification_prediction"]
    
    # Create flattened quantification arrays
    num_images = len(quantification_gt)
    
    quantification_gt_flat = {
        "areas": np.zeros((num_images, class_num), dtype=np.float64),
        "counts": np.zeros((num_images, class_num), dtype=np.uint64),
    }
    
    quantification_predictions_flat = {
        "areas": np.zeros((num_images, class_num), dtype=np.float64),
        "counts": np.zeros((num_images, class_num), dtype=np.uint64),
    }
    
    # Fill quantification arrays
    for idx, image_name in enumerate(quantification_gt.keys()):
        gt = quantification_gt[image_name]
        pred = quantification_predictions[image_name]
        
        quantification_gt_flat["areas"][idx] = gt["areas"]
        quantification_gt_flat["counts"][idx] = gt["counts"]
        quantification_predictions_flat["areas"][idx] = pred["areas"]
        quantification_predictions_flat["counts"][idx] = pred["counts"]
    
    return {
        "detect_predictions": detect_predictions,
        "classify_predictions": classify_predictions,
        "classify_gt": classify_gt,
        "inference_predictions": inference_predictions,
        "quantification_gt": quantification_gt,
        "quantification_predictions": quantification_predictions,
        "classify_predictions_flat": classify_predictions_flat,
        "classify_gt_flat": classify_gt_flat,
        "quantification_gt_flat": quantification_gt_flat,
        "quantification_predictions_flat": quantification_predictions_flat
    }

# Structure all results
all_results = structure_results(image_results, class_num)

In [None]:
def plot_detection_metrics(results, title_prefix=""):
    """Plot detection metrics in an organized way."""
    # Separate AP and AR metrics for better visualization
    ap_metrics = {k: v for k, v in results.items() if k.startswith("AP")}
    ar_metrics = {k: v for k, v in results.items() if k.startswith("AR")}
    
    # Plot AP metrics
    plt.figure(figsize=(8, 6))
    sns.barplot(x=list(ap_metrics.keys()), y=list(ap_metrics.values()), palette="coolwarm")
    plt.title(f"{title_prefix}Average Precision (AP) Metrics")
    plt.ylabel("Score")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()
    
    # Plot AR metrics
    plt.figure(figsize=(8, 6))
    sns.barplot(x=list(ar_metrics.keys()), y=list(ar_metrics.values()), palette="viridis")
    plt.title(f"{title_prefix}Average Recall (AR) Metrics")
    plt.ylabel("Score")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()

# Evaluate detection results
detect_results = coco_detection_evaluator.eval(
    all_results["detect_predictions"], names_to_ids, useCats=False
)

with open(os.path.join(config['results'], 'detect_coco.json'), "w") as f:
    json.dump(detect_results, f, indent=4)

# Plot detection metrics
plot_detection_metrics(detect_results, "Detection: ")

In [None]:
def evaluate_classification(predictions, ground_truth, class_names):
    """Evaluate classification metrics and plot results."""
    precision = precision_score(ground_truth, predictions, average='macro')
    recall = recall_score(ground_truth, predictions, average='macro')
    f1 = f1_score(ground_truth, predictions, average='macro')
    accuracy = accuracy_score(ground_truth, predictions)
    
    # Plot confusion matrix
    cm = confusion_matrix(ground_truth, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=list(class_names.values()),
                yticklabels=list(class_names.values()))
    plt.title(f'Confusion Matrix for Classification')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()
    
    # Plot metrics
    metrics = {
        "Precision": precision,
        "Recall": recall,
        "F1 Score": f1,
        "Accuracy": accuracy,
    }
    
    plt.figure(figsize=(8, 6))
    sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), palette="viridis")
    plt.title("Classification Metrics")
    plt.ylabel("Score")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()
    
    return metrics

# Evaluate classification results
classification_metrics = evaluate_classification(
    all_results["classify_predictions_flat"], 
    all_results["classify_gt_flat"],
    config["to_classes"]
)

In [None]:
infere_results = coco_detection_evaluator.eval(
    all_results["inference_predictions"], names_to_ids, useCats=True
)

with open(os.path.join(config['results'], 'inference_coco.json'), "w") as f:
    json.dump(infere_results, f, indent=4)

# Plot inference metrics
plot_detection_metrics(infere_results, "Full Inference: ")

In [None]:
def evaluate_quantification_metrics(gt_values, pred_values, class_names, metric_type="Area"):
    """Calculate and return quantification metrics."""
    metrics = {}
    
    # Per-class metrics
    for class_id, class_name in class_names.items():
        if class_id != 0:  # Skip background
            gt = gt_values[:, class_id]
            pred = pred_values[:, class_id]
            
            # Skip if no ground truth
            if np.sum(gt) == 0:
                continue
                
            metrics[class_name] = {
                "Explained Variance": explained_variance_score(gt, pred),
                #"MAE": mean_absolute_error(gt, pred),
                "R2": r2_score(gt, pred),
              #  "MAPE": mean_absolute_percentage_error(gt, pred) if np.any(gt != 0) else float('nan'),
            }
    
    # Overall metrics (excluding background)
    all_gt = gt_values[:, 1:].sum(axis=1)
    all_pred = pred_values[:, 1:].sum(axis=1)
    
    if np.sum(all_gt) > 0:
        metrics["All Classes"] = {
            "Explained Variance": explained_variance_score(all_gt, all_pred),
            #"MAE": mean_absolute_error(all_gt, all_pred),
            "R2": r2_score(all_gt, all_pred),
           # "MAPE": mean_absolute_percentage_error(all_gt, all_pred) if np.any(all_gt != 0) else float('nan'),
        }
    
    return metrics

def plot_quantification_metrics(metrics, title_prefix=""):
    """Plot quantification metrics for each class."""
    for name, class_metrics in metrics.items():
        valid_metrics = {k: v for k, v in class_metrics.items() if not np.isnan(v)}
        
        if not valid_metrics:
            continue
            
        plt.figure(figsize=(8, 6))
        sns.barplot(x=list(valid_metrics.keys()), y=list(valid_metrics.values()))
        plt.title(f"{title_prefix} Metrics for {name}")
        plt.ylabel("Score")
        plt.tight_layout()
        plt.show()

In [None]:
area_metrics = evaluate_quantification_metrics(
    all_results["quantification_gt_flat"]["areas"],
    all_results["quantification_predictions_flat"]["areas"],
    config["to_classes"],
    "Area"
)



plot_quantification_metrics(area_metrics, "Area")

In [None]:
count_metrics = evaluate_quantification_metrics(
    all_results["quantification_gt_flat"]["counts"],
    all_results["quantification_predictions_flat"]["counts"],
    config["to_classes"],
    "Count"
)

plot_quantification_metrics(count_metrics, "Count")