In [10]:
import os
import pandas as pd


def load_results(path="../scripts/data/benchmarks-2"):

    _results = {}

    for root, dirs, files in os.walk(path):
        for file in files:
            path, factors = os.path.split(root)
            _, model = os.path.split(path)
            key=f"{model}-{factors}"
            _results[key] = pd.read_csv(os.path.join(root, file), index_col=0)

    return _results


def get_best_stats(results, key, filtered=lambda _: False):
    stats = []
    for k, df in results.items():
        print(k, filtered(k))
        if not filtered(k):
            best_vals = [df[_].max() for _ in df.columns if _.startswith(key)]
            stats.append(pd.DataFrame(data=zip(list(range(1, len(best_vals))), best_vals, [k for _ in range(len(best_vals))]), columns=["k", key, "key"]))

    return pd.concat(stats)


def get_best_stats_by_factor(results, k, filtered=lambda _: False):
    stats = []
    for key, df in results.items():
        if not filtered(key):
            label, factors = key.split("-")
            stats.append([label, int(factors), df[f"ndcg{k}"].max(), df[f"hr{k}"].max()])

    return pd.DataFrame(data=stats, columns=["key", "factors", "ndcg", "hr"])


def plot_best_stats(results, key, **kwargs):
    stats = get_best_stats(results, key, **kwargs)
    return alt.Chart(stats).mark_line().encode(
        x="k:O",
        y=alt.Y(f'{key}:Q',
            scale=alt.Scale(zero=False)
        ),
        color="key"
    )


def plot_best_stats_by_factor(results, key, k):
    stats = get_best_stats_by_factor(results, k)

    return alt.Chart(stats.sort_values(by="factors")).mark_line().encode(
        x=alt.X('factors:O',
            scale=alt.Scale(zero=False)
        ),
        y=alt.Y(f'{key}:Q',
            scale=alt.Scale(zero=False)
        ),
        color="key"
    )


def filter_on(factors):
    return lambda _: False if int(_.split("-")[1]) == factors else True


In [11]:
import altair as alt

results = load_results()

plot_best_stats(results, "ndcg")

gmf-8 False
nmf-8 False
als-8 False
mlp-8 False
bpr-8 False


In [13]:
plot_best_stats(results, "hr")

gmf-8 False
nmf-8 False
als-8 False
mlp-8 False
bpr-8 False


In [14]:
plot_best_stats_by_factor(results, "ndcg", 10)

In [15]:
plot_best_stats_by_factor(results, "hr", 10)