# LightGCN model RecSys

In [1]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import yaml
import glob
import torch
import torch.nn
from lightgcn import LightGCN
# from torch_geometric.nn import LightGCN
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from collections import defaultdict
from utils import *
pd.set_option('display.max_colwidth', None)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Load Config

In [3]:
%cd ..
with open("params.yaml") as config_file:
    config = yaml.safe_load(config_file)
config

/Users/yingkang/4thBrain/GNN-eCommerce


{'base': {'random_seed': 42},
 'data': {'cosmetic_shop': 'data/raw/cosmetic-shop-ecommerce-events/',
  'preprocessed': 'data/preprocessed/'},
 'training': {'event_type_weights': {'view': 0.01,
   'cart': 0.1,
   'remove_from_cart': -0.09,
   'purchase': 1.0}},
 'reports': None}

# Load Interaction Matrix from csv

In [4]:
interaction_matrix = pd.read_csv(config['data']['preprocessed'] + "interaction_matrix.csv")
interaction_matrix = interaction_matrix.rename(columns={"product_id": "item_id"})

In [5]:
im = interaction_matrix[['user_id', 'item_id', 'weight']].copy()

In [6]:
print('Total data size: ', len(im), ', unique user: ', im.user_id.nunique(), ', unique items: ', im.item_id.nunique())

Total data size:  10157408 , unique user:  1639358 , unique items:  54571


In [7]:
# ?? Should user never purchase removed at the beginning ??? NO for now!!
# mini_im = purchase_users(im)
mini_im = im.sample(100000, random_state=1)  #100000

In [8]:
print('Mini dataset size: ', len(mini_im), ', Users at least purchased once: ', len(purchase_users(mini_im)))
# print('Valid data percentage: ', f'{len(mini_im)/len(im):.2%}')

Mini dataset size:  100000 , Users at least purchased once:  18729


### Prepare Train/ Val/ Test dataset

In [9]:
train_df, test_df = train_test_split(mini_im, test_size=0.3, random_state=16)
test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=16)

In [10]:
print('Mini set unique user: ', mini_im.user_id.nunique(), ', unique items: ', mini_im.item_id.nunique())
print("Train Size  : ", len(train_df))
print("Val Size : ", len (val_df))
print("Test Size : ", len (test_df))

Mini set unique user:  74976 , unique items:  25645
Train Size  :  70000
Val Size :  15000
Test Size :  15000


In [11]:
n_users, n_items, train_df, train_pos_list_df, val_pos_list_df, test_pos_list_df = prepare_val_test(train_df, val_df, test_df)  # , val_u_i_matrix, test_u_i_matrix

In [12]:
print("Users : ", n_users)
print("Items : ", n_items)
print("Train Size  : ", len(train_df))
print("Val Size : ", len (val_pos_list_df))
print("Test Size : ", len (test_pos_list_df))

Users :  55837
Items :  21848
Train Size  :  70000
Val Size :  589
Test Size :  594


In [19]:
# train_df

In [20]:
# train_pos_list_df

In [21]:
# val_pos_list_df

In [75]:
# val_df, test_df = sync_nodes(train_df, val_df, test_df)
# n_users, n_items, train_df, val_df, test_df = relabelling(train_df, val_df, test_df)
# train_df['item_id_idx'] = train_df['item_id_idx'] + n_users
# train_pos_list_df = pos_item_list(train_df)
# val_pos_list_df = pos_item_list(val_df)
# test_pos_list_df = pos_item_list(test_df)

In [13]:
train_df

Unnamed: 0,weight,user_id_idx,item_id_idx
2121608,0.04,10627,70358
8381314,0.01,45146,75271
1771809,0.01,8895,72131
6683976,0.01,34963,57858
9386737,0.20,51066,59031
...,...,...,...
417990,0.11,2058,75462
2002131,0.01,10036,61660
8610057,0.11,46466,64178
8725565,0.01,47125,59772


In [14]:
train_pos_list_df

Unnamed: 0,user_id_idx,item_id_idx_list,ignor_neg_list
0,16,[62009],[62009]
1,25,[63977],[63977]
2,27,[56235],[56235]
3,44,[59581],"[68857, 59581]"
4,49,[59260],[59260]
...,...,...,...
7445,55727,[57755],[57755]
7446,55744,[64220],[64220]
7447,55774,"[77004, 76830]","[77004, 64605, 76830, 75181]"
7448,55786,[59467],[59467]


