# LightGCN model RecSys

In [1]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import yaml
import glob
import torch
import torch.nn
from lightgcn import LightGCN
# from torch_geometric.nn import LightGCN
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from collections import defaultdict
from utils import *
pd.set_option('display.max_colwidth', None)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Load Config

In [3]:
%cd ..
with open("params.yaml") as config_file:
    config = yaml.safe_load(config_file)
config

/Users/yingkang/4thBrain/GNN-eCommerce


{'base': {'random_seed': 42},
 'data': {'cosmetic_shop': 'data/raw/cosmetic-shop-ecommerce-events/',
  'preprocessed': 'data/preprocessed/'},
 'training': {'event_type_weights': {'view': 0.01,
   'cart': 0.1,
   'remove_from_cart': -0.09,
   'purchase': 1.0}},
 'reports': None}

# Load Interaction Matrix from csv

In [4]:
interaction_matrix = pd.read_csv(config['data']['preprocessed'] + "interaction_matrix.csv")
interaction_matrix = interaction_matrix.rename(columns={"product_id": "item_id"})

In [5]:
im = interaction_matrix[['user_id', 'item_id', 'weight']].copy()

In [6]:
print('Total data size: ', len(im), ', unique user: ', im.user_id.nunique(), ', unique items: ', im.item_id.nunique())

Total data size:  10157408 , unique user:  1639358 , unique items:  54571


In [6]:
# ying = im.head(1000)
# ying

In [41]:
# train_df, test_df = train_test_split(ying, test_size=0.3, random_state=16)
# test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=16)

In [7]:
# ?? Should user never purchase removed at the beginning ??? NO for now!!
# mini_im = purchase_users(im)
mini_im = im.sample(100000, random_state=1)  #100000

In [8]:
print('Mini dataset size: ', len(mini_im), ', Users at least purchased once: ', len(purchase_users(mini_im)))
# print('Valid data percentage: ', f'{len(mini_im)/len(im):.2%}')

Mini dataset size:  100000 , Users at least purchased once:  18729


In [9]:
# mini_im.loc[~mini_im['user_id'].isin(u_id_filter)]

### Prepare Train/ Val/ Test dataset

In [10]:
train_df, test_df = train_test_split(mini_im, test_size=0.3, random_state=16)
test_df, val_df = train_test_split(test_df, test_size=0.5, random_state=16)

In [11]:
print('Mini set unique user: ', mini_im.user_id.nunique(), ', unique items: ', mini_im.item_id.nunique())
print("Train Size  : ", len(train_df))
print("Val Size : ", len (val_df))
print("Test Size : ", len (test_df))

Mini set unique user:  74976 , unique items:  25645
Train Size  :  70000
Val Size :  15000
Test Size :  15000


In [12]:
n_users, n_items, train_df, val_pos_list_df, test_pos_list_df, val_u_i_matrix, test_u_i_matrix = \
    prepare_val_test(train_df, val_df, test_df, device)

In [13]:
print("Users : ", n_users)
print("Items : ", n_items)
print("Train Size  : ", len(train_df))
print("Val Size : ", len (val_pos_list_df))
print("Test Size : ", len (test_pos_list_df))

Users :  55837
Items :  21848
Train Size  :  70000
Val Size :  589
Test Size :  594


In [23]:
# train_df = train_df.loc[train_df['weight'] == 1].drop_duplicates('user_id_idx')

In [24]:
# users = torch.LongTensor(list(np.repeat(train_df['user_id_idx'], 100)))

In [40]:
# p = train_df['item_id_idx_list'].apply(lambda x: sample_pos(x, 100)).tolist()

In [14]:
# torch.LongTensor(p)

In [None]:
users, pos_items, neg_items = pos_neg_edge_index(train_df, 100, n_users, n_items)

In [16]:
print('Size of sample pool: ', len(users))

Size of sample pool:  745000


### Instantiate model and train/val the model

In [17]:
latent_dim = 64
n_layers = 3
LR = 0.005

