In [None]:
import os
import pickle
import matplotlib.pyplot as plt
import pandas as pd

# Define the directory containing the models
base_dir = "runs"
models = ["base_svm", "base_knn", "base_cnn", "base_resnet", "base_vit"]

# Define the folder to save images
save_dir = "images"
os.makedirs(save_dir, exist_ok=True)  # Create the folder if it doesn't exist

# Initialize storage for metrics
metrics = {}

# Iterate through models and load data
for model in models:
    log_path = os.path.join(base_dir, model, "log.pkl")
    if os.path.exists(log_path):
        with open(log_path, "rb") as f:
            metrics[model] = pickle.load(f)

# Ensure all models have been processed
if not metrics:
    print("No log.pkl files found in the specified directories.")
    exit()

# Extract metrics for bar graphs and tables
inference_times = []
testing_accuracies = []
f1_scores = []
confusion_matrices = []
model_names = []

for model, data in metrics.items():
    model_names.append(model.replace("base_", "").upper())
    inference_times.append(data.get("inference_time", 0))
    testing_accuracies.append(data.get("test_accuracy", 0))
    f1_scores.append(data.get("f1_score", 0))
    confusion_matrices.append(data.get("confusion_matrix", [[]]))

# Print testing metrics
print("\n=== Model Testing Metrics ===")
df_metrics = pd.DataFrame({
    "Model": model_names,
    "F1 Score": [f"{score:.4f}" for score in f1_scores],
    "Testing Accuracy": [f"{acc:.4f}" for acc in testing_accuracies],
    "Inference Time (ms)": [f"{time * 1000:.4f}" for time in inference_times]
})
print(df_metrics)

# Print confusion matrices
for model, cm in zip(model_names, confusion_matrices):
    print(f"\n=== Confusion Matrix for {model} ===")
    cm_df = pd.DataFrame(cm)
    print(cm_df)

# Plot training vs validation accuracy and loss on stacked subplots for each model
plot_models = ["base_cnn", "base_resnet", "base_vit"]
for model in plot_models:
    if model in metrics:
        data = metrics[model]
        epochs = range(1, len(data.get("train_accuracy", [])) + 1)  # Epoch range from 1 to n_epochs
        
        # Create a figure with stacked subplots
        fig, axes = plt.subplots(2, 1, figsize=(8, 10), sharex=True)

        # Plot training and validation accuracy
        axes[0].plot(epochs, data.get("train_accuracy", []), label="Training Accuracy", marker="o", linestyle="-")
        axes[0].plot(epochs, data.get("val_accuracy", []), label="Validation Accuracy", marker="o", linestyle="--")
        axes[0].set_title(f"Training vs Validation Accuracy: {model.replace('base_', '').upper()}", fontsize=14, fontweight='bold')
        axes[0].set_ylabel("Accuracy", fontsize=12)
        axes[0].legend(fontsize=10)
        axes[0].grid(visible=True, linestyle='--', alpha=0.6)

        # Plot training and validation loss
        axes[1].plot(epochs, data.get("train_loss", []), label="Training Loss", marker="s", linestyle="-")
        axes[1].plot(epochs, data.get("val_loss", []), label="Validation Loss", marker="s", linestyle="--")
        axes[1].set_title(f"Training vs Validation Loss: {model.replace('base_', '').upper()}", fontsize=14, fontweight='bold')
        axes[1].set_xlabel("Epochs", fontsize=12)
        axes[1].set_ylabel("Loss", fontsize=12)
        axes[1].legend(fontsize=10)
        axes[1].grid(visible=True, linestyle='--', alpha=0.6)

        # Set ticks for epochs
        axes[1].set_xticks(epochs)
        axes[1].set_xticklabels(epochs, fontsize=10)

        # Adjust layout and save the plot
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{model.replace('base_', '').lower()}_metrics.png"), dpi=300, format="png")
        plt.close()

# Bar graph for inference time
inference_times = [t * 1000 for t in inference_times]  # Convert to ms
plt.figure(figsize=(8, 5))
bars = plt.bar(model_names, inference_times, color='royalblue', edgecolor='black')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2.0, height, f'{height:.4f} ms', ha='center', va='bottom', fontsize=10)
plt.title("Inference Time (ms)", fontsize=14, fontweight='bold')
plt.xlabel("Models", fontsize=12)
plt.ylabel("Time (ms)", fontsize=12)
plt.grid(axis="y", linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "inference_time.png"), dpi=300, format="png")
plt.close()

# Bar graph for testing accuracy
plt.figure(figsize=(8, 5))
bars = plt.bar(model_names, testing_accuracies, color='forestgreen', edgecolor='black')
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2.0, height, f'{height:.4f}', ha='center', va='bottom', fontsize=10)
plt.title("Testing Accuracy", fontsize=14, fontweight='bold')
plt.xlabel("Models", fontsize=12)
plt.ylabel("Accuracy")
plt.ylim(0, 1.0)  # Ensure accuracy scale is 0-1
plt.grid(axis="y", linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "testing_accuracy.png"), dpi=300, format="png")
plt.close()



=== Model Testing Metrics ===
    Model F1 Score Testing Accuracy Inference Time (ms)
0     SVM   0.7252           0.7323              0.0005
1     KNN   0.5112           0.5051              1.3251
2     CNN   0.9590           0.9593              0.9763
3  RESNET   0.9816           0.9818              4.5004
4     VIT   0.9161           0.9177              3.9087

=== Confusion Matrix for SVM ===
    0    1    2    3    4    5   6    7    8    9   ...   33   34   35   36  \
0    8   40    2    0   10    0   0    0    0    0  ...    0    0    0    0   
1    2  566   77    0   40    8   0   10    4    0  ...    0    0    1    0   
2    2   97  572   15   27   23   0    2    1    1  ...    0    0    0    0   
3    0   26   53  209   25   72   2   13    5    3  ...    1    1    0    0   
4    2   66   45   14  415   22   0   11   53    0  ...    0    0    0    0   
5    1   71  182   24   84  192   1    7   12    1  ...    0    0    0    0   
6    0    0    1    0   22    7  89    0    3 