In [86]:
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
ltr_MQ2007 = pd.read_csv("../output/ltr_MQ2007_metrics.csv")
ltr_MQ2008 = pd.read_csv("../output/ltr_MQ2008_metrics.csv")
ltr_MSLR10K = pd.read_csv("../output/ltr_MSLR10K_metrics.csv")
ltr_MSLR30K = pd.read_csv("../output/ltr_MSLR30K_metrics.csv")

In [None]:
datasets = {"MQ2007": ltr_MQ2007, "MSLR-WEB10K": ltr_MSLR10K}
linestyle = {"linear": "-", "neural": "--"}
marker = {"informational": "o", "navigational": "^"}
color = {"MQ2007": "#ff7f0e", "MSLR-WEB10K": "#1f77b4"}

for dataset_name, df in datasets.items():
    df = df[["name", "auc"]].groupby("name").describe()
    df.columns = df.columns.droplevel()

    df["query"] = df.index.map(lambda x: int(x.split("_")[x.split("_").index("query") - 1])).values
    df["click_model"] = df.index.map(lambda x: x.split("_")[x.split("_").index("query") - 2]).values
    df["model"] = df.index.map(lambda x: x.split("_")[0]).values
    df["data"] = dataset_name

    df = df[df.index.str.contains("eps_inf")]
    df = df[df["query"] <= 16]
    df = df[df["model"] != "random"]

    df = df[["mean", "query", "model", "click_model", "data"]].reset_index(drop=True)

    for model in ["linear", "neural"]:
        for click_model in ["informational", "navigational"]:
            line_df = df[(df["model"] == model) & (df["click_model"] == click_model)].sort_values(by=["query"])
            plt.plot(line_df["query"].astype(str).tolist(), line_df["mean"].tolist(), linestyle=linestyle[model], marker=marker[click_model], color=color[dataset_name], label=f"{dataset_name}: {model} + {click_model}")

plt.xlabel('Number of queries')
plt.ylabel('Mean AUC')
plt.tight_layout()

legend_entries = []
for label, value in linestyle.items():
    legend_entries.append(mlines.Line2D([], [], color='black', linestyle=value, label=label))
for label, value in marker.items():
    legend_entries.append(mlines.Line2D([], [], color='black', marker=value, label=label))
for label, value in color.items():
    legend_entries.append(mlines.Line2D([], [], color=value, label=label))
plt.legend(handles=legend_entries)

plt.savefig("../plots/plain_metrics.pdf", bbox_inches='tight')
plt.show()
