In [1]:
import pandas as pd
import seaborn as sns

In [2]:
similarities = pd.read_csv('../../data/network_method/similarities_10m.csv')
stein_similarities1 = pd.read_csv('../../data/network_method/similarities_10m_stein_0.6_3.csv')
stein_similarities2 = pd.read_csv('../../data/network_method/similarities_10m_stein_0.5_2.csv')
stein_similarities3 = pd.read_csv('../../data/network_method/similarities_10m_stein_0.5_1.csv')
binary_similarities = pd.read_csv('../../data/network_method/similarities_10m_binary.csv')
no_mz_similarities = pd.read_csv('../../data/network_method/similarities_10m_without_mz.csv')

In [5]:
all_similarities = {
    'original' : similarities, 
    'stein_0.6_3': stein_similarities1,
    'stein_0.5_2': stein_similarities2, 
    'stein_0.5_1': stein_similarities3,
    'binary': binary_similarities,
    'no_mz': no_mz_similarities
}

In [6]:
for name in all_similarities.keys():
    all_similarities[name] = all_similarities[name].loc[~pd.isna(all_similarities[name]['tanimoto'])]

In [None]:
kl_div_df = pd.DataFrame(columns=['kl_div'])
tanimoto_kde = stats.gaussian_kde(np.array(list(.tanimoto)))
modified_cosine_kde = stats.gaussian_kde(np.array(list(test.modified_cosine)))
weighted_cosine_kde = stats.gaussian_kde(np.array(list(test.weighted_modified_cosine)))

In [None]:
with sns.plotting_context("paper", font_scale=1.6):

    label_dict = {
        "modified_cosine": "Modified cosine",
        "weighted_modified_cosine": "Modified cosine (weighted)",
        "entropy": "Spectral entropy",
        "weighted_entropy": "Spectral entropy (weighted)",
        "bhattacharya_1": "Bhattacharya distance",
        "weighted_bhattacharya_1": "Bhattacharya 1 (weighted)",
        "fidelity": "Fidelity",
        "weighted_fidelity": "Fidelity (weighted)",
    }

    fig = plt.figure(constrained_layout=True, figsize=(20, 10))

    gs = GridSpec(2, 2, figure=fig)

    ax = fig.add_subplot(gs[0, 0])

    """ Modified cosine similarity """
    sns.violinplot(
        data=similarities_tanimoto,
        x="tanimoto_interval",
        y="value",
        hue="variable",
        hue_order=["weighted_modified_cosine", "modified_cosine"],
        cut=0,
        scale="width",
        scale_hue=False,
        ax=ax,
    )
    ax.set_xlabel("Tanimoto index")
    ax.set_ylabel("Spectrum similarity")
    for label in ax.legend().get_texts():
        label.set_text(
            label_dict[label.get_text()]
        )

    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(.5, 1),
        ncol=3,
        title=None,
        frameon=False,
    )

    sns.despine(ax=ax)

    """Spectral Entropy """
    
    ax = fig.add_subplot(gs[0, 1])

    sns.violinplot(
        data=similarities_tanimoto,
        x="tanimoto_interval",
        y="value",
        hue="variable",
        hue_order=["weighted_entropy", "entropy"],
        cut=0,
        scale="width",
        scale_hue=False,
        ax=ax,
    )
    ax.set_xlabel("Tanimoto index")
    ax.set_ylabel("Spectrum similarity")
    for label in ax.legend().get_texts():
        label.set_text(
            label_dict[label.get_text()]
        )

    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(.5, 1),
        ncol=3,
        title=None,
        frameon=False,
    )

    sns.despine(ax=ax)

    """ bhattacharya_1 """
    ax = fig.add_subplot(gs[1, 0])

    sns.violinplot(
        data=similarities_tanimoto,
        x="tanimoto_interval",
        y="value",
        hue="variable",
        hue_order=["weighted_bhattacharya_1", "bhattacharya_1"],
        cut=0,
        scale="width",
        scale_hue=False,
        ax=ax,
    )
    ax.set_xlabel("Tanimoto index")
    ax.set_ylabel("Spectrum similarity")
    for label in ax.legend().get_texts():
        label.set_text(
            label_dict[label.get_text()]
        )

    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(.5, 1),
        ncol=3,
        title=None,
        frameon=False,
    )

    sns.despine(ax=ax)

    """fidelity"""
    ax = fig.add_subplot(gs[1, 1])

    sns.violinplot(
        data=similarities_tanimoto,
        x="tanimoto_interval",
        y="value",
        hue="variable",
        hue_order=["weighted_fidelity", "fidelity"],
        cut=0,
        scale="width",
        scale_hue=False,
        ax=ax,
    )
    ax.set_xlabel("Tanimoto index")
    ax.set_ylabel("Spectrum similarity")
    for label in ax.legend().get_texts():
        label.set_text(
            label_dict[label.get_text()]
        )

    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(.5, 1),
        ncol=3,
        title=None,
        frameon=False,
    )

    sns.despine(ax=ax)

    # # Save figure.
    plt.savefig("benchmark_metrics.png", dpi=400, bbox_inches="tight")
    plt.show()
    plt.close()