In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Directory containing the output CSV files
output_csv_directory = './PR_Benefit/'
# Directory to save the plots
output_plots_directory = './output_plots_by_model_PR_Benefit/'
os.makedirs(output_plots_directory, exist_ok=True)

# Read the file containing the best models info
best_models_info = pd.read_csv('./best_models_infoPR_Benefit.csv')

# Function to extract the relevant part of the filename for the title
def extract_title_part(filename):
    return filename.split('_')[0]  # Adjust the split index based on the filename format

# Dictionary to store data for each best model
best_model_data_dict = {}

# Read each CSV file and store the data for the best models
csv_files = glob.glob(os.path.join(output_csv_directory, '*.csv'))

for csv_file in csv_files:
    data = pd.read_csv(csv_file)
    # Extract the part of the filename to match with best_models_info
    csv_file_basename = os.path.basename(csv_file).replace('output_o', 'o').replace('_PR_Benefit.csv', '')

    best_model_row = best_models_info[best_models_info['csv_file'] == csv_file_basename]
    if not best_model_row.empty:
        best_model_name = best_model_row['model_name'].values[0]
        accuracy = best_model_row['accuracy'].values[0]
        if best_model_name in data['Model'].unique():
            best_model_data = data[data['Model'] == best_model_name]
            best_model_data_dict[(best_model_name, csv_file_basename, accuracy)] = best_model_data
# Function to plot confusion matrix for each best model
def plot_confusion_matrix(ax, model_name, csv_filename, model_data, accuracy):
    actual = model_data['Actual']
    predicted = model_data['Predicted']
    cm = confusion_matrix(actual, predicted, labels=[0, 1, 2])
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[0, 1, 2], yticklabels=[0, 1, 2], ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    title_part = extract_title_part(csv_filename)
    ax.set_title(f'Confusion Matrix for {model_name} ({title_part})\nAccuracy: {accuracy}')

# Plot the top 3 confusion matrices based on accuracy
top_3_models = sorted(best_model_data_dict.items(), key=lambda x: x[0][2], reverse=True)[:3]
num_models = len(top_3_models)
fig, axes = plt.subplots(num_models, 1, figsize=(10, 8 * num_models))

if num_models == 1:
    axes = [axes]  # Ensure axes is iterable if there's only one subplot

for ax, ((model_name, csv_filename, accuracy), model_data) in zip(axes, top_3_models):
    plot_confusion_matrix(ax, model_name, csv_filename, model_data, accuracy)

plt.tight_layout(pad=3.0)
plot_filename = os.path.join(output_plots_directory, "top_3_confusion_matrices.png")
plt.savefig(plot_filename)
plt.close()

print("Top 3 confusion matrices saved to output_plots_by_model")

Top 3 confusion matrices saved to output_plots_by_model
