# Recommendation system: LightGCN model

@Jialu Wang, University of Notre Dame 


Reference: 

1.Xiangnan He et al. LightGCN: Simplifying and Powering Graph Convolution network for recommendation.

In [None]:
import os
from os.path import join
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.sparse import csr_matrix
from tensorboardX import SummaryWriter
import time

In [None]:
# data process
from dataloader import Loader
# model implementation
from model import LightGCN, PairWiseModel
# config settings
import world
# utility function
import utils
# training and testing functions
import Procedure

### Load data

In [None]:
seeData = Loader()
print("No. of users:", seeData.n_users, "No. of items:", seeData.m_item)

In [None]:
graph = seeData.getSparseGraph()
print(seeData.Graph)

## Create model

In [None]:
rec_model = LightGCN(config=world.config, dataset=seeData)

In [None]:
print("No. of users:", rec_model.num_users, ", No. of items:", rec_model.num_items)
print("Embedding size:", rec_model.latent_dim)

In [None]:
print(rec_model.embedding_user)
print(rec_model.embedding_item)
print(np.shape(rec_model.Graph))
print(rec_model.Graph)

In [None]:
rec_model.train()

## Train one epoch

In [None]:
pair = PairWiseModel()

BPRLoss_input = utils.BPRLoss(rec_model, world.config)

In [None]:
w : SummaryWriter = SummaryWriter(join(world.BOARD_PATH, time.strftime("%m-%d-%Hh%Mm%Ss-") + "-" + world.comment))
    
    

In [None]:
loss_avg = []
test_rec = []
test_recall = []
test_ndcg = []

In [None]:
for epoch in range(300):
    if (epoch)%10==0:
        print("==========test==========")
        test_epoch = Procedure.Test(seeData, rec_model, epoch, w, world.config['multicore'])
        print(test_epoch['recall'], test_epoch['ndcg'])
        test_rec.append(test_epoch)
        test_recall.append(test_epoch['recall'][0])
        test_ndcg.append(test_epoch['ndcg'][0])
        
    output = Procedure.BPR_train_original(seeData, rec_model, BPRLoss_input, epoch, neg_k=1, w=w)
    loss_avg.append(output)

In [None]:
print(loss_avg)
print(test_recall)
print(test_ndcg)

In [None]:
test_recall_list, test_ndcg_list = [], []
for no, i in enumerate(test_recall):
    if no>=10:
        test_recall_list.append(i)
    else:
        test_recall_list.append(i[0])
    
for no, j in enumerate(test_ndcg):
    if no>=10:
        test_ndcg_list.append(j)
    else:
        test_ndcg_list.append(j[0])
    
print(test_ndcg_list)

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(5,300), loss_avg[5:])

si=18
plt.rc('axes', titlesize=si)
plt.rc('axes', labelsize=si)
plt.rc('xtick', labelsize=si)
plt.rc('ytick', labelsize=si)
plt.rc('legend', fontsize=14)

plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.title("Gowalla")

In [None]:
plt.plot(range(30), test_recall_list)

si=18
plt.rc('axes', titlesize=si)
plt.rc('axes', labelsize=si)
plt.rc('xtick', labelsize=si)
plt.rc('ytick', labelsize=si)
plt.rc('legend', fontsize=14)

plt.xlabel("Epoch")
plt.ylabel("Recall")
plt.title("Gowalla")

In [None]:
plt.plot(range(30), test_ndcg_list)

si=18
plt.rc('axes', titlesize=si)
plt.rc('axes', labelsize=si)
plt.rc('xtick', labelsize=si)
plt.rc('ytick', labelsize=si)
plt.rc('legend', fontsize=14)

plt.xlabel("Epoch")
plt.ylabel("ndcg")
plt.title("Gowalla")

In [None]:
import pickle

# Its important to use binary mode
loss100 = open('gcn300-loss', 'wb')

# source, destination
pickle.dump(loss_avg, loss100)                     
loss100.close()

# Its important to use binary mode
loss100_recall = open('gcn300-recall', 'wb')

# source, destination
pickle.dump(test_recall_list, loss100_recall)                     
loss100_recall.close()

# Its important to use binary mode
loss100_ndcg = open('gcn300-ndcg', 'wb')

# source, destination
pickle.dump(test_ndcg_list, loss100_ndcg)                     
loss100_ndcg.close()

In [None]:
file = open('gcn300-loss', 'rb')

# dump information to that file
data_loss = pickle.load(file)

# close the file
file.close()

file2 = open('gcn300-recall', 'rb')

# dump information to that file
data_recall = pickle.load(file2)

# close the file
file2.close()

file3 = open('gcn300-ndcg', 'rb')

# dump information to that file
data_ndcg = pickle.load(file3)

# close the file
file3.close()

In [None]:
print(data_loss)
print(data_recall)
print(data_ndcg)