In [6]:
# --------------------- IMPORT NECESSARY LIBRARIES ---------------------
from fastai.vision.all import *
import time
import pandas as pd
import numpy as np
from pathlib import Path
import torch
import re
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# --------------------- DEFINE REQUIRED FUNCTIONS ---------------------

def get_image(r): 
    """
    Constructs the full image path for a given row in the DataFrame.
    
    Args:
        r (pd.Series): A row from the DataFrame.
    
    Returns:
        str: Full path to the image file.
    """
    return str(dataroot_path / 'frames' / f"{r['filename']}.jpg")

def smooth_predictions_with_neighbors(predicted_classes: List[str], window_size: int = 1) -> List[str]:
    """
    Smooths the predicted class labels by exploring the neighboring frames.
    
    Args:
        predicted_classes (List[str]): The list of predicted class labels per frame.
        window_size (int): Number of neighboring frames to consider on each side for smoothing.
    
    Returns:
        List[str]: The smoothed list of predicted class labels.
    """
    if not predicted_classes:
        return []
    
    num_frames = len(predicted_classes)
    smoothed_predictions = []
    
    for i in range(num_frames):
        # Define the window boundaries
        start_idx = max(0, i - window_size)
        end_idx = min(num_frames - 1, i + window_size)
    
        # Extract the neighborhood labels
        neighborhood = predicted_classes[start_idx:end_idx + 1]
    
        # Count the frequency of each label in the neighborhood
        label_counts = {}
        for label in neighborhood:
            label_counts[label] = label_counts.get(label, 0) + 1
    
        # Identify the most common label(s)
        max_count = max(label_counts.values())
        common_labels = [label for label, count in label_counts.items() if count == max_count]
    
        # Resolve ties by retaining the current frame's label
        if len(common_labels) == 1:
            most_common_label = common_labels[0]
        else:
            most_common_label = predicted_classes[i]
    
        smoothed_predictions.append(most_common_label)
    
    return smoothed_predictions

def decode_predictions(preds: torch.Tensor, vocab: List[str]) -> List[str]:
    """
    Convert raw predictions into class labels.
    
    Args:
        preds (torch.Tensor): Raw predictions (logits or probabilities).
        vocab (List[str]): List of class labels.
    
    Returns:
        List[str]: List of predicted class labels.
    """
    pred_indices = preds.argmax(dim=1)
    predicted_classes = [vocab[i] for i in pred_indices]
    return predicted_classes

# --------------------- DATA PREPARATION ---------------------

# Define your data root path
dataroot_path = Path('/home/exsdatalab/data/surgvu24')

# Load the images DataFrame
images_df = pd.read_csv(dataroot_path / 'final_labels.csv')

# Get all model paths with '.pkl' extension in 'models/baseline/' directory
models_path = [pth for pth in Path('models/baseline/').ls() if pth.suffix == '.pkl']

# Check if models are found
if not models_path:
    raise ValueError("No models found in 'models/baseline/' directory with '.pkl' extension.")

print(f"Found {len(models_path)} models for ensemble.")

# Filter validation data
images_df_valid = images_df[images_df['valid'] == True].copy()

# Construct image paths
images_df_valid['image_path'] = images_df_valid.apply(get_image, axis=1)
image_files = images_df_valid['image_path'].tolist()
true_labels = images_df_valid['task_label'].tolist()

# Ensure image files exist
image_files = [Path(p) for p in image_files if Path(p).exists()]
if not image_files:
    raise ValueError("No valid image files found for validation.")

# --------------------- MODEL PREDICTIONS ---------------------

# Store predicted labels, raw predictions, and computation times for each model
all_model_preds = []
all_raw_preds = []
model_times = []
vocab = None  # To store vocabulary from the first model

for idx, model_path in enumerate(models_path):
    model_name = model_path.stem  # Get model filename without extension

    # Load the model
    try:
        learner = load_learner(model_path, cpu=False)
    except Exception as e:
        print(f"Error loading model at {model_path}: {e}")
        continue
    print(f'Loaded model {model_name} ({idx+1}/{len(models_path)}) from {model_path}.')

    # Create test_dl using the loaded learner's DataLoaders
    test_dl = learner.dls.test_dl(image_files, bs=128, num_workers=8)

    # Get predictions from the model
    try:
        start_time = time.time()
        with torch.no_grad():
            raw_preds, _ = learner.get_preds(dl=test_dl)
        end_time = time.time()
        elapsed_time = end_time - start_time
        preds = decode_predictions(raw_preds.cpu(), learner.dls.vocab)
        all_model_preds.append({'model_name': model_name, 'preds': preds, 'raw_preds': raw_preds.cpu()})
        all_raw_preds.append(raw_preds.cpu())  # Store raw predictions for ensemble
        model_times.append({'model_name': model_name, 'time': elapsed_time})
    except Exception as e:
        print(f"Error during prediction with model at {model_path}: {e}")
        continue

    # Store vocab from the first model
    if vocab is None:
        vocab = learner.dls.vocab

    # Free up memory
    del learner
    torch.cuda.empty_cache()

if not all_raw_preds:
    raise ValueError("No predictions were made. Please check your models and image paths.")

# --------------------- ENSEMBLE PREDICTIONS ---------------------

# Measure time for ensemble prediction
ensemble_start_time = time.time()

# Average the raw predictions across all models
ensemble_raw_preds = torch.stack(all_raw_preds).mean(0)

