In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.gridspec as gridspec

from qmlhep.config import results_path, figures_path
# This will load plot configuration & style
from qmlhep.utils.plot_results import *

## Kmeans Results

This notebook will plot the best performing VQC on the kmeans dataset.

Author: Miguel Caçador Peixoto

# 

#### Load Data

In [None]:
df = pd.read_csv(join(results_path, "kmeans_results.csv"))

# Replace regular with ReguLar and kmeans with KMeans on regime column
df["regime"] = df["regime"].str.replace("regular", "Regular")
df["regime"] = df["regime"].str.replace("kmeans", "KMeans")
df

#### Plot Results

In [None]:
fig = plt.figure(figsize=(17, 17))

sns.set_style("darkgrid")

# Create a grid of 2x2
gs = gridspec.GridSpec(4, 4)
ax_qml = plt.subplot(gs[:2, :2])
ax_svm = plt.subplot(gs[:2, 2:])
ax_lr = plt.subplot(gs[2:4, 1:3])

# Plot QML results
sns.lineplot(x="n_datapoints", y="auc", hue="regime",  data=df[df['model'] == 'qml'], ax=ax_qml, markers=True,
                    dashes=True,
                    markersize=8,
                    linewidth=2,)

# Plot SVM results
sns.lineplot(x="n_datapoints", y="auc", hue="regime",  data=df[df['model'] == 'qml'], ax=ax_svm,
                    dashes=True,
                    markersize=8,
                    linewidth=2,)

# Plot Logistic Regression results
sns.lineplot(x="n_datapoints", y="auc", hue="regime",  data=df[df['model'] == 'qml'], ax=ax_lr,
                    dashes=True,
                    markersize=8,
                    linewidth=2,)

# Set titles
ax_svm.set_title("SVM", fontsize=MEDIUM_SIZE)
ax_qml.set_title("QML", fontsize=MEDIUM_SIZE)
ax_lr.set_title("Log. Reg.", fontsize=MEDIUM_SIZE)

# Share y
ax_qml.set_ylim(0.65, 0.9)
ax_svm.set_ylim(0.65, 0.9)
ax_lr.set_ylim(0.65, 0.9)

# Set x labels and ticks
for ax in [ax_qml, ax_svm, ax_lr]:
    ax.set_xlabel("#Datapoints", fontsize=MEDIUM_SIZE)
    ax.set_ylabel("AUC Score", fontsize=MEDIUM_SIZE)
    ax.legend(fontsize=LEGEND_SIZE)
    ax.tick_params(axis="both", which="major", labelsize=TICK_SIZE)

plt.tight_layout()
plt.savefig(join(figures_path, "kmeans_study_results.pdf"))