In [5]:
import json
import os

def load_metrics(folder):
    metrics = {}
    for file_name in os.listdir(folder):
        model_name = file_name.split('.')[0]
        with open(os.path.join(folder, file_name), 'r') as f:
            metrics[model_name] = json.load(f)
    return metrics

original_metrics = load_metrics('./models/results/original/')
rebalanced_metrics = load_metrics('./models/results/rebalanced/')

In [17]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_metric_comparison(original_metrics, rebalanced_metrics, metric_name, output_file=None):
    models = list(original_metrics.keys())
    
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Original", "Rebalanced"))
    
    # Format the values to two decimal places
    original_values = [round(original_metrics[model][metric_name], 2) for model in models]
    rebalanced_values = [round(rebalanced_metrics[model][metric_name], 2) for model in models]
    
    # Adjust y-axis range to focus on differences
    y_min = min(min(original_values), min(rebalanced_values)) - 0.02
    y_max = max(max(original_values), max(rebalanced_values)) + 0.02
    
    
    fig.add_trace(
        go.Bar(x=models, y=original_values, name="Original Dataset", text=original_values, textposition='auto'),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Bar(x=models, y=rebalanced_values, name="Balanced Dataset", text=rebalanced_values, textposition='auto'),
        row=1, col=2
    )
    
    fig.update_yaxes(range=[y_min, y_max], title_text=metric_name.capitalize(), row=1, col=1)
    fig.update_yaxes(range=[y_min, y_max], row=1, col=2)


    if metric_name == 'accuracy':
        metric_name = 'Accuracy'
    elif metric_name == 'precision':
        metric_name = 'Precision'
    elif metric_name == 'recall':
        metric_name = 'Recall'
    elif metric_name == 'f1_score':
        metric_name = 'F1 Score'
    elif metric_name == 'roc_auc':
        metric_name = 'ROC AUC'
    elif metric_name == 'mcc':
        metric_name = 'Matthews Correlation Coefficient'
    elif metric_name == 'balanced_accuracy':
        metric_name = 'Balanced Accuracy'
    else:
        pass
    
    
    fig.update_layout(
        title_text=metric_name,
        xaxis_title="Model",
        showlegend=False,
        height=600,
        width=1200,
        template="plotly_white"
    )
    
    if output_file:
        fig.write_image(output_file)
    
    fig.show()

In [18]:
metrics_to_plot = [
    "accuracy", "precision", "recall", "f1_score", "mcc", "balanced_accuracy"
]

for metric in metrics_to_plot:
    plot_metric_comparison(original_metrics, rebalanced_metrics, metric)