In [1]:
import pandas as pd
import numpy as np
import sys

REPO_RELATIVE_PATH = '..'

if REPO_RELATIVE_PATH not in sys.path:
    sys.path.append(REPO_RELATIVE_PATH)

In [2]:
from src.metrics.metrics import precision_at_k, recall_at_k

In [5]:
items = pd.read_csv(
    f'{REPO_RELATIVE_PATH}/datasets/ml-1m/ml-1m.item',
    sep='\t',
    header=0,
    names=['item_id', 'movie_title', 'release_year', 'genre']
)
random_val = pd.read_csv(
    'splits/random/val.csv',
    sep='\t',
    header=0
)
by_user_val = pd.read_csv(
    'splits/by_user/val.csv',
    sep='\t',
    header=0
)
lol_val = pd.read_csv(
    'splits/leave_one_last/val.csv',
    sep='\t',
    header=0
)
t_user_val = pd.read_csv(
    'splits/temporal_user/val.csv',
    sep='\t',
    header=0
)
t_global_val = pd.read_csv(
    'splits/temporal_global/val.csv',
    sep='\t',
    header=0
)

In [6]:
items['item_id'].head()

0    1
1    2
2    3
3    4
4    5
Name: item_id, dtype: int64

In [7]:
random_recommendations = items.sample(frac=1, random_state=42).reset_index()['item_id']
random_recommendations.head()

0    1365
1    2706
2    3667
3    3684
4    1881
Name: item_id, dtype: int64

In [8]:
random_val

Unnamed: 0,user_id,item_id,rating,timestamp
0,2857,2268,4,972511121
1,1172,2710,3,974867608
2,5376,3681,5,960409841
3,4362,246,5,965188352
4,2197,246,4,974606004
...,...,...,...,...
100016,4253,1348,2,965678247
100017,3367,2080,4,970081461
100018,261,3698,1,976671742
100019,2777,1665,2,973124102


In [9]:
random_val.sort_values(['user_id', 'rating'], ascending=[True, False]).head(20)

Unnamed: 0,user_id,item_id,rating,timestamp
28888,1,1035,5,978301753
71232,1,3105,5,978301713
72371,1,1193,5,978300760
98463,1,1836,5,978300172
58747,1,2018,4,978301777
12706,2,1945,5,978298458
15223,2,2002,5,978300100
19144,2,1357,5,978298709
29517,2,1957,5,978298750
98670,2,3468,5,978298542


In [10]:
random_ranged_labels = (
    random_val
    .sort_values(['user_id', 'rating'], ascending=[True, False])
    .groupby('user_id')
    .agg(
        items=('item_id', lambda x: list(x)),
        ratings=('rating', lambda x: list(x))
    )
    .reset_index()
)
random_ranged_labels.head()

Unnamed: 0,user_id,items,ratings
0,1,"[1035, 3105, 1193, 1836, 2018]","[5, 5, 5, 5, 4]"
1,2,"[1945, 2002, 1357, 1957, 3468, 3451, 3068, 459...","[5, 5, 5, 5, 5, 4, 4, 3, 3, 3, 3, 3, 2, 1]"
2,3,"[1259, 1196, 1049, 1394, 1270]","[5, 4, 4, 4, 3]"
3,4,[1210],[3]
4,5,"[2427, 3083, 2997, 1175, 2289, 348, 1392, 506,...","[5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, ..."


In [11]:
random_ranged_labels[random_ranged_labels['user_id'] == 2]['ratings'].apply(lambda x: [i > 3 for i in x])

1    [True, True, True, True, True, True, True, Fal...
Name: ratings, dtype: object

In [12]:
random_ranged_labels['ratings'].apply(lambda x: [i > 3 for i in x])

0                          [True, True, True, True, True]
1       [True, True, True, True, True, True, True, Fal...
2                         [True, True, True, True, False]
3                                                 [False]
4       [True, True, True, True, True, True, True, Tru...
                              ...                        
5956    [True, True, True, True, True, True, True, Tru...
5957    [True, True, True, True, True, True, True, Tru...
5958                                        [True, False]
5959        [True, True, True, True, False, False, False]
5960    [True, True, True, True, True, True, True, Tru...
Name: ratings, Length: 5961, dtype: object

In [13]:
random_ranged_labels

Unnamed: 0,user_id,items,ratings
0,1,"[1035, 3105, 1193, 1836, 2018]","[5, 5, 5, 5, 4]"
1,2,"[1945, 2002, 1357, 1957, 3468, 3451, 3068, 459...","[5, 5, 5, 5, 5, 4, 4, 3, 3, 3, 3, 3, 2, 1]"
2,3,"[1259, 1196, 1049, 1394, 1270]","[5, 4, 4, 4, 3]"
3,4,[1210],[3]
4,5,"[2427, 3083, 2997, 1175, 2289, 348, 1392, 506,...","[5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, ..."
...,...,...,...
5956,6036,"[3471, 3150, 1303, 971, 3365, 3028, 1214, 232,...","[5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, ..."
5957,6037,"[1207, 1253, 2973, 1956, 50, 1193, 2324, 1210,...","[5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ..."
5958,6038,"[3088, 1354]","[5, 3]"
5959,6039,"[364, 903, 3396, 914, 2174, 1269, 2622]","[5, 5, 4, 4, 3, 3, 3]"


In [20]:
(
    random_val
    .sort_values(['user_id', 'rating'], ascending=[True, False])
    .groupby('user_id')
    .agg(
        items=('item_id', lambda x: list(x)),
        ratings=('rating', lambda x: list(x))
    )
    .reset_index()
    .apply(lambda x: [i for (i, r) in zip(x['items'], x['ratings']) if r > 3], axis=1)
)

0                          [1035, 3105, 1193, 1836, 2018]
1              [1945, 2002, 1357, 1957, 3468, 3451, 3068]
2                                [1259, 1196, 1049, 1394]
3                                                      []
4       [2427, 3083, 2997, 1175, 2289, 348, 1392, 506,...
                              ...                        
5956    [3471, 3150, 1303, 971, 3365, 3028, 1214, 232,...
5957    [1207, 1253, 2973, 1956, 50, 1193, 2324, 1210,...
5958                                               [3088]
5959                                [364, 903, 3396, 914]
5960    [1111, 1237, 1077, 750, 1248, 1294, 1295, 1249...
Length: 5961, dtype: object

In [14]:
random_ranged_labels.apply(lambda x: [i for (i, r) in zip(x['items'], x['ratings']) if r > 3], axis=1)

0                          [1035, 3105, 1193, 1836, 2018]
1              [1945, 2002, 1357, 1957, 3468, 3451, 3068]
2                                [1259, 1196, 1049, 1394]
3                                                      []
4       [2427, 3083, 2997, 1175, 2289, 348, 1392, 506,...
                              ...                        
5956    [3471, 3150, 1303, 971, 3365, 3028, 1214, 232,...
5957    [1207, 1253, 2973, 1956, 50, 1193, 2324, 1210,...
5958                                               [3088]
5959                                [364, 903, 3396, 914]
5960    [1111, 1237, 1077, 750, 1248, 1294, 1295, 1249...
Length: 5961, dtype: object

In [16]:
random_recommendations

0    1365
1    2706
2    3667
3    3684
4    1881
5    2224
6     730
7     355
8    1241
9     326
Name: item_id, dtype: int64