<a href="https://colab.research.google.com/github/hatef-hosseinpour/dental-clf/blob/main/edl_3_clf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""dental_classification_edl.py

Dental classification using Evidential Deep Learning for uncertainty quantification
"""

from google.colab import drive
drive.mount('/content/drive')

!pip install edl_pytorch

import os
import cv2
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('ggplot')
import numpy as np
from PIL import Image
import gc
import math
from tqdm import tqdm
from edl_pytorch import Dirichlet, evidential_classification
import warnings
warnings.filterwarnings("ignore")

# Set global configuration
CONFIG = {
    'IMAGE_SIZE': 224,
    'BATCH_SIZE': 32,
    'LEARNING_RATE': 0.001,
    'EPOCHS': 30,
    'SEED': 42,
    'DEVICE': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    'NORMALIZATION_MEAN': [0.5, 0.5, 0.5],
    'NORMALIZATION_STD': [0.5, 0.5, 0.5],
    'SAVE_PATH': "/content/drive/MyDrive/Dentisrty/models/",
    'METRICS_PATH': "/content/drive/MyDrive/Dentisrty/metrics/",
    'LAMBDA_EDL': 1.0,  # EDL regularization parameter
    'ANNEALING_COEFF': 0.01  # KL annealing coefficient
}

# Define the label mapping
LABEL_MAPPING = {'amalgam_filling': 0, 'caries': 1, 'healthy': 2}
NUM_CLASSES = len(LABEL_MAPPING)

def seed_everything(seed_value=42):
    """Set seed for reproducibility across all libraries"""
    import random
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed_value)

# Set the seed for reproducibility
seed_everything(CONFIG['SEED'])

# Define transforms
train_transforms = transforms.Compose([
    transforms.Resize((CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(45),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(CONFIG['NORMALIZATION_MEAN'], CONFIG['NORMALIZATION_STD'])
])

val_transforms = transforms.Compose([
    transforms.Resize((CONFIG['IMAGE_SIZE'], CONFIG['IMAGE_SIZE'])),
    transforms.ToTensor(),
    transforms.Normalize(CONFIG['NORMALIZATION_MEAN'], CONFIG['NORMALIZATION_STD'])
])

class DentalDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

def load_data(data_dir):
    """Load images and labels from directories"""
    images = []
    labels = []
    label_counts = {label: 0 for label in LABEL_MAPPING.keys()}

    for class_name, label_idx in LABEL_MAPPING.items():
        folder_path = os.path.join(data_dir, class_name)
        for file_name in os.listdir(folder_path):
            file_path = os.path.join(folder_path, file_name)
            try:
                image = cv2.imread(file_path, cv2.IMREAD_COLOR)
                if image is not None:
                    # Convert BGR to RGB
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    image = Image.fromarray(image)
                    images.append(image)
                    labels.append(label_idx)
                    label_counts[class_name] += 1
            except Exception as e:
                print(f"Error loading {file_path}: {e}")

    print(f"Loaded class distribution: {label_counts}")
    return images, labels

def plot_distribution(labels):
    """Plot the distribution of classes"""
    plt.figure(figsize=(10, 6))
    plt.title('Distribution of Images by Class')
    plt.xlabel('Class')
    plt.ylabel('Number of Images')

    label_counts = {}
    for label in labels:
        if label in label_counts:
            label_counts[label] += 1
        else:
            label_counts[label] = 1

    x = list(LABEL_MAPPING.keys())
    y = [label_counts.get(LABEL_MAPPING[key], 0) for key in x]

    plt.bar(x, y)
    plt.grid(True, alpha=0.3)
    plt.show()

class EvidentialModel(nn.Module):
    """Evidential neural network model"""
    def __init__(self, base_model, num_classes):
        super(EvidentialModel, self).__init__()
        self.base_model = base_model
        self.num_classes = num_classes

        # Remove the original classifier
        if hasattr(base_model, 'fc'):
            # ResNet
            feature_dim = base_model.fc.in_features
            self.base_model.fc = nn.Identity()
        elif hasattr(base_model, 'classifier'):
            # DenseNet
            feature_dim = base_model.classifier.in_features
            self.base_model.classifier = nn.Identity()

        # Evidential layer - outputs evidence (alpha - 1)
        self.evidence_layer = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        features = self.base_model(x)
        evidence = self.evidence_layer(features)
        # Ensure evidence is positive using ReLU
        evidence = torch.relu(evidence)
        # Convert evidence to Dirichlet parameters (alphas)
        alpha = evidence + 1
        return alpha

def build_evidential_models():
    """Build evidential ResNet18 and DenseNet121 models"""
    # ResNet18
    base_resnet18 = models.resnet18(weights='IMAGENET1K_V1')
    model_resnet18 = EvidentialModel(base_resnet18, NUM_CLASSES)

    # DenseNet121
    base_densenet121 = models.densenet121(weights='IMAGENET1K_V1')
    model_densenet121 = EvidentialModel(base_densenet121, NUM_CLASSES)

    return model_resnet18.to(CONFIG['DEVICE']), model_densenet121.to(CONFIG['DEVICE'])

def evidential_loss(alpha, target, epoch, num_epochs, device):
    """
    Evidential loss function combining classification loss and KL divergence
    """
    # Convert target to one-hot
    target_one_hot = torch.zeros(alpha.size(0), alpha.size(1), device=device)
    target_one_hot.scatter_(1, target.unsqueeze(1), 1)

    # Sum of Dirichlet parameters
    S = torch.sum(alpha, dim=1, keepdim=True)

    # Expected probability
    p = alpha / S

    # Accuracy term (negative log-likelihood)
    A = torch.sum(target_one_hot * (torch.digamma(S) - torch.digamma(alpha)), dim=1)

    # KL divergence term with annealing
    annealing_coeff = min(1.0, epoch / (num_epochs * CONFIG['ANNEALING_COEFF']))

    # KL divergence between Dirichlet and uniform distribution
    beta = torch.ones((1, alpha.size(1)), device=device)
    S_beta = torch.sum(beta, dim=1, keepdim=True)

    lnB = torch.lgamma(S) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
    lnB_beta = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)

    dg0 = torch.digamma(S)
    dg1 = torch.digamma(alpha)

    kl = lnB + lnB_beta + torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True)

    # Total loss
    loss = A + annealing_coeff * CONFIG['LAMBDA_EDL'] * kl.squeeze()

    return torch.mean(loss)

def get_uncertainty_measures(alpha):
    """
    Calculate various uncertainty measures from Dirichlet parameters
    """
    # Total evidence
    S = torch.sum(alpha, dim=1)

    # Predicted probability (mean of Dirichlet)
    prob = alpha / S.unsqueeze(1)

    # Predictive uncertainty (total uncertainty)
    predictive_uncertainty = NUM_CLASSES / S

    # Aleatoric uncertainty (data uncertainty)
    expected_prob = prob
    aleatoric_uncertainty = torch.sum(expected_prob * (1 - expected_prob), dim=1)

    # Epistemic uncertainty (model uncertainty)
    epistemic_uncertainty = predictive_uncertainty - aleatoric_uncertainty

    return {
        'probability': prob,
        'total_uncertainty': predictive_uncertainty,
        'aleatoric_uncertainty': aleatoric_uncertainty,
        'epistemic_uncertainty': epistemic_uncertainty,
        'evidence': S - NUM_CLASSES
    }

def train_evidential_models(model_1, model_2, train_loader, model_1_optimizer, model_2_optimizer, epoch):
    """Train both evidential models for one epoch"""
    model_1.train()
    model_2.train()

    total_model_1_loss = 0
    total_model_2_loss = 0
    num_batches = len(train_loader)

    for inputs, labels in tqdm(train_loader, desc='Training'):
        inputs, labels = inputs.to(CONFIG['DEVICE']), labels.to(CONFIG['DEVICE'])

        # Train model 1 (ResNet18)
        model_1_optimizer.zero_grad()
        alpha_1 = model_1(inputs)
        model_1_loss = evidential_loss(alpha_1, labels, epoch, CONFIG['EPOCHS'], CONFIG['DEVICE'])
        model_1_loss.backward()
        model_1_optimizer.step()
        total_model_1_loss += model_1_loss.item()

        # Train model 2 (DenseNet121)
        model_2_optimizer.zero_grad()
        alpha_2 = model_2(inputs)
        model_2_loss = evidential_loss(alpha_2, labels, epoch, CONFIG['EPOCHS'], CONFIG['DEVICE'])
        model_2_loss.backward()
        model_2_optimizer.step()
        total_model_2_loss += model_2_loss.item()

    return total_model_1_loss / num_batches, total_model_2_loss / num_batches

def combine_evidential_predictions(alpha_1, alpha_2):
    """
    Combine predictions from two evidential models using evidence combination
    """
    # Convert alphas to evidence (subtract 1)
    evidence_1 = alpha_1 - 1
    evidence_2 = alpha_2 - 1

    # Combine evidence by addition
    combined_evidence = evidence_1 + evidence_2

    # Convert back to alpha parameters
    combined_alpha = combined_evidence + 1

    return combined_alpha

def evaluate_evidential_models(model_1, model_2, dataloader, phase="val", visualize=False):
    """Evaluate using both evidential models"""
    model_1.eval()
    model_2.eval()

    total_acc = 0
    total_count = 0
    all_labels = []
    all_preds = []
    all_uncertainties = []
    all_combined_alphas = []

    label_idx_to_name = {v: k for k, v in LABEL_MAPPING.items()}

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc=f'Evaluating {phase}')):
            images, labels = images.to(CONFIG['DEVICE']), labels.to(CONFIG['DEVICE'])

            # Get Dirichlet parameters from both models
            alpha_1 = model_1(images)
            alpha_2 = model_2(images)

            # Combine predictions
            combined_alpha = combine_evidential_predictions(alpha_1, alpha_2)

            # Get uncertainty measures
            uncertainty_measures = get_uncertainty_measures(combined_alpha)

            # Make predictions
            probs = uncertainty_measures['probability']
            preds = torch.argmax(probs, dim=1)

            # Update metrics
            correct = (preds == labels).sum().item()
            total_acc += correct
            total_count += len(labels)

            # Store results
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_uncertainties.extend(uncertainty_measures['total_uncertainty'].cpu().numpy())
            all_combined_alphas.extend(combined_alpha.cpu().numpy())

            # Visualize results for first batch in test phase
            if visualize and phase == "test" and batch_idx == 0:
                visualize_evidential_results(images, uncertainty_measures, labels, batch_idx)

    # Calculate metrics
    accuracy = total_acc / total_count
    all_labels_np = np.array(all_labels)
    all_preds_np = np.array(all_preds)

    # Compute metrics
    metrics = {
        'accuracy': accuracy,
        'precision': precision_score(all_labels_np, all_preds_np, average='weighted'),
        'recall': recall_score(all_labels_np, all_preds_np, average='weighted'),
        'f1_score': f1_score(all_labels_np, all_preds_np, average='weighted'),
        'confusion_matrix': confusion_matrix(all_labels_np, all_preds_np),
        'mean_uncertainty': np.mean(all_uncertainties),
        'std_uncertainty': np.std(all_uncertainties)
    }

    if phase == "test":
        print_detailed_metrics(metrics)

    return accuracy, metrics, all_combined_alphas

def print_detailed_metrics(metrics):
    """Print detailed metrics for test evaluation"""
    print("\n----- Detailed Metrics -----")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1_score']:.4f}")
    print(f"Mean Uncertainty: {metrics['mean_uncertainty']:.4f}")
    print(f"Std Uncertainty: {metrics['std_uncertainty']:.4f}")

    # Display confusion matrix
    cm = metrics['confusion_matrix']

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d',
                xticklabels=list(LABEL_MAPPING.keys()),
                yticklabels=list(LABEL_MAPPING.keys()))
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

def visualize_evidential_results(images, uncertainty_measures, labels, batch_idx=0):
    """Visualize evidential prediction results with uncertainty"""
    num_images = min(16, len(images))
    grid_cols = 4
    grid_rows = math.ceil(num_images / grid_cols)

    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(15, 4 * grid_rows))
    if grid_rows == 1 and grid_cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    # Get mapping from index to label name
    label_idx_to_name = {v: k for k, v in LABEL_MAPPING.items()}

    for i, ax in enumerate(axes):
        if i < num_images:
            # Get image
            img = images[i].cpu().numpy().transpose(1, 2, 0)
            img = img * np.array(CONFIG['NORMALIZATION_STD']) + np.array(CONFIG['NORMALIZATION_MEAN'])
            img = np.clip(img, 0, 1)

            # Get predictions and uncertainties
            probs = uncertainty_measures['probability'][i].cpu().numpy()
            pred_idx = np.argmax(probs)
            pred_class = label_idx_to_name[pred_idx]
            true_class = label_idx_to_name[labels[i].item()]

            total_unc = uncertainty_measures['total_uncertainty'][i].cpu().item()
            aleatoric_unc = uncertainty_measures['aleatoric_uncertainty'][i].cpu().item()
            epistemic_unc = uncertainty_measures['epistemic_uncertainty'][i].cpu().item()

            # Format text
            prob_text = "\n".join([f"{label_idx_to_name[j]}: {probs[j]:.3f}"
                                  for j in range(len(probs))])
            unc_text = f"Total: {total_unc:.3f}\nAleatoric: {aleatoric_unc:.3f}\nEpistemic: {epistemic_unc:.3f}"

            # Show image with predictions
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"Pred: {pred_class}\nTrue: {true_class}\n\nProbs:\n{prob_text}\n\nUncertainty:\n{unc_text}",
                        fontsize=8)

            # Highlight incorrect predictions or high uncertainty
            if pred_class != true_class or total_unc > 0.5:
                color = 'red' if pred_class != true_class else 'orange'
                for spine in ax.spines.values():
                    spine.set_edgecolor(color)
                    spine.set_linewidth(3)
        else:
            ax.axis('off')

    plt.tight_layout()
    plt.savefig(f'edl_prediction_batch_{batch_idx}.png')
    plt.show()

def plot_uncertainty_distribution(uncertainties, labels, save_path):
    """Plot uncertainty distribution by class"""
    plt.figure(figsize=(12, 8))

    label_idx_to_name = {v: k for k, v in LABEL_MAPPING.items()}

    # Plot uncertainty by class
    plt.subplot(2, 2, 1)
    for class_idx in range(NUM_CLASSES):
        class_mask = np.array(labels) == class_idx
        class_uncertainties = np.array(uncertainties)[class_mask]
        plt.hist(class_uncertainties, alpha=0.7, label=label_idx_to_name[class_idx], bins=20)
    plt.xlabel('Total Uncertainty')
    plt.ylabel('Frequency')
    plt.title('Uncertainty Distribution by Class')
    plt.legend()

    # Plot uncertainty vs accuracy
    plt.subplot(2, 2, 2)
    correct_predictions = np.array(labels) == np.argmax(uncertainties, axis=1) if len(np.array(uncertainties).shape) > 1 else np.ones_like(labels)
    plt.scatter(uncertainties, correct_predictions, alpha=0.6)
    plt.xlabel('Total Uncertainty')
    plt.ylabel('Correct Prediction')
    plt.title('Uncertainty vs Prediction Correctness')

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, 'uncertainty_analysis.png'))
    plt.show()

def main():
    # Set the data directory
    data_dir = '/content/drive/MyDrive/Dentisrty/panoramic_data'

    # Load data
    print("Loading data...")
    images, labels = load_data(data_dir)
    print(f"Loaded {len(images)} images")

    # Plot class distribution
    plot_distribution(labels)

    # Split data
    print("Splitting data...")
    X_train, X_test, y_train, y_test = train_test_split(
        images, labels, test_size=0.2, random_state=CONFIG['SEED']
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.2, random_state=CONFIG['SEED']
    )

    print(f"Train: {len(X_train)}, Validation: {len(X_val)}, Test: {len(X_test)}")

    # Create datasets and dataloaders
    print("Creating datasets and dataloaders...")
    train_dataset = DentalDataset(X_train, y_train, transform=train_transforms)
    val_dataset = DentalDataset(X_val, y_val, transform=val_transforms)
    test_dataset = DentalDataset(X_test, y_test, transform=val_transforms)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=True)

    # Initialize evidential models
    print("Initializing evidential models...")
    model_resnet18, model_densenet121 = build_evidential_models()

    # Optimizers
    model_1_optimizer = optim.AdamW(model_resnet18.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=1e-4)
    model_2_optimizer = optim.AdamW(model_densenet121.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=1e-4)

    # Create schedulers
    model_1_scheduler = optim.lr_scheduler.CosineAnnealingLR(model_1_optimizer, T_max=CONFIG['EPOCHS'])
    model_2_scheduler = optim.lr_scheduler.CosineAnnealingLR(model_2_optimizer, T_max=CONFIG['EPOCHS'])

    # Training loop
    print(f"\n{'='*50}")
    print(f"Starting evidential training for {CONFIG['EPOCHS']} epochs")
    print(f"{'='*50}\n")

    best_val_accuracy = 0.0
    metrics_history = []

    for epoch in range(CONFIG['EPOCHS']):
        print(f'\nEpoch {epoch+1}/{CONFIG["EPOCHS"]}')

        # Train
        model_1_loss, model_2_loss = train_evidential_models(
            model_resnet18, model_densenet121, train_loader,
            model_1_optimizer, model_2_optimizer, epoch
        )

        # Step schedulers
        model_1_scheduler.step()
        model_2_scheduler.step()

        print(f'ResNet18 EDL Loss: {model_1_loss:.4f}, DenseNet121 EDL Loss: {model_2_loss:.4f}')

        # Validate
        val_accuracy, val_metrics, _ = evaluate_evidential_models(
            model_resnet18, model_densenet121, val_loader, "val"
        )
        print(f'Validation Accuracy: {val_accuracy:.4f}, Mean Uncertainty: {val_metrics["mean_uncertainty"]:.4f}')

        # Save best models
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model_resnet18.state_dict(), os.path.join(CONFIG['SAVE_PATH'], 'best_resnet18_edl.pth'))
            torch.save(model_densenet121.state_dict(), os.path.join(CONFIG['SAVE_PATH'], 'best_densenet121_edl.pth'))
            print(f'Saved new best models with validation accuracy: {val_accuracy:.4f}')

        # Track metrics
        metrics_history.append({
            'epoch': epoch + 1,
            'resnet18_loss': model_1_loss,
            'densenet121_loss': model_2_loss,
            'val_accuracy': val_accuracy,
            'val_precision': val_metrics['precision'],
            'val_recall': val_metrics['recall'],
            'val_f1': val_metrics['f1_score'],
            'val_mean_uncertainty': val_metrics['mean_uncertainty']
        })

    # Save training history
    pd.DataFrame(metrics_history).to_csv(os.path.join(CONFIG['METRICS_PATH'], 'edl_training_history.csv'), index=False)

    # Plot training history
    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.plot([m['epoch'] for m in metrics_history], [m['resnet18_loss'] for m in metrics_history], label='ResNet18 Loss')
    plt.plot([m['epoch'] for m in metrics_history], [m['densenet121_loss'] for m in metrics_history], label='DenseNet121 Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot([m['epoch'] for m in metrics_history], [m['val_accuracy'] for m in metrics_history], label='Validation Accuracy')
    plt.plot([m['epoch'] for m in metrics_history], [m['val_precision'] for m in metrics_history], label='Validation Precision')
    plt.plot([m['epoch'] for m in metrics_history], [m['val_f1'] for m in metrics_history], label='Validation F1')
    plt.title('Validation Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()

    plt.subplot(2, 2, 3)
    plt.plot([m['epoch'] for m in metrics_history], [m['val_mean_uncertainty'] for m in metrics_history], label='Mean Uncertainty')
    plt.title('Validation Mean Uncertainty')
    plt.xlabel('Epoch')
    plt.ylabel('Uncertainty')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['METRICS_PATH'], 'edl_training_history.png'))
    plt.show()

    # Final evaluation
    print('\nPerforming final evidential evaluation...')
    test_accuracy, test_metrics, final_alphas = evaluate_evidential_models(
        model_resnet18, model_densenet121, test_loader, "test", visualize=True
    )

    # Print final results
    print(50 * '=')
    print('FINAL EVIDENTIAL RESULTS:')
    print(f'Test Accuracy: {test_accuracy:.4f}')
    print(f'Test Precision: {test_metrics["precision"]:.4f}')
    print(f'Test Recall: {test_metrics["recall"]:.4f}')
    print(f'Test F1 Score: {test_metrics["f1_score"]:.4f}')
    print(f'Mean Test Uncertainty: {test_metrics["mean_uncertainty"]:.4f}')
    print(f'Std Test Uncertainty: {test_metrics["std_uncertainty"]:.4f}')
    print(50 * '=')

    # Save final results
    results_df = pd.DataFrame({
        'Model': ['EDL ResNet18+DenseNet121'],
        'Accuracy': [test_accuracy],
        'Precision': [test_metrics['precision']],
        'Recall': [test_metrics['recall']],
        'F1_Score': [test_metrics['f1_score']],
        'Mean_Uncertainty': [test_metrics['mean_uncertainty']],
        'Std_Uncertainty': [test_metrics['std_uncertainty']]
    })
    results_df.to_csv(os.path.join(CONFIG['METRICS_PATH'], 'edl_final_results.csv'), index=False)

    print(f"Results saved to {os.path.join(CONFIG['METRICS_PATH'], 'edl_final_results.csv')}")

if __name__ == '__main__':
    main()