In [1]:
import pandas as pd
import numpy as np
from session_rec.algorithms.knn import vsknn
from session_rec.evaluation import evaluation_last, evaluation
from session_rec.evaluation.metrics.accuracy import HitRate
from utils.experiment_setup import gru4rec_vsknn_setups as setups
import os

In [2]:
metrics = [HitRate(1), HitRate(5), HitRate(10), HitRate(20)]
# After running vsknn_optimize.py, the best parameters can also be found in the corresponding best parameter files under: data/results/vsknn_paropt/*train_tr_best.json
experiments = {
    "full": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"},
    "91D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_91D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"},
    "56D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_56D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"},
    "28D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_28D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"},
    "14D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_14D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"},
    "7D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_7D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv"}
}

In [3]:
for exp_name, exp_setup in experiments.items():
    print(exp_name)
    train_data = pd.read_csv(exp_setup["train_file"], sep='\t')
    train_data = train_data.sort_values(by=["SessionId", "Time", "ItemId"], ascending=True)
    test_data = pd.read_csv(exp_setup["test_file"], sep='\t')
    test_data = test_data.sort_values(by=["SessionId", "Time", "ItemId"], ascending=True)

    print("\t", train_data.shape, test_data.shape)
    train_unique_items = train_data["ItemId"].unique()
    test_data = test_data[test_data["ItemId"].isin(train_unique_items)]
    session_length_test = test_data.groupby("SessionId").size()
    test_data = test_data[test_data["SessionId"].isin(session_length_test[session_length_test > 1].index)]

    print("\t", train_data.shape, test_data.shape, train_data.SessionId.nunique(), test_data.SessionId.nunique())
    fn = os.path.split(exp_setup["train_file"])[-1][:-4]
    best_params_key = f"{fn.split('_')[0]}_{fn.split('_')[-1]}"
    best_params = setups[best_params_key]["vsknn_params"]
    model = vsknn.VMContextKNN(**best_params)
    model.fit(train_data)
    print("\t", "START eval, scoreing ALL")
    res = evaluation.evaluate_sessions(pr=model, metrics=metrics, test_data=test_data, train_data=train_data)
    for r in res:
        print("\t", r[0], f"{r[1]:.6f}")

full
	 (750832, 3) (29148, 3)
	 (750832, 3) (28965, 3) 196234 8013
	 START eval, scoreing ALL
	 HitRate@1:  0.040951
	 HitRate@5:  0.302596
	 HitRate@10:  0.398244
	 HitRate@20:  0.475372
91D
	 (502044, 3) (29148, 3)
	 (502044, 3) (28815, 3) 132431 7974
	 START eval, scoreing ALL
	 HitRate@1:  0.041313
	 HitRate@5:  0.297011
	 HitRate@10:  0.388033
	 HitRate@20:  0.460103
56D
	 (287593, 3) (29148, 3)
	 (287593, 3) (28133, 3) 78370 7799
	 START eval, scoreing ALL
	 HitRate@1:  0.042736
	 HitRate@5:  0.291974
	 HitRate@10:  0.376020
	 HitRate@20:  0.446838
28D
	 (128942, 3) (29148, 3)
	 (128942, 3) (26087, 3) 35669 7242
	 START eval, scoreing ALL
	 HitRate@1:  0.055771
	 HitRate@5:  0.283470
	 HitRate@10:  0.363651
	 HitRate@20:  0.423083
14D
	 (65132, 3) (29148, 3)
	 (65132, 3) (23029, 3) 18105 6471
	 START eval, scoreing ALL
	 HitRate@1:  0.064017
	 HitRate@5:  0.278536
	 HitRate@10:  0.350948
	 HitRate@20:  0.406812
7D
	 (33843, 3) (29148, 3)
	 (33843, 3) (19292, 3) 9421 5508
	 START 