In [2]:
#%% [code]
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas.plotting import parallel_coordinates

# Ensure plots directory exists
os.makedirs('plots', exist_ok=True)

# Load evaluation data
# If eval.csv exists in the same directory, load it; else, raise an informative error
try:
    df = pd.read_csv('eval.csv')
except FileNotFoundError:
    raise FileNotFoundError("eval.csv not found in the current directory.")

# Define metrics and prepare inverted log_loss for radar plotting
metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc', 'mcc', 'specificity', 'log_loss']
# Create a new column for inverted log loss (so higher is better)
df['inv_log_loss'] = 1 / (1 + df['log_loss'])

# Final list of metrics for plotting (using inv_log_loss instead of log_loss)
plot_metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc', 'mcc', 'specificity', 'inv_log_loss']

# 1. Heatmap of Average Metrics per Model
plt.figure(figsize=(10, 6))
sns.heatmap(df.set_index('model_index')[plot_metrics], annot=True, fmt='.3f', cmap='YlGnBu')
plt.title('Heatmap of Average Metrics per Model')
plt.xlabel('Metrics')
plt.ylabel('Model Index')
plt.tight_layout()
plt.savefig('plots/heatmap_avg.png')
plt.close()

# 2. Radar Chart (Spider Plot)
angles = np.linspace(0, 2 * np.pi, len(plot_metrics), endpoint=False).tolist()
angles += angles[:1]  # complete the loop

fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
for _, row in df.iterrows():
    values = row[plot_metrics].tolist()
    values += values[:1]
    ax.plot(angles, values, label=f"Model {int(row['model_index'])}")
    ax.fill(angles, values, alpha=0.1)

ax.set_xticks(angles[:-1])
ax.set_xticklabels(plot_metrics)
ax.set_title('Radar Chart of Average Metrics per Model', pad=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.2, 1.1))
plt.tight_layout()
plt.savefig('plots/radar_avg.png')
plt.close()

# 3. Parallel Coordinates Plot
pc_df = df.copy()
pc_df['model'] = pc_df['model_index'].apply(lambda x: f"Model {int(x)}")
plt.figure(figsize=(12, 6))
parallel_coordinates(pc_df[['model'] + plot_metrics], 'model', colormap=plt.get_cmap('tab10'))
plt.title('Parallel Coordinates Plot of Average Metrics per Model')
plt.xlabel('Metrics')
plt.ylabel('Scaled Value')
plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1.0))
plt.tight_layout()
plt.savefig('plots/parallel_avg.png')
plt.close()

# 4. Grouped Bar Chart
melted = df.melt(id_vars=['model_index'], value_vars=plot_metrics, var_name='Metric', value_name='Value')
plt.figure(figsize=(12, 6))
sns.barplot(data=melted, x='Metric', y='Value', hue='model_index', palette='tab20')
plt.title('Grouped Bar Chart of Average Metrics per Model')
plt.xlabel('Metric')
plt.ylabel('Value')
plt.legend(title='Model Index', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('plots/grouped_bar_avg.png')
plt.close()

print("All plots saved in the 'plots/' directory.")


All plots saved in the 'plots/' directory.
