In [32]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pyrootutils
from IPython.display import display

pyrootutils.set_root(os.path.abspath(".."), pythonpath=True)
    
from src.utils.data import load_news, load_behaviors
from src.utils.hydra import RunCollection

DATA_DIR = "../data"


In [33]:
def get_metrics(runs, split="dev", best_on="AUC", index_by=None, **kwargs):
    dev_metrics = pd.concat([pd.read_csv(run.get_file_path("metrics_dev.csv")).assign(run=run) for run in runs])
    dev_metrics = dev_metrics.reset_index(drop=True)
    best_epoch = dev_metrics.groupby("run")[best_on].idxmax()
    
    metrics = pd.concat([pd.read_csv(run.get_file_path(f"metrics_{split}.csv")).assign(run=run) for run in runs])
    metrics = metrics.reset_index(drop=True)
    best_metrics = metrics.loc[best_epoch]
    for key, value in kwargs.items():
        best_metrics[key] = best_metrics["run"].apply(value)
    if index_by:
        return best_metrics.set_index(index_by, drop=True).sort_index()
    return best_metrics

def get_metrics_per_epoch(runs, split="dev", index_by=None, **kwargs):
    metrics = pd.concat([pd.read_csv(run.get_file_path(f"metrics_{split}.csv")).assign(run=run) for run in runs])
    metrics = metrics.reset_index(drop=True)
    for key, value in kwargs.items():
        metrics[key] = metrics["run"].apply(value)
    if index_by:
        return metrics.set_index(index_by, drop=True).sort_index()
    return metrics

In [34]:
hparam_runs = RunCollection.from_path("../multirun")\
    .filter_by_job("train_recommender")\
    .filter(lambda run: "hparams" in run.config.tags)\
    .filter(lambda run: "metrics_dev.csv" in run.list_files()) # only completed runs


In [35]:
baseline_runs = hparam_runs.filter_by_override("+experiment", "baseline")
baseline_metrics = get_metrics(
    baseline_runs,
    index_by=["model", "history_mask"],
    model=lambda run: run.overrides["+model"],
    history_mask=lambda run: run.config.use_history_mask,
)
display(baseline_metrics)
display(baseline_metrics.loc[baseline_metrics.groupby("model")["AUC"].idxmax()])

Unnamed: 0_level_0,Unnamed: 1_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
naml,False,0.684824,0.32987,0.365823,0.428846,4,Run(path=../multirun/2023-06-12/16-21-31/0)
naml,True,0.677233,0.320281,0.356732,0.420236,5,Run(path=../multirun/2023-06-12/16-21-31/1)
nrms,False,0.658464,0.309063,0.342386,0.407371,5,Run(path=../multirun/2023-06-12/14-44-55/0)
nrms,True,0.678373,0.327599,0.361544,0.424687,5,Run(path=../multirun/2023-06-12/14-44-55/1)
tanr,False,0.66975,0.318366,0.350656,0.414704,5,Run(path=../multirun/2023-06-12/14-44-55/2)
tanr,True,0.666312,0.315725,0.348641,0.411907,5,Run(path=../multirun/2023-06-12/14-44-55/3)


Unnamed: 0_level_0,Unnamed: 1_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
naml,False,0.684824,0.32987,0.365823,0.428846,4,Run(path=../multirun/2023-06-12/16-21-31/0)
nrms,True,0.678373,0.327599,0.361544,0.424687,5,Run(path=../multirun/2023-06-12/14-44-55/1)
tanr,False,0.66975,0.318366,0.350656,0.414704,5,Run(path=../multirun/2023-06-12/14-44-55/2)


In [36]:
hierec_runs = hparam_runs.filter(lambda run: "hierec" in run.overrides["+experiment"])
hierec_metrics = get_metrics(
    hierec_runs,
    index_by=["model", "history_mask"],
    model=lambda run: run.overrides["+experiment"].split("_")[0],
    history_mask=lambda run: run.config.use_history_mask,
)
display(hierec_metrics)
display(hierec_metrics.loc[hierec_metrics.groupby("model")["AUC"].idxmax()])