In [15]:
val_pos_list_df

Unnamed: 0,user_id_idx,item_id_idx_list
0,95,[9152]
1,248,[12703]
2,331,[14474]
3,437,[19202]
4,447,[10242]
...,...,...
584,54143,[3649]
585,54216,[3745]
586,54240,[8179]
587,54271,[3754]


In [16]:
test_pos_list_df

Unnamed: 0,user_id_idx,item_id_idx_list
0,44,[13020]
1,83,[17319]
2,104,[5519]
3,127,[3092]
4,255,[20453]
...,...,...
589,55090,[16015]
590,55214,[3477]
591,55325,"[12895, 20438]"
592,55634,[7893]


In [25]:
# v = val_pos_list_df
# v.loc[3] = [809, [2600, 99]]
# v.loc[4] = [999, []]
# v

In [51]:
# v.item_id_idx_list = v.item_id_idx_list.apply(lambda x: np.array(x)+100)

In [80]:
# def ignor_neg_item_list(train_pos_list_df, val_pos_list_df, test_pos_list_df, n_users):
#     v = pd.merge(val_pos_list_df, test_pos_list_df, how='outer', left_on='user_id_idx', right_on='user_id_idx')
#     train_pos_list_df = pd.merge(train_pos_list_df, v, how='left', left_on='user_id_idx', right_on='user_id_idx')
#     train_pos_list_df.item_id_idx_list = train_pos_list_df.item_id_idx_list.fillna('').apply(list)
#     train_pos_list_df.item_id_idx_list_x = train_pos_list_df.item_id_idx_list_x.fillna('').apply(list)
#     train_pos_list_df.item_id_idx_list_x = train_pos_list_df.item_id_idx_list_x.apply(lambda x: np.array(x) + n_users)
#     train_pos_list_df.item_id_idx_list_y = train_pos_list_df.item_id_idx_list_y.fillna('').apply(list)
#     train_pos_list_df.item_id_idx_list_y = train_pos_list_df.item_id_idx_list_y.apply(lambda x: np.array(x) + n_users)
#
#     train_pos_list_df['ignor_neg_list'] = [list((set(a).union(b).union(c))) for a, b, c in
#                                            zip(train_pos_list_df.item_id_idx_list, train_pos_list_df.item_id_idx_list_x, train_pos_list_df.item_id_idx_list_y)]
#     train_pos_list_df = train_pos_list_df[['user_id_idx', 'item_id_idx_list', 'ignor_neg_list']]
#
#     return train_pos_list_df

In [26]:
# x = ignor_neg_item_list(train_pos_list_df, val_pos_list_df, test_pos_list_df, n_users)
# x

In [13]:
# users, pos_items, neg_items = pos_neg_edge_index(train_pos_list_df, 100, n_users, n_items)

In [27]:
# print('Size of sample pool: ', len(users))

### Instantiate model and train/val the model

In [17]:
latent_dim = 64
n_layers = 3
LR = 0.005

