In [1]:
import pandas as pd
import numpy as np
from session_rec.algorithms.knn import vsknn
from session_rec.evaluation import evaluation_last, evaluation
from session_rec.evaluation.metrics.accuracy import HitRate

In [2]:
metrics = [HitRate(1), HitRate(5), HitRate(10), HitRate(20)]
# After running vsknn_optimize.py, the best parameters can be found in the corresponding best parameter files under: data/results/vskk_paropt/*train_tr_best.json
experiments_vsknn = {
    "full": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 500, "sample_size": 1000, "weighting": "log", "weighting_score": "same", "idf_weighting": 2}},
    "91D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_91D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 1000, "sample_size": 1000, "weighting": "log", "weighting_score": "same", "idf_weighting": 10}},
    "56D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_56D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 500, "sample_size": 500, "weighting": "log", "weighting_score": "same", "idf_weighting": 5}},
    "28D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_28D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 1000, "sample_size": 500, "weighting": "same", "weighting_score": "log", "idf_weighting": 5}},
    "14D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_14D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 500, "sample_size": 500, "weighting": "same", "weighting_score": "log", "idf_weighting": 1}},
    "7D": {
        "train_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_train_full_7D.tsv",
        "test_file": "data/data_sources/retailrocket_ecommerce/retailrocket_processed_view_test.tsv",
        "best_params": {"k": 1500, "sample_size": 2500, "weighting": "log", "weighting_score": "same", "idf_weighting": 10}}
}

In [3]:
for exp_name, setup in experiments_vsknn.items():
    print(exp_name)
    train_data = pd.read_csv(setup["train_file"], sep='\t')
    train_data = train_data.sort_values(by=["SessionId", "Time", "ItemId"], ascending=True)
    test_data = pd.read_csv(setup["test_file"], sep='\t')
    test_data = test_data.sort_values(by=["SessionId", "Time", "ItemId"], ascending=True)

    print("\t", train_data.shape, test_data.shape)
    train_unique_items = train_data["ItemId"].unique()
    test_data = test_data[test_data["ItemId"].isin(train_unique_items)]
    session_length_test = test_data.groupby("SessionId").size()
    test_data = test_data[test_data["SessionId"].isin(session_length_test[session_length_test > 1].index)]

    print("\t", train_data.shape, test_data.shape, train_data.SessionId.nunique(), test_data.SessionId.nunique())

    model = vsknn.VMContextKNN(**setup["best_params"])
    model.fit(train_data)
    print("\t", "START eval, scoreing ALL")
    res = evaluation.evaluate_sessions(pr=model, metrics=metrics, test_data=test_data, train_data=train_data)
    for r in res:
        print("\t", r[0], r[1])

full
	 (750832, 3) (29148, 3)
	 (750832, 3) (28965, 3) 196234 8013
	 START eval, scoreing ALL
	 HitRate@1:  0.04095074455899198
	 HitRate@5:  0.30259641084383354
	 HitRate@10:  0.39824360442917145
	 HitRate@20:  0.4753722794959908
91D
	 (502044, 3) (29148, 3)
	 (502044, 3) (28815, 3) 132431 7974
	 START eval, scoreing ALL
	 HitRate@1:  0.041312796890744206
	 HitRate@5:  0.29701070006237706
	 HitRate@10:  0.3880332037810086
	 HitRate@20:  0.46010268221294565
56D
	 (287593, 3) (29148, 3)
	 (287593, 3) (28133, 3) 78370 7799
	 START eval, scoreing ALL
	 HitRate@1:  0.04273630372774663
	 HitRate@5:  0.29197403363824137
	 HitRate@10:  0.37602045834562803
	 HitRate@20:  0.44683780859643946
28D
	 (128942, 3) (29148, 3)
	 (128942, 3) (26087, 3) 35669 7242
	 START eval, scoreing ALL
	 HitRate@1:  0.05577076147519236
	 HitRate@5:  0.28347041655611566
	 HitRate@10:  0.36365083576545504
	 HitRate@20:  0.42308304590076945
14D
	 (65132, 3) (29148, 3)
	 (65132, 3) (23029, 3) 18105 6471
	 START eval, s