Unnamed: 0_level_0,Unnamed: 1_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
naml,False,0.662146,0.312854,0.345113,0.410379,2,Run(path=../multirun/2023-06-12/16-03-44/0)
naml,True,0.67425,0.319615,0.353117,0.417583,5,Run(path=../multirun/2023-06-12/16-03-44/1)
nrms,False,0.652659,0.311511,0.342616,0.407263,4,Run(path=../multirun/2023-06-12/14-57-07/0)
nrms,True,0.676315,0.328161,0.362376,0.425634,5,Run(path=../multirun/2023-06-12/14-57-07/1)
tanr,False,0.658235,0.314197,0.34695,0.411404,4,Run(path=../multirun/2023-06-12/14-57-07/2)
tanr,True,0.667404,0.318588,0.350952,0.415962,5,Run(path=../multirun/2023-06-12/14-57-07/3)


Unnamed: 0_level_0,Unnamed: 1_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
naml,True,0.67425,0.319615,0.353117,0.417583,5,Run(path=../multirun/2023-06-12/16-03-44/1)
nrms,True,0.676315,0.328161,0.362376,0.425634,5,Run(path=../multirun/2023-06-12/14-57-07/1)
tanr,True,0.667404,0.318588,0.350952,0.415962,5,Run(path=../multirun/2023-06-12/14-57-07/3)


In [37]:
multi_interest_runs = hparam_runs.filter(lambda run: "multi_interest" in run.overrides["+experiment"])
multi_interest_metrics = get_metrics(
    multi_interest_runs,
    index_by=["model", "history_mask", "n_interest_vectors"],
    model=lambda run: run.overrides["+experiment"].split("_")[0],
    history_mask=lambda run: run.config.use_history_mask,
    n_interest_vectors=lambda run: run.config.model.user_encoder.n_interest_vectors
)
display(multi_interest_metrics)
multi_interest_metrics = multi_interest_metrics.reset_index(level="n_interest_vectors")
multi_interest_metrics = multi_interest_metrics[multi_interest_metrics["n_interest_vectors"] != 1]
multi_interest_metrics = multi_interest_metrics.set_index("n_interest_vectors", append=True)
display(multi_interest_metrics.loc[multi_interest_metrics.groupby("model")["AUC"].idxmax()])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,n_interest_vectors,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
naml,False,1,0.684764,0.329485,0.365053,0.428359,4,Run(path=../multirun/2023-06-12/14-56-49/0)
naml,False,4,0.627116,0.287156,0.31363,0.37665,1,Run(path=../multirun/2023-06-12/14-56-49/2)
naml,False,16,0.650986,0.300137,0.329913,0.394253,3,Run(path=../multirun/2023-06-12/14-56-49/4)
naml,False,32,0.668336,0.310847,0.340232,0.405489,2,Run(path=../multirun/2023-06-12/14-56-49/6)
naml,False,48,0.673778,0.320066,0.352016,0.415405,2,Run(path=../multirun/2023-06-12/14-56-49/8)
naml,False,64,0.674898,0.319567,0.35181,0.416314,4,Run(path=../multirun/2023-06-12/14-56-49/10)
naml,True,1,0.676562,0.322725,0.358234,0.420974,3,Run(path=../multirun/2023-06-12/14-56-49/1)
naml,True,4,0.630421,0.299526,0.325673,0.387773,5,Run(path=../multirun/2023-06-12/14-56-49/3)
naml,True,16,0.653178,0.309235,0.338908,0.401661,5,Run(path=../multirun/2023-06-12/14-56-49/5)
naml,True,32,0.655951,0.316177,0.345929,0.409312,5,Run(path=../multirun/2023-06-12/14-56-49/7)


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,AUC,MRR,NDCG@5,NDCG@10,epoch,run
model,history_mask,n_interest_vectors,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
naml,False,64,0.674898,0.319567,0.35181,0.416314,4,Run(path=../multirun/2023-06-12/14-56-49/10)
nrms,False,16,0.669293,0.321003,0.354082,0.41758,5,Run(path=../multirun/2023-06-12/14-46-08/4)
tanr,True,64,0.650847,0.309249,0.3378,0.401798,5,Run(path=../multirun/2023-06-12/14-47-09/11)