ensemble_end_time = time.time()
ensemble_elapsed_time = ensemble_end_time - ensemble_start_time

ensemble_preds = decode_predictions(ensemble_raw_preds, vocab)

# --------------------- POST-PROCESSING ---------------------

# Measure time for post-processing (prediction correction)
best_avg_ws = 8  # As defined earlier or determined through evaluation
postproc_start_time = time.time()
smoothed_preds = smooth_predictions_with_neighbors(ensemble_preds, window_size=best_avg_ws)
postproc_end_time = time.time()
postproc_elapsed_time = postproc_end_time - postproc_start_time

# --------------------- METRICS CALCULATION ---------------------

# Map labels to indices
label_to_idx = {label: idx for idx, label in enumerate(vocab)}
true_indices = [label_to_idx[label] for label in true_labels]

# Initialize a list to store metrics for each model
metrics_list = []

# Compute metrics for each model
for model_info in all_model_preds:
    model_name = model_info['model_name']
    preds = model_info['preds']
    pred_indices = [label_to_idx[label] for label in preds]

    acc = accuracy_score(true_indices, pred_indices)
    prec = precision_score(true_indices, pred_indices, average='macro', zero_division=0)
    rec = recall_score(true_indices, pred_indices, average='macro', zero_division=0)
    f1 = f1_score(true_indices, pred_indices, average='macro', zero_division=0)

    # Find the time taken for this model
    model_time = next((item['time'] for item in model_times if item['model_name'] == model_name), None)

    metrics_list.append({
        'Model': model_name,
        'Accuracy': acc,
        'Precision': prec,
        'Recall': rec,
        'F1 Score': f1,
        'Time (s)': model_time
    })

# Compute metrics for the ensemble
ensemble_pred_indices = [label_to_idx[label] for label in ensemble_preds]

acc = accuracy_score(true_indices, ensemble_pred_indices)
prec = precision_score(true_indices, ensemble_pred_indices, average='macro', zero_division=0)
rec = recall_score(true_indices, ensemble_pred_indices, average='macro', zero_division=0)
f1 = f1_score(true_indices, ensemble_pred_indices, average='macro', zero_division=0)

metrics_list.append({
    'Model': 'Ensemble',
    'Accuracy': acc,
    'Precision': prec,
    'Recall': rec,
    'F1 Score': f1,
    'Time (s)': ensemble_elapsed_time
})

# Compute metrics for the post-processed predictions
smoothed_pred_indices = [label_to_idx[label] for label in smoothed_preds]

acc = accuracy_score(true_indices, smoothed_pred_indices)
prec = precision_score(true_indices, smoothed_pred_indices, average='macro', zero_division=0)
rec = recall_score(true_indices, smoothed_pred_indices, average='macro', zero_division=0)
f1 = f1_score(true_indices, smoothed_pred_indices, average='macro', zero_division=0)

metrics_list.append({
    'Model': 'Prediction Correction',
    'Accuracy': acc,
    'Precision': prec,
    'Recall': rec,
    'F1 Score': f1,
    'Time (s)': postproc_elapsed_time
})

# Create a DataFrame from the metrics list
metrics_df = pd.DataFrame(metrics_list)

# Display the results table
print(metrics_df)

# Save the results table to a CSV file
results_csv_path = dataroot_path / 'results_table.csv'
metrics_df.to_csv(results_csv_path, index=False)
print(f"Results table saved to {results_csv_path}")

Found 3 models for ensemble.
Loaded model regnety_008 (1/3) from models/baseline/regnety_008.pkl.


Loaded model convnextv2_tiny (2/3) from models/baseline/convnextv2_tiny.pkl.


Loaded model vit_tiny (3/3) from models/baseline/vit_tiny.pkl.


                   Model  Accuracy  Precision    Recall  F1 Score   Time (s)
0            regnety_008  0.584119   0.675336  0.512029  0.540557  74.195408
1        convnextv2_tiny  0.818238   0.828582  0.704525  0.746908  78.108038
2               vit_tiny  0.811633   0.786008  0.755999  0.768067  67.449265
3               Ensemble  0.812745   0.832694  0.715647  0.755712   0.000802
4  Prediction Correction  0.834749   0.879821  0.724274  0.767982   0.089151
Results table saved to /home/exsdatalab/data/surgvu24/results_table.csv


In [5]:
# Save the results table to a CSV file
results_csv_path = dataroot_path / 'results_table.csv'
metrics_df.to_csv(results_csv_path, index=False)
print(f"Results table saved to {results_csv_path}")

Found 3 models for ensemble.
Loaded model regnety_008 (1/3) from models/baseline/regnety_008.pkl.


Loaded model convnextv2_tiny (2/3) from models/baseline/convnextv2_tiny.pkl.


Loaded model vit_tiny (3/3) from models/baseline/vit_tiny.pkl.


                   Model  Accuracy  Precision    Recall  F1 Score   Time (s)
0            regnety_008  0.584119   0.675336  0.512029  0.540557  73.783131
1        convnextv2_tiny  0.818238   0.828582  0.704525  0.746908  78.418473
2               vit_tiny  0.811633   0.786008  0.755999  0.768067  67.497329
3               Ensemble  0.812745   0.832694  0.715647  0.755712   0.000877
4  Prediction Correction  0.834749   0.879821  0.724274  0.767982   0.089676
Results table saved to /home/exsdatalab/data/surgvu24/results_table.csv