model = LightGCN(num_nodes=n_users+n_items, embedding_dim=latent_dim, num_layers=n_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
print("Size of Learnable Embedding : ", [x.shape for x in list(model.parameters())])

Size of Learnable Embedding :  [torch.Size([77685, 64])]


In [18]:
def train_and_evl(n_users, n_items, n_neg, train_df, test_u_i_matrix, test_pos_list_df, model, optimizer, device, EPOCHS = 50, BATCH_SIZE = 1024, K = 20, DECAY = 0.0001):
    edge_index, edge_weight = df_to_graph(train_df, True)

    test_u_i_matrix = test_u_i_matrix.to(device)
    edge_index = edge_index.to(device)
    edge_weight = edge_weight.to(device)
    model.to(device)

    bpr_loss_epoch_list = []
    reg_loss_epoch_list = []
    final_loss_epoch_list = []
    recall_epoch_list = []
    precision_epoch_list = []

    print('bpr_loss | reg_loss | final_loss | precision | recall')
    for epoch in tqdm(range(EPOCHS)):
        bpr_loss, reg_loss, final_loss = train_loop(train_df, n_users, n_items, n_neg, edge_index, edge_weight, model, optimizer, BATCH_SIZE)

        precision, recall = evaluation(model, n_users, n_items, edge_index, edge_weight, test_u_i_matrix, test_pos_list_df, K)

        print(bpr_loss, reg_loss, final_loss, precision, recall)
        bpr_loss_epoch_list.append(bpr_loss)
        reg_loss_epoch_list.append(reg_loss)
        final_loss_epoch_list.append(final_loss)
        recall_epoch_list.append(recall)
        precision_epoch_list.append(precision)

    return (
        bpr_loss_epoch_list,
        reg_loss_epoch_list,
        final_loss_epoch_list,
        recall_epoch_list,
        precision_epoch_list)

In [19]:
bpr_loss, reg_loss, final_loss, recall, precision = \
    train_and_evl(n_users, n_items, 100, train_df, val_u_i_matrix, val_pos_list_df, model, optimizer, device=device, EPOCHS = 50, BATCH_SIZE = 1024, K = 20, DECAY = 0.0001)

bpr_loss | reg_loss | final_loss | precision | recall


  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [15]:
# users, pos_items, neg_items = pos_neg_edge_index(train_df, n_users, n_items)
# len(pos_items)

8264

817 pos sample out of 7000 train data: Pos_Sample_Rate = 11.67%
848091 pos sample out of 7110185 train data: Pos_Sample Rate = 11.93%

In [16]:
# idx = list(range(len(users)))
# random.shuffle(idx)
# loader = DataLoader(idx, batch_size=100, shuffle=True)

In [17]:
# model.train()
# for batch in loader:
#     optimizer.zero_grad()
#
#     batch_usr = users[batch]
#     batch_pos_items = pos_items[batch]
#     batch_neg_items = neg_items[batch]
#
#     batch_pos_neg_labels = batch_pos_neg_edges(batch_usr, batch_pos_items, batch_neg_items)

In [18]:
# batch_pos_neg_labels

tensor([[32115, 28194, 21421, 18533,  6468, 23668, 42851, 10046, 40984, 49278,
         21264,  5819, 24200, 33788, 10791, 26484, 50754, 25508, 21396, 15941,
         22990, 41688, 17044, 26760, 42211,  2168, 12126,  4852,  9065, 38664,
         22303, 27577, 38019, 26925, 40179,  5836,  7054, 49337, 15728, 20238,
         18254, 12742, 23836, 42906, 47775, 32403,  9764, 18732,  6613,  5310,
         27452, 17747, 40770, 11488, 35654, 27395,  6554, 22329, 47689, 29664,
          5811, 34993, 33453, 42565, 32115, 28194, 21421, 18533,  6468, 23668,
         42851, 10046, 40984, 49278, 21264,  5819, 24200, 33788, 10791, 26484,
         50754, 25508, 21396, 15941, 22990, 41688, 17044, 26760, 42211,  2168,
         12126,  4852,  9065, 38664, 22303, 27577, 38019, 26925, 40179,  5836,
          7054, 49337, 15728, 20238, 18254, 12742, 23836, 42906, 47775, 32403,
          9764, 18732,  6613,  5310, 27452, 17747, 40770, 11488, 35654, 27395,
          6554, 22329, 47689, 29664,  5811, 34993, 3

In [19]:
# batch_pos_neg_labels[1].min()

tensor(55920)

In [21]:
# out = model(edge_index, batch_pos_neg_labels, edge_weight)
# out

tensor([ 7.8876e-04,  6.6548e-04,  2.9060e-04,  5.8692e-04,  1.0170e-03,
         2.4106e-04,  7.3517e-04,  9.2481e-04,  8.5302e-04,  3.4077e-04,
         8.7647e-04,  5.4225e-04,  7.7014e-04,  2.6808e-04,  2.2831e-04,
         4.8092e-04,  7.1286e-04,  5.4340e-04,  4.8758e-04,  3.8278e-04,
         5.8570e-04,  1.1317e-04,  5.9058e-04,  8.1913e-04,  3.6317e-04,
         9.2634e-04,  2.7700e-04,  2.9253e-04,  4.8097e-04,  3.0643e-04,
         4.2362e-04,  4.0224e-04,  6.0253e-04,  7.5705e-04,  9.7720e-04,
         6.4622e-04,  2.8558e-04,  4.4712e-04,  5.3895e-04,  3.5354e-04,
         4.0005e-04,  5.8406e-04,  4.5432e-04,  4.0665e-04,  5.5185e-04,
         3.0077e-04,  2.7294e-04,  7.7393e-04,  5.5425e-04,  6.9607e-04,
         3.6969e-04,  4.5125e-04,  8.1844e-04,  4.5082e-04,  2.1936e-04,
         8.2935e-04,  8.0989e-04,  8.6049e-04,  4.5159e-04,  5.2694e-04,
         8.0631e-04,  8.8362e-04,  3.1063e-04,  3.9110e-04, -8.7095e-05,
        -1.8979e-05,  5.5088e-05, -2.3548e-05, -1.9

In [22]:
# size = len(batch)
#
# bpr_loss = model.recommendation_loss(out[:size], out[size:], 0) * size
# reg_loss = regularization_loss(model.embedding.weight, size, batch_usr, batch_pos_items, batch_neg_items)
# loss = bpr_loss + reg_loss

In [23]:
# bpr_loss_batch_list = []
# reg_loss_batch_list = []
# final_loss_batch_list = []

In [24]:
# bpr_loss_batch_list.append(bpr_loss.item())
# reg_loss_batch_list.append(reg_loss.item())
# final_loss_batch_list.append(loss.item())

In [25]:
# bpr_loss = round(np.mean(bpr_loss_batch_list), 8)
# reg_loss = round(np.mean(reg_loss_batch_list), 8)
# final_loss = round(np.mean(final_loss_batch_list), 8)

In [26]:
# print("bpr loss: ", loss, "reg loss: ", reg_loss, "final loss", final_loss)

bpr loss:  tensor(0.6929, grad_fn=<AddBackward0>) reg loss:  2.5e-07 final loss 0.69287777


In [27]:
# bpr_loss_epoch_list = []
# reg_loss_epoch_list = []
# final_loss_epoch_list = []
# recall_epoch_list = []
# precision_epoch_list = []

In [29]:
# model.eval()
# with torch.no_grad():
#     embeds = model.get_embedding(edge_index, edge_weight)   # ?? ???
#     final_usr_embed, final_item_embed = torch.split(embeds, (n_users, n_items))

In [30]:
matrix = interact_matrix(train_df, n_users, n_items)

In [31]:
test_topK_recall, test_topK_precision = get_metrics(final_usr_embed, final_item_embed, matrix, val_df, 20)

In [32]:
print('precision: ', test_topK_precision, 'recall: ', test_topK_recall)

precision:  2.7449903925336262e-05 recall:  0.0002744990392533626


In [33]:
matrix

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [34]:
relevance_score = final_usr_embed @ final_item_embed.t()
relevance_score

tensor([[ 6.5503e-05, -1.2144e-05, -3.5105e-05,  ..., -2.7406e-05,
          4.9693e-06, -6.8124e-05],
        [-7.2272e-05, -2.6420e-05,  1.3902e-05,  ..., -2.7292e-05,
         -3.2123e-05, -4.6970e-05],
        [-6.1275e-06,  5.6032e-05, -2.4367e-05,  ..., -1.6769e-04,
         -1.6696e-05, -4.8051e-05],
        ...,
        [-1.9772e-05,  8.6049e-05, -1.0757e-04,  ..., -2.3527e-05,
         -1.4579e-05, -2.4870e-05],
        [ 3.7011e-05, -8.9452e-06,  1.0440e-06,  ...,  2.9516e-05,
         -4.1773e-05,  2.5515e-05],
        [-1.4688e-04, -1.2512e-05, -8.6683e-05,  ...,  5.1958e-06,
          3.2031e-05, -1.0007e-05]])

In [35]:
relevance_score = matrix = torch.mul(relevance_score, (1 - matrix))
relevance_score.shape

torch.Size([55837, 21848])

In [36]:
topk_relevance_indices = torch.topk(relevance_score, 20).indices + n_users
topk_relevance_indices[0:6]

tensor([[68953, 61906, 57160, 65925, 72183, 66884, 75556, 64867, 58660, 68567,
         76008, 55916, 68004, 72730, 57959, 58041, 73858, 67402, 69536, 71235],
        [60379, 62224, 56314, 74592, 63314, 57673, 67046, 62044, 59997, 75624,
         61462, 63471, 76928, 69757, 70303, 57168, 75129, 59868, 61044, 67499],
        [74922, 73959, 65184, 58568, 71186, 71598, 61827, 69022, 68489, 72623,
         69591, 59018, 74509, 75462, 68181, 59361, 60893, 63412, 70471, 76413],
        [67858, 58989, 76093, 72266, 75732, 67497, 65922, 69232, 64275, 73523,
         66349, 57818, 65273, 66434, 69224, 57474, 69060, 74521, 58485, 60997],
        [68902, 60080, 74127, 74931, 65532, 66609, 61802, 68125, 63193, 77664,
         59371, 73148, 61979, 60451, 56652, 64926, 58533, 61027, 65662, 73434],
        [62182, 64137, 76442, 69006, 56525, 63828, 65370, 57310, 74259, 56245,
         58883, 58902, 60310, 70391, 64302, 59167, 62294, 62193, 62089, 65281]])

In [42]:
topk_relevance_indices = torch.topk(relevance_score, 20).indices
topk_relevance_indices[0:6]

tensor([[2565, 3615, 4526, 1155, 2243, 2282, 3804,  437,  775, 1583, 2662, 4791,
         1505, 3363,  326,  850, 4986, 3243,  152, 2113],
        [1221, 1365, 3813, 3627, 3836, 3190, 3722,  198, 3091, 1454, 4221,  229,
          906, 4833, 3938, 2276, 3373, 3600, 4000,  714],
        [3991, 3492,   88, 1285,  376, 1062, 2144, 1893, 1070, 4829, 3643, 2802,
          968,  398, 1635, 1026,  135, 3445, 1207, 1627],
        [2881, 4725, 4520,  333, 2503, 1540,  915,  445, 2594, 4724, 3747,  135,
          627, 3414, 4910, 3812, 3903,  563, 1867, 3467],
        [ 488, 4794,  302, 4236, 3975, 3400,  454,   90,  664, 2241, 3914, 2190,
          924, 3250, 4042, 1715, 2239, 2732, 3422, 3256],
        [4432, 4330, 2292, 1054, 4223, 2485, 3655, 4589, 2029,   13, 4574, 1789,
         3842, 3473, 3344, 1615, 3489,  341, 1594, 3313]])

In [37]:
topk_relevance_indices_df = pd.DataFrame(topk_relevance_indices.cpu().numpy())
topk_relevance_indices_df['top_rlvnt_itm'] = topk_relevance_indices_df.values.tolist()
topk_relevance_indices_df['user_ID'] = topk_relevance_indices_df.index
topk_relevance_indices_df = topk_relevance_indices_df[['user_ID', 'top_rlvnt_itm']]
topk_relevance_indices_df

Unnamed: 0,user_ID,top_rlvnt_itm
0,0,"[68953, 61906, 57160, 65925, 72183, 66884, 75556, 64867, 58660, 68567, 76008, 55916, 68004, 72730, 57959, 58041, 73858, 67402, 69536, 71235]"
1,1,"[60379, 62224, 56314, 74592, 63314, 57673, 67046, 62044, 59997, 75624, 61462, 63471, 76928, 69757, 70303, 57168, 75129, 59868, 61044, 67499]"
2,2,"[74922, 73959, 65184, 58568, 71186, 71598, 61827, 69022, 68489, 72623, 69591, 59018, 74509, 75462, 68181, 59361, 60893, 63412, 70471, 76413]"
3,3,"[67858, 58989, 76093, 72266, 75732, 67497, 65922, 69232, 64275, 73523, 66349, 57818, 65273, 66434, 69224, 57474, 69060, 74521, 58485, 60997]"
4,4,"[68902, 60080, 74127, 74931, 65532, 66609, 61802, 68125, 63193, 77664, 59371, 73148, 61979, 60451, 56652, 64926, 58533, 61027, 65662, 73434]"
...,...,...
55832,55832,"[64299, 76591, 57933, 72247, 69309, 57926, 59647, 76905, 64896, 56331, 59176, 57818, 65543, 57722, 61798, 58568, 62853, 66775, 57259, 71862]"
55833,55833,"[56101, 66487, 66415, 74465, 60585, 59253, 72778, 64483, 59029, 64827, 69660, 64413, 70982, 59708, 72951, 59596, 66116, 63067, 62456, 58020]"
55834,55834,"[67003, 70988, 70444, 70479, 64497, 75580, 59297, 63397, 65770, 60378, 60057, 61511, 75250, 70652, 72099, 74239, 65402, 66967, 74698, 63418]"
55835,55835,"[68662, 65931, 60242, 69695, 58560, 59252, 73044, 75178, 69302, 69778, 60741, 71776, 71649, 69069, 68339, 63173, 65107, 75160, 65591, 56379]"


In [38]:
test_interacted_items = val_df.groupby('user_id_idx')['item_id_idx'].apply(list).reset_index()
test_interacted_items

Unnamed: 0,user_id_idx,item_id_idx
0,9,"[12828, 2301]"
1,16,[12633]
2,19,[12936]
3,34,[21064]
4,47,"[17121, 9641]"
...,...,...
3638,55601,[1110]
3639,55607,[5488]
3640,55747,[4594]
3641,55780,[2251]


In [39]:
metrics_df = pd.merge(test_interacted_items, topk_relevance_indices_df, how='left', left_on='user_id_idx', right_on=['user_ID'])
metrics_df

Unnamed: 0,user_id_idx,item_id_idx,user_ID,top_rlvnt_itm
0,9,"[12828, 2301]",9,"[68780, 58231, 68539, 71708, 65792, 65621, 76969, 63235, 65050, 59871, 76228, 68064, 74592, 60264, 73602, 74730, 71644, 66758, 74514, 67201]"
1,16,[12633],16,"[72451, 65526, 66905, 70322, 62534, 77137, 59169, 66668, 60812, 61102, 75475, 75044, 74279, 69698, 77154, 59350, 72246, 62717, 76225, 70751]"
2,19,[12936],19,"[75287, 75767, 60299, 61737, 76128, 72479, 67589, 75825, 62614, 73229, 58816, 76207, 59518, 75276, 65993, 77139, 60387, 63435, 67354, 67914]"
3,34,[21064],34,"[62304, 70380, 71033, 60507, 66624, 58211, 71919, 71519, 66847, 62329, 55877, 66358, 65670, 75128, 57702, 63089, 59707, 75093, 64499, 66225]"
4,47,"[17121, 9641]",47,"[59131, 73765, 76128, 67606, 69739, 71819, 71666, 77270, 60841, 76009, 76602, 75770, 64656, 57772, 57256, 62086, 58717, 70292, 65649, 73905]"
...,...,...,...,...
3638,55601,[1110],55601,"[63308, 65107, 64265, 62099, 70077, 63206, 57981, 63194, 58675, 59667, 58462, 68910, 65612, 76601, 70542, 61576, 74811, 71583, 62569, 74936]"
3639,55607,[5488],55607,"[67431, 73624, 71891, 65735, 72983, 61669, 73151, 75909, 59448, 66304, 68771, 71172, 70725, 57288, 77522, 66595, 57802, 74719, 69220, 65932]"
3640,55747,[4594],55747,"[70393, 62620, 72378, 73172, 68485, 59062, 62885, 69143, 72402, 73470, 77483, 58416, 74249, 63960, 57507, 76064, 59646, 68004, 72520, 61950]"
3641,55780,[2251],55780,"[57929, 64702, 65787, 67361, 63609, 71610, 67988, 73280, 59296, 70380, 69591, 62481, 60418, 60214, 62391, 59538, 68638, 55842, 64421, 57221]"


In [40]:
metrics_df['intrsctn_itm'] = [list(set(a).intersection(b)) for a, b in
                                  zip(metrics_df.item_id_idx, metrics_df.top_rlvnt_itm)]
metrics_df

Unnamed: 0,user_id_idx,item_id_idx,user_ID,top_rlvnt_itm,intrsctn_itm
0,9,"[12828, 2301]",9,"[68780, 58231, 68539, 71708, 65792, 65621, 76969, 63235, 65050, 59871, 76228, 68064, 74592, 60264, 73602, 74730, 71644, 66758, 74514, 67201]",[]
1,16,[12633],16,"[72451, 65526, 66905, 70322, 62534, 77137, 59169, 66668, 60812, 61102, 75475, 75044, 74279, 69698, 77154, 59350, 72246, 62717, 76225, 70751]",[]
2,19,[12936],19,"[75287, 75767, 60299, 61737, 76128, 72479, 67589, 75825, 62614, 73229, 58816, 76207, 59518, 75276, 65993, 77139, 60387, 63435, 67354, 67914]",[]
3,34,[21064],34,"[62304, 70380, 71033, 60507, 66624, 58211, 71919, 71519, 66847, 62329, 55877, 66358, 65670, 75128, 57702, 63089, 59707, 75093, 64499, 66225]",[]
4,47,"[17121, 9641]",47,"[59131, 73765, 76128, 67606, 69739, 71819, 71666, 77270, 60841, 76009, 76602, 75770, 64656, 57772, 57256, 62086, 58717, 70292, 65649, 73905]",[]
...,...,...,...,...,...
3638,55601,[1110],55601,"[63308, 65107, 64265, 62099, 70077, 63206, 57981, 63194, 58675, 59667, 58462, 68910, 65612, 76601, 70542, 61576, 74811, 71583, 62569, 74936]",[]
3639,55607,[5488],55607,"[67431, 73624, 71891, 65735, 72983, 61669, 73151, 75909, 59448, 66304, 68771, 71172, 70725, 57288, 77522, 66595, 57802, 74719, 69220, 65932]",[]
3640,55747,[4594],55747,"[70393, 62620, 72378, 73172, 68485, 59062, 62885, 69143, 72402, 73470, 77483, 58416, 74249, 63960, 57507, 76064, 59646, 68004, 72520, 61950]",[]
3641,55780,[2251],55780,"[57929, 64702, 65787, 67361, 63609, 71610, 67988, 73280, 59296, 70380, 69591, 62481, 60418, 60214, 62391, 59538, 68638, 55842, 64421, 57221]",[]
