In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.cm as cm

def plot_metrics(df):
    # Set up the figure and axes
    fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), dpi=200, sharex=True)

    # Define the metrics to plot
    metrics = ['dice', 'rel_vol_error', 'hausdorff']
    y_labels = ['Dice Coefficient ($\\uparrow$)', 'Rel. Volume Error ($\\downarrow$)', 'Hausdorff Distance ($\\downarrow$)']

    # Define the colors and styles
    unique_epsilons = sorted(df['epsilon'].unique())

    # Drop nan from unique_epsilons
    unique_epsilons = [x for x in unique_epsilons if x == x]

    colormap = cm.get_cmap('viridis', len(unique_epsilons))
    epsilon_to_color = {epsilon: colormap(i) for i, epsilon in enumerate(unique_epsilons)}
    linestyle_dict = {
        'HEAVISIDE': 'dashed',
        'SIGNED_DISTANCE_EXACT': 'dashed',
        'SIGNED_DISTANCE_APPROXIMATE': 'dashed',
        'TANH_EPSILON': 'solid'
    }

    # Plot each metric
    for i, metric in enumerate(metrics):
        for interface_type in df['interface_type'].unique():
            subset = df[df['interface_type'] == interface_type].sort_values(by='mu')
            if 'TANH' in interface_type:
                for epsilon in sorted(subset['epsilon'].unique()):
                    epsilon_subset = subset[subset['epsilon'] == epsilon]
                    mean_values = epsilon_subset.groupby('mu')[metric].mean()
                    label = f'Tanh 1/{1/epsilon:.0f}' if epsilon else 'tanh'
                    axs[i].plot(mean_values.index, mean_values.values, label=label, color=epsilon_to_color[epsilon], linestyle=linestyle_dict[interface_type], marker='o')
            else:
                mean_values = subset.groupby('mu')[metric].mean()
                label = 'Sharp' if interface_type == 'HEAVISIDE' else 'SDF'
                axs[i].plot(mean_values.index, mean_values.values, label=label, linestyle=linestyle_dict[interface_type], marker='o')

        axs[i].set_xlabel('$\\mu$')
        axs[i].set_ylabel(y_labels[i])
        axs[i].spines['top'].set_visible(False)
        axs[i].spines['right'].set_visible(False)

    axs[0].set_ylim(0.6, 1)
    axs[1].set_ylim(0, 0.25)

    # Add a legend outside the plot
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5))

    plt.tight_layout()
    plt.savefig('figures/fig_07.png')
    plt.show()

df = pd.read_csv('evaluation_results_2.csv')
plot_metrics(df)