model = LightGCN(num_nodes=n_users+n_items, embedding_dim=latent_dim, num_layers=n_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
print("Size of Learnable Embedding : ", [x.shape for x in list(model.parameters())])

Size of Learnable Embedding :  [torch.Size([77685, 64])]


In [18]:
def train_and_evl(n_users, n_items, n_neg, train_df, train_pos_list_df, test_pos_list_df, model, optimizer, device, EPOCHS = 50, BATCH_SIZE = 1024, K = 20, DECAY = 0.0001):  # test_u_i_matrix,
    edge_index, edge_weight = df_to_graph(train_df, True)

    # test_u_i_matrix = test_u_i_matrix.to(device)
    edge_index = edge_index.to(device)
    edge_weight = edge_weight.to(device)
    model.to(device)

    bpr_loss_epoch_list = []
    reg_loss_epoch_list = []
    final_loss_epoch_list = []
    recall_epoch_list = []
    precision_epoch_list = []

    print('bpr_loss | reg_loss | final_loss | precision | recall')
    for epoch in tqdm(range(EPOCHS)):
        bpr_loss, reg_loss, final_loss = train_loop(train_pos_list_df, n_users, n_items, n_neg, edge_index, edge_weight, model, optimizer, BATCH_SIZE)

        precision, recall = evaluation(model, n_users, n_items, edge_index, edge_weight, test_pos_list_df, K)

        print(bpr_loss, reg_loss, final_loss, precision, recall)
        bpr_loss_epoch_list.append(bpr_loss)
        reg_loss_epoch_list.append(reg_loss)
        final_loss_epoch_list.append(final_loss)
        recall_epoch_list.append(recall)
        precision_epoch_list.append(precision)

    return (
        bpr_loss_epoch_list,
        reg_loss_epoch_list,
        final_loss_epoch_list,
        recall_epoch_list,
        precision_epoch_list)

In [19]:
bpr_loss, reg_loss, final_loss, recall, precision = \
    train_and_evl(n_users, n_items, 100, train_df, train_pos_list_df, val_pos_list_df, model, optimizer, device=device, EPOCHS = 50, BATCH_SIZE = 1024, K = 20, DECAY = 0.0001)  # val_u_i_matrix,

bpr_loss | reg_loss | final_loss | precision | recall


  0%|          | 0/50 [00:00<?, ?it/s]

0.0515 0.0005 0.052 0.0003 0.0059
0.0004 0.0007 0.0011 0.0003 0.0059
0.0002 0.0008 0.001 0.0003 0.0059


KeyboardInterrupt: 

In [18]:
edge_index, edge_weight = df_to_graph(train_df, True)
edge_index

tensor([[ 1753,  5772,  4461,  ...,  9099,  7207,  8270],
        [10986,  9195,  7952,  ...,  4856,  1356,   580]])

In [23]:
idx = list(range(len(users)))
random.shuffle(idx)
loader = DataLoader(idx, batch_size=100, shuffle=True)
len(idx)

81100

In [24]:
model.train()
for batch in loader:
    optimizer.zero_grad()

    batch_usr = users[batch]
    batch_pos_items = pos_items[batch]
    batch_neg_items = neg_items[batch]

    batch_pos_neg_labels = batch_pos_neg_edges(batch_usr, batch_pos_items, batch_neg_items)

In [25]:
batch_pos_neg_labels

tensor([[ 4647,   142,  1749,  1675,   523,   782,  3760,  5455,   159,  2272,
          4928,  1460,  6241,  3454,  1992,  3984,   898,  5821,   406,  6291,
          4630,  3049,  4261,  6359,  4463,  3290,  6498,  2889,  3236,   813,
          6695,  4228,  6431,  4883,  6329,  4536,  3299,  1699,  2499,  4954,
          2807,  6749,  4835,   838,  4648,  3746,   819,  1175,   304,  6219,
          3593,  6021,  1122,  2116,   994,  6398,  3194,  2419,  4156,  2019,
          5275,  3031,  4143,  3562,   745,   591,  4795,  6338,  1191,  2961,
          1855,   950,   213,    32,  3634,  2159,  4488,  3664,  1837,   443,
          5006,  2791,   330,  3838,   130,  2164,  2475,  2899,  5838,  4215,
          2159,  5032,    66,   398,  5475,  5314,  5193,  4008,   312,  1061,
          4647,   142,  1749,  1675,   523,   782,  3760,  5455,   159,  2272,
          4928,  1460,  6241,  3454,  1992,  3984,   898,  5821,   406,  6291,
          4630,  3049,  4261,  6359,  4463,  3290,  

In [19]:
# batch_pos_neg_labels[1].min()

tensor(55920)

In [26]:
out = model(edge_index, batch_pos_neg_labels, edge_weight)
out

tensor([ 8.7462e+00,  9.0312e+00,  8.7245e+00,  8.7949e+00,  8.7454e+00,
         8.7560e+00,  8.7704e+00,  8.6928e+00,  8.6050e+00,  8.8123e+00,
         8.7946e+00,  8.7746e+00,  9.3868e+00,  8.7652e+00,  8.8764e+00,
         8.8347e+00,  8.9700e+00,  8.7243e+00,  8.5606e+00,  8.7655e+00,
         8.4840e+00,  8.6740e+00,  8.7507e+00,  8.7256e+00,  8.6734e+00,
         9.2560e+00,  8.9011e+00,  8.7272e+00,  8.7289e+00,  8.4832e+00,
         8.8131e+00,  8.6571e+00,  8.6996e+00,  8.7895e+00,  9.0355e+00,
         8.6716e+00,  8.8301e+00,  8.8911e+00,  8.9282e+00,  8.8197e+00,
         8.6396e+00,  9.0734e+00,  8.7252e+00,  8.9373e+00,  8.6486e+00,
         8.9984e+00,  8.8802e+00,  8.7302e+00,  9.2221e+00,  8.7233e+00,
         8.7469e+00,  8.9942e+00,  8.7192e+00,  8.6693e+00,  8.8153e+00,
         8.6592e+00,  8.6334e+00,  8.7553e+00,  8.6475e+00,  8.6949e+00,
         8.6151e+00,  8.7265e+00,  8.7388e+00,  8.9374e+00,  8.9761e+00,
         8.6015e+00,  8.7837e+00,  8.8329e+00,  8.9

In [28]:
size = len(batch)

bpr_loss = model.recommendation_loss(out[:size], out[size:], 0) * size
reg_loss = regularization_loss(model.embedding.weight, size, batch_usr, batch_pos_items, batch_neg_items)
loss = bpr_loss + reg_loss

tensor(0.0010, grad_fn=<AddBackward0>)

In [23]:
# bpr_loss_batch_list = []
# reg_loss_batch_list = []
# final_loss_batch_list = []

In [24]:
# bpr_loss_batch_list.append(bpr_loss.item())
# reg_loss_batch_list.append(reg_loss.item())
# final_loss_batch_list.append(loss.item())

In [25]:
# bpr_loss = round(np.mean(bpr_loss_batch_list), 8)
# reg_loss = round(np.mean(reg_loss_batch_list), 8)
# final_loss = round(np.mean(final_loss_batch_list), 8)

In [30]:
print("bpr loss: ", loss, "reg loss: ", reg_loss, "final loss", loss)

bpr loss:  tensor(0.0010, grad_fn=<AddBackward0>) reg loss:  tensor(0.0009, grad_fn=<MulBackward0>) final loss tensor(0.0010, grad_fn=<AddBackward0>)


In [27]:
# bpr_loss_epoch_list = []
# reg_loss_epoch_list = []
# final_loss_epoch_list = []
# recall_epoch_list = []
# precision_epoch_list = []

In [31]:
model.eval()
with torch.no_grad():
    embeds = model.get_embedding(edge_index, edge_weight)   # ?? ???
    final_usr_embed, final_item_embed = torch.split(embeds, (n_users, n_items))

In [30]:
# matrix = interact_matrix(train_df, n_users, n_items)

In [34]:
test_topK_recall, test_topK_precision = get_metrics(final_usr_embed, final_item_embed, val_pos_list_df, 20)

In [35]:
print('precision: ', test_topK_precision, 'recall: ', test_topK_recall)

precision:  0.0 recall:  0.0


In [39]:
relevance_score = final_usr_embed @ final_item_embed.t()
relevance_score.shape

torch.Size([6767, 5015])

In [94]:
topk_relevance_indices = torch.topk(relevance_score, 20).indices
topk_relevance_indices[0:6]

tensor([[2077, 4210, 1475,  884, 4279, 4977, 1204, 1829, 1261, 2638, 3566,  769,
         2911, 2088, 2524, 1313, 4375, 1624, 1093, 4767],
        [4653, 2535,  259, 2326,  434, 2558, 4091,  858, 2389, 3123, 2491, 1451,
          365, 2345, 4836,  427,  364, 1295, 2467, 4362],
        [3824, 2881, 3703,  193, 4956, 2139,  141,  119, 3304, 1601, 2023, 4878,
         2449, 3161, 3079, 3209, 2413, 3356, 2238,  607],
        [3203,  595, 3981, 4555, 4016, 1511, 3913,  555,  426, 3555, 2531, 2815,
          400, 1468, 2810, 1314, 2585, 1901, 4093,  539],
        [ 520,  355, 2326, 1520, 2750, 1199, 2520, 2615, 2490, 2569, 4563, 4279,
         3696, 2056, 3063, 4091, 2764, 4758, 2065,  705],
        [4587, 2747, 2157, 1884, 2640, 1175, 2159, 1429, 4898, 1822, 3696, 1016,
         2326, 1411, 2293, 3529,  259, 4563, 3277, 2413]])

In [52]:
topk_relevance_indices.shape

torch.Size([6767, 20])

In [102]:
# t = topk_relevance_indices
# t = torch.flatten(t)
# t.shape

In [99]:
# torch.max(t)

In [100]:
# torch.min(t)

In [101]:
# aa = list([68,561,1949,2478,4236,5117,5320,6173])
# train_df.loc[train_df['user_id_idx'].isin(aa)]

In [96]:
topk_relevance_indices_df = pd.DataFrame(topk_relevance_indices.cpu().numpy())
topk_relevance_indices_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,2077,4210,1475,884,4279,4977,1204,1829,1261,2638,3566,769,2911,2088,2524,1313,4375,1624,1093,4767
1,4653,2535,259,2326,434,2558,4091,858,2389,3123,2491,1451,365,2345,4836,427,364,1295,2467,4362
2,3824,2881,3703,193,4956,2139,141,119,3304,1601,2023,4878,2449,3161,3079,3209,2413,3356,2238,607
3,3203,595,3981,4555,4016,1511,3913,555,426,3555,2531,2815,400,1468,2810,1314,2585,1901,4093,539
4,520,355,2326,1520,2750,1199,2520,2615,2490,2569,4563,4279,3696,2056,3063,4091,2764,4758,2065,705
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6762,1266,2924,1101,2838,1751,2915,2628,3923,1916,2558,400,1494,3755,4272,1489,3123,4050,1665,1117,2238
6763,1041,256,1748,2326,2643,1313,2023,3707,701,3753,4653,964,3182,2537,2676,4279,2491,187,1199,571
6764,2815,632,266,1028,2030,2121,3064,2290,4016,4084,1169,4847,3203,3844,1102,2799,2363,2859,4620,478
6765,2325,2324,4955,1320,3419,35,3923,2920,1487,4200,3487,4218,2174,2585,3079,3424,937,2640,3160,838


In [103]:
topk_relevance_indices_df['top_rlvnt_itm'] = topk_relevance_indices_df.values.tolist()
topk_relevance_indices_df['top_rlvnt_itm']

0        [2077, 4210, 1475, 884, 4279, 4977, 1204, 1829, 1261, 2638, 3566, 769, 2911, 2088, 2524, 1313, 4375, 1624, 1093, 4767]
1            [4653, 2535, 259, 2326, 434, 2558, 4091, 858, 2389, 3123, 2491, 1451, 365, 2345, 4836, 427, 364, 1295, 2467, 4362]
2          [3824, 2881, 3703, 193, 4956, 2139, 141, 119, 3304, 1601, 2023, 4878, 2449, 3161, 3079, 3209, 2413, 3356, 2238, 607]
3           [3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 3555, 2531, 2815, 400, 1468, 2810, 1314, 2585, 1901, 4093, 539]
4         [520, 355, 2326, 1520, 2750, 1199, 2520, 2615, 2490, 2569, 4563, 4279, 3696, 2056, 3063, 4091, 2764, 4758, 2065, 705]
                                                                 ...                                                           
6762    [1266, 2924, 1101, 2838, 1751, 2915, 2628, 3923, 1916, 2558, 400, 1494, 3755, 4272, 1489, 3123, 4050, 1665, 1117, 2238]
6763        [1041, 256, 1748, 2326, 2643, 1313, 2023, 3707, 701, 3753, 4653, 964, 3182, 2537, 2676, 4279

In [104]:
topk_relevance_indices_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,top_rlvnt_itm
0,2077,4210,1475,884,4279,4977,1204,1829,1261,2638,...,769,2911,2088,2524,1313,4375,1624,1093,4767,"[2077, 4210, 1475, 884, 4279, 4977, 1204, 1829, 1261, 2638, 3566, 769, 2911, 2088, 2524, 1313, 4375, 1624, 1093, 4767]"
1,4653,2535,259,2326,434,2558,4091,858,2389,3123,...,1451,365,2345,4836,427,364,1295,2467,4362,"[4653, 2535, 259, 2326, 434, 2558, 4091, 858, 2389, 3123, 2491, 1451, 365, 2345, 4836, 427, 364, 1295, 2467, 4362]"
2,3824,2881,3703,193,4956,2139,141,119,3304,1601,...,4878,2449,3161,3079,3209,2413,3356,2238,607,"[3824, 2881, 3703, 193, 4956, 2139, 141, 119, 3304, 1601, 2023, 4878, 2449, 3161, 3079, 3209, 2413, 3356, 2238, 607]"
3,3203,595,3981,4555,4016,1511,3913,555,426,3555,...,2815,400,1468,2810,1314,2585,1901,4093,539,"[3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 3555, 2531, 2815, 400, 1468, 2810, 1314, 2585, 1901, 4093, 539]"
4,520,355,2326,1520,2750,1199,2520,2615,2490,2569,...,4279,3696,2056,3063,4091,2764,4758,2065,705,"[520, 355, 2326, 1520, 2750, 1199, 2520, 2615, 2490, 2569, 4563, 4279, 3696, 2056, 3063, 4091, 2764, 4758, 2065, 705]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6762,1266,2924,1101,2838,1751,2915,2628,3923,1916,2558,...,1494,3755,4272,1489,3123,4050,1665,1117,2238,"[1266, 2924, 1101, 2838, 1751, 2915, 2628, 3923, 1916, 2558, 400, 1494, 3755, 4272, 1489, 3123, 4050, 1665, 1117, 2238]"
6763,1041,256,1748,2326,2643,1313,2023,3707,701,3753,...,964,3182,2537,2676,4279,2491,187,1199,571,"[1041, 256, 1748, 2326, 2643, 1313, 2023, 3707, 701, 3753, 4653, 964, 3182, 2537, 2676, 4279, 2491, 187, 1199, 571]"
6764,2815,632,266,1028,2030,2121,3064,2290,4016,4084,...,4847,3203,3844,1102,2799,2363,2859,4620,478,"[2815, 632, 266, 1028, 2030, 2121, 3064, 2290, 4016, 4084, 1169, 4847, 3203, 3844, 1102, 2799, 2363, 2859, 4620, 478]"
6765,2325,2324,4955,1320,3419,35,3923,2920,1487,4200,...,4218,2174,2585,3079,3424,937,2640,3160,838,"[2325, 2324, 4955, 1320, 3419, 35, 3923, 2920, 1487, 4200, 3487, 4218, 2174, 2585, 3079, 3424, 937, 2640, 3160, 838]"


In [105]:
topk_relevance_indices_df['user_ID'] = topk_relevance_indices_df.index
topk_relevance_indices_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,12,13,14,15,16,17,18,19,top_rlvnt_itm,user_ID
0,2077,4210,1475,884,4279,4977,1204,1829,1261,2638,...,2911,2088,2524,1313,4375,1624,1093,4767,"[2077, 4210, 1475, 884, 4279, 4977, 1204, 1829, 1261, 2638, 3566, 769, 2911, 2088, 2524, 1313, 4375, 1624, 1093, 4767]",0
1,4653,2535,259,2326,434,2558,4091,858,2389,3123,...,365,2345,4836,427,364,1295,2467,4362,"[4653, 2535, 259, 2326, 434, 2558, 4091, 858, 2389, 3123, 2491, 1451, 365, 2345, 4836, 427, 364, 1295, 2467, 4362]",1
2,3824,2881,3703,193,4956,2139,141,119,3304,1601,...,2449,3161,3079,3209,2413,3356,2238,607,"[3824, 2881, 3703, 193, 4956, 2139, 141, 119, 3304, 1601, 2023, 4878, 2449, 3161, 3079, 3209, 2413, 3356, 2238, 607]",2
3,3203,595,3981,4555,4016,1511,3913,555,426,3555,...,400,1468,2810,1314,2585,1901,4093,539,"[3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 3555, 2531, 2815, 400, 1468, 2810, 1314, 2585, 1901, 4093, 539]",3
4,520,355,2326,1520,2750,1199,2520,2615,2490,2569,...,3696,2056,3063,4091,2764,4758,2065,705,"[520, 355, 2326, 1520, 2750, 1199, 2520, 2615, 2490, 2569, 4563, 4279, 3696, 2056, 3063, 4091, 2764, 4758, 2065, 705]",4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6762,1266,2924,1101,2838,1751,2915,2628,3923,1916,2558,...,3755,4272,1489,3123,4050,1665,1117,2238,"[1266, 2924, 1101, 2838, 1751, 2915, 2628, 3923, 1916, 2558, 400, 1494, 3755, 4272, 1489, 3123, 4050, 1665, 1117, 2238]",6762
6763,1041,256,1748,2326,2643,1313,2023,3707,701,3753,...,3182,2537,2676,4279,2491,187,1199,571,"[1041, 256, 1748, 2326, 2643, 1313, 2023, 3707, 701, 3753, 4653, 964, 3182, 2537, 2676, 4279, 2491, 187, 1199, 571]",6763
6764,2815,632,266,1028,2030,2121,3064,2290,4016,4084,...,3203,3844,1102,2799,2363,2859,4620,478,"[2815, 632, 266, 1028, 2030, 2121, 3064, 2290, 4016, 4084, 1169, 4847, 3203, 3844, 1102, 2799, 2363, 2859, 4620, 478]",6764
6765,2325,2324,4955,1320,3419,35,3923,2920,1487,4200,...,2174,2585,3079,3424,937,2640,3160,838,"[2325, 2324, 4955, 1320, 3419, 35, 3923, 2920, 1487, 4200, 3487, 4218, 2174, 2585, 3079, 3424, 937, 2640, 3160, 838]",6765


In [106]:
topk_relevance_indices_df = topk_relevance_indices_df[['user_ID', 'top_rlvnt_itm']]
topk_relevance_indices_df

Unnamed: 0,user_ID,top_rlvnt_itm
0,0,"[2077, 4210, 1475, 884, 4279, 4977, 1204, 1829, 1261, 2638, 3566, 769, 2911, 2088, 2524, 1313, 4375, 1624, 1093, 4767]"
1,1,"[4653, 2535, 259, 2326, 434, 2558, 4091, 858, 2389, 3123, 2491, 1451, 365, 2345, 4836, 427, 364, 1295, 2467, 4362]"
2,2,"[3824, 2881, 3703, 193, 4956, 2139, 141, 119, 3304, 1601, 2023, 4878, 2449, 3161, 3079, 3209, 2413, 3356, 2238, 607]"
3,3,"[3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 3555, 2531, 2815, 400, 1468, 2810, 1314, 2585, 1901, 4093, 539]"
4,4,"[520, 355, 2326, 1520, 2750, 1199, 2520, 2615, 2490, 2569, 4563, 4279, 3696, 2056, 3063, 4091, 2764, 4758, 2065, 705]"
...,...,...
6762,6762,"[1266, 2924, 1101, 2838, 1751, 2915, 2628, 3923, 1916, 2558, 400, 1494, 3755, 4272, 1489, 3123, 4050, 1665, 1117, 2238]"
6763,6763,"[1041, 256, 1748, 2326, 2643, 1313, 2023, 3707, 701, 3753, 4653, 964, 3182, 2537, 2676, 4279, 2491, 187, 1199, 571]"
6764,6764,"[2815, 632, 266, 1028, 2030, 2121, 3064, 2290, 4016, 4084, 1169, 4847, 3203, 3844, 1102, 2799, 2363, 2859, 4620, 478]"
6765,6765,"[2325, 2324, 4955, 1320, 3419, 35, 3923, 2920, 1487, 4200, 3487, 4218, 2174, 2585, 3079, 3424, 937, 2640, 3160, 838]"


In [42]:
# test_interacted_items = val_df.groupby('user_id_idx')['item_id_idx'].apply(list).reset_index()
# test_interacted_items

In [43]:
test_pos_list_df

Unnamed: 0,user_id_idx,item_id_idx_list
0,68,[2557]
1,561,[97]
2,1949,[2314]
3,2478,[2887]
4,4236,[547]
5,5117,[2181]
6,5320,[868]
7,6173,[2660]


In [44]:
metrics_df = pd.merge(test_pos_list_df, topk_relevance_indices_df, how='left', left_on='user_id_idx', right_on='user_ID')
metrics_df

Unnamed: 0,user_id_idx,item_id_idx_list,user_ID,top_rlvnt_itm
0,68,[2557],68,"[4409, 3868, 241, 2960, 2675, 2251, 710, 313, 4160, 2531, 400, 882, 2924, 150, 365, 1797, 4856, 1355, 3030, 610]"
1,561,[97],561,"[4300, 2531, 2014, 705, 958, 2915, 4821, 1114, 1910, 4743, 2570, 1080, 1380, 685, 3795, 1028, 1794, 1259, 3914, 1445]"
2,1949,[2314],1949,"[1873, 2326, 2078, 2139, 2947, 1199, 4210, 4575, 4401, 1475, 122, 2628, 2520, 4813, 1204, 4396, 2286, 2065, 2911, 3707]"
3,2478,[2887],2478,"[4784, 69, 1180, 1687, 4940, 1027, 3763, 3161, 3612, 1093, 3745, 1596, 1820, 2884, 4507, 2325, 2139, 3451, 3554, 465]"
4,4236,[547],4236,"[2436, 377, 1451, 3981, 1800, 2965, 323, 2295, 1057, 782, 508, 2945, 1314, 1969, 329, 2726, 3030, 89, 3327, 4396]"
5,5117,[2181],5117,"[3616, 2520, 303, 2799, 2268, 1127, 676, 2045, 481, 1289, 3798, 1607, 373, 353, 4426, 555, 1371, 3566, 1489, 168]"
6,5320,[868],5320,"[3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 2531, 2815, 3555, 400, 1314, 2585, 1468, 2810, 1901, 3125, 3062]"
7,6173,[2660],6173,"[2139, 3811, 4668, 4801, 1178, 2184, 3555, 3637, 1289, 508, 3667, 3164, 3327, 353, 2077, 1259, 3078, 2643, 3030, 2828]"


In [46]:
metrics_df['intrsctn_itm'] = [list(set(a).intersection(b)) for a, b in
                                  zip(metrics_df.item_id_idx_list, metrics_df.top_rlvnt_itm)]
metrics_df

Unnamed: 0,user_id_idx,item_id_idx_list,user_ID,top_rlvnt_itm,intrsctn_itm
0,68,[2557],68,"[4409, 3868, 241, 2960, 2675, 2251, 710, 313, 4160, 2531, 400, 882, 2924, 150, 365, 1797, 4856, 1355, 3030, 610]",[]
1,561,[97],561,"[4300, 2531, 2014, 705, 958, 2915, 4821, 1114, 1910, 4743, 2570, 1080, 1380, 685, 3795, 1028, 1794, 1259, 3914, 1445]",[]
2,1949,[2314],1949,"[1873, 2326, 2078, 2139, 2947, 1199, 4210, 4575, 4401, 1475, 122, 2628, 2520, 4813, 1204, 4396, 2286, 2065, 2911, 3707]",[]
3,2478,[2887],2478,"[4784, 69, 1180, 1687, 4940, 1027, 3763, 3161, 3612, 1093, 3745, 1596, 1820, 2884, 4507, 2325, 2139, 3451, 3554, 465]",[]
4,4236,[547],4236,"[2436, 377, 1451, 3981, 1800, 2965, 323, 2295, 1057, 782, 508, 2945, 1314, 1969, 329, 2726, 3030, 89, 3327, 4396]",[]
5,5117,[2181],5117,"[3616, 2520, 303, 2799, 2268, 1127, 676, 2045, 481, 1289, 3798, 1607, 373, 353, 4426, 555, 1371, 3566, 1489, 168]",[]
6,5320,[868],5320,"[3203, 595, 3981, 4555, 4016, 1511, 3913, 555, 426, 2531, 2815, 3555, 400, 1314, 2585, 1468, 2810, 1901, 3125, 3062]",[]
7,6173,[2660],6173,"[2139, 3811, 4668, 4801, 1178, 2184, 3555, 3637, 1289, 508, 3667, 3164, 3327, 353, 2077, 1259, 3078, 2643, 3030, 2828]",[]
