In [1]:
import pandas as pd
import numpy as np
import argparse
import random
from model import KGCN
from data_loader import DataLoader
import torch
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score

In [2]:
# parser = argparse.ArgumentParser()

# parser.add_argument('--dataset', type=str, default='music', help='which dataset to use')
# parser.add_argument('--aggregator', type=str, default='sum', help='which aggregator to use')
# parser.add_argument('--n_epochs', type=int, default=100, help='the number of epochs')
# parser.add_argument('--neighbor_sample_size', type=int, default=8, help='the number of neighbors to be sampled')
# parser.add_argument('--dim', type=int, default=16, help='dimension of user and entity embeddings')
# parser.add_argument('--n_iter', type=int, default=1, help='number of iterations when computing entity representation')
# parser.add_argument('--batch_size', type=int, default=32, help='batch size')
# parser.add_argument('--l2_weight', type=float, default=1e-4, help='weight of l2 regularization')
# parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
# parser.add_argument('--ratio', type=float, default=0.6, help='size of training dataset')

# args = parser.parse_args([])

In [3]:
# parser = argparse.ArgumentParser()

# parser.add_argument('--dataset', type=str, default='ml100k', help='which dataset to use')
# parser.add_argument('--aggregator', type=str, default='sum', help='which aggregator to use')
# parser.add_argument('--n_epochs', type=int, default=100, help='the number of epochs')
# parser.add_argument('--neighbor_sample_size', type=int, default=8, help='the number of neighbors to be sampled')
# parser.add_argument('--dim', type=int, default=16, help='dimension of user and entity embeddings')
# parser.add_argument('--n_iter', type=int, default=1, help='number of iterations when computing entity representation')
# parser.add_argument('--batch_size', type=int, default=16, help='batch size')
# parser.add_argument('--l2_weight', type=float, default=1e-4, help='weight of l2 regularization')
# parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
# parser.add_argument('--ratio', type=float, default=0.6, help='size of training dataset')

# args = parser.parse_args([])

In [4]:
parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default='bookcrossing', help='which dataset to use')
parser.add_argument('--aggregator', type=str, default='sum', help='which aggregator to use')
parser.add_argument('--n_epochs', type=int, default=100, help='the number of epochs')
parser.add_argument('--neighbor_sample_size', type=int, default=8, help='the number of neighbors to be sampled')
parser.add_argument('--dim', type=int, default=32, help='dimension of user and entity embeddings')
parser.add_argument('--n_iter', type=int, default=1, help='number of iterations when computing entity representation')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--l2_weight', type=float, default=1e-4, help='weight of l2 regularization')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--ratio', type=float, default=0.6, help='size of training dataset')

args = parser.parse_args([])

In [5]:
# parser = argparse.ArgumentParser()

# parser.add_argument('--dataset', type=str, default='ml1m', help='which dataset to use')
# parser.add_argument('--aggregator', type=str, default='sum', help='which aggregator to use')
# parser.add_argument('--n_epochs', type=int, default=100, help='the number of epochs')
# parser.add_argument('--neighbor_sample_size', type=int, default=8, help='the number of neighbors to be sampled')
# parser.add_argument('--dim', type=int, default=32, help='dimension of user and entity embeddings')
# parser.add_argument('--n_iter', type=int, default=1, help='number of iterations when computing entity representation')
# parser.add_argument('--batch_size', type=int, default=2048, help='batch size')
# parser.add_argument('--l2_weight', type=float, default=1e-7, help='weight of l2 regularization')
# parser.add_argument('--lr', type=float, default=2e-2, help='learning rate')
# parser.add_argument('--ratio', type=float, default=0.6, help='size of training dataset')

# args = parser.parse_args([])

In [6]:
# build dataset and knowledge graph
data_loader = DataLoader(args.dataset)
kg = data_loader.load_kg()
df_dataset = data_loader.load_dataset()
df_dataset

Construct knowledge graph ... Done
Build dataset dataframe ... Done


Unnamed: 0,userID,itemID,label
0,190,1216,1
1,3317,1824,1
2,8686,62938,0
3,15392,61295,0
4,15731,55338,0
...,...,...,...
139741,7530,34662,0
139742,1855,3786,1
139743,17081,7056,0
139744,3256,73870,0


In [7]:
# Dataset class
class KGCNDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        user_id = np.array(self.df.iloc[idx]["userID"])
        item_id = np.array(self.df.iloc[idx]["itemID"])
        label = np.array(self.df.iloc[idx]["label"], dtype=np.float32)
        return user_id, item_id, label


In [8]:
# train test split
x_train, x_test, y_train, y_test = train_test_split(
    df_dataset,
    df_dataset["label"],
    test_size=1 - args.ratio,
    shuffle=False,
    random_state=999,
)
x_val, x_test, y_val, y_test = train_test_split(
    x_test, y_test, test_size=0.5, shuffle=False, random_state=999
)
train_dataset = KGCNDataset(x_train)
val_dataset = KGCNDataset(x_val)
test_dataset = KGCNDataset(x_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_dataset.__len__())

test_loader_epoch = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)

In [9]:
# prepare network, loss function, optimizer
num_user, num_entity, num_relation = data_loader.get_num()
user_encoder, entity_encoder, relation_encoder = data_loader.get_encoders()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
net = KGCN(num_user, num_entity, num_relation, kg, args, device).to(device)
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.l2_weight)
print("device: ", device)

device:  cuda


In [10]:
name_version = "KGCN"
patience = 5

In [11]:
from pytorchtools import EarlyStopping
# add early stopping
early_stopping = EarlyStopping(patience=patience, verbose=True, path=f'./checkpoint/{name_version}_{args.dataset}.pt', delta=0.005)
# early_stopping = EarlyStopping(patience=patience, verbose=True, path=f'./checkpoint/{name_version}_{args.dataset}.pt', delta=0.01)

# train
loss_list = []
val_loss_list = []
auc_score_list = []
f1_score_list = []

import time

start_train = time.time()
print("start_train:", start_train)


for epoch in range(args.n_epochs):
    running_loss = 0.0

    for i, (user_ids, item_ids, labels) in enumerate(train_loader):
        user_ids, item_ids, labels = (
            user_ids.to(device),
            item_ids.to(device),
            labels.to(device),
        )
        optimizer.zero_grad()
        outputs = net(user_ids, item_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # print train loss per every epoch
    print("[Epoch {}]".format(epoch + 1))
    print("train_loss: ".format(epoch + 1), running_loss / len(train_loader))
    loss_list.append(running_loss / len(train_loader))

    # evaluate per every epoch
    with torch.no_grad():
        val_loss = 0
        total_roc = 0
        total_f1 = 0
        for user_ids, item_ids, labels in val_loader:
            user_ids, item_ids, labels = (
                user_ids.to(device),
                item_ids.to(device),
                labels.to(device),
            )
            outputs = net(user_ids, item_ids)
            val_loss += criterion(outputs, labels).item()
            outputs = outputs.cpu().detach().numpy()
            labels = labels.cpu().detach().numpy()
            total_roc += roc_auc_score(labels, outputs)
            outputs = np.where(outputs >= 0.5, 1, 0)
            total_f1 += f1_score(labels, outputs)

        print("val_loss: ".format(epoch + 1), val_loss / len(val_loader))
        print("val_auc: ".format(epoch + 1), total_roc / len(val_loader))
        print("val_f1: ".format(epoch + 1), total_f1 / len(val_loader))
        print("--------------------------------")
        val_loss_list.append(val_loss / len(val_loader))
        auc_score_list.append(total_roc / len(val_loader))
        f1_score_list.append(total_f1 / len(val_loader))

    # early stopping
    early_stopping(val_loss / len(val_loader), net)
    if early_stopping.early_stop:
        print("Early stopping")
        break

end_train = time.time()
print("end_train:", end_train)
print("end_train - start_train:", end_train - start_train)

pytorch tools loaded
start_train: 1669214150.660836
[Epoch 1]
train_loss:  1.218475737586254
val_loss:  1.0021519279915447
val_auc:  0.5058752159045528
val_f1:  0.5075499807001249
--------------------------------
Validation loss decreased (inf --> 1.002152).  Saving model ...
[Epoch 2]
train_loss:  0.8689831809208888
val_loss:  0.764110267978825
val_auc:  0.508602720424158
val_f1:  0.5059092667850424
--------------------------------
Validation loss decreased (1.002152 --> 0.764110).  Saving model ...
[Epoch 3]
train_loss:  0.7220862172998306
val_loss:  0.7066533203538694
val_auc:  0.5115537133666634
val_f1:  0.5076688397906157
--------------------------------
Validation loss decreased (0.764110 --> 0.706653).  Saving model ...
[Epoch 4]
train_loss:  0.6931373578746144
val_loss:  0.697339199721541
val_auc:  0.513102886488381
val_f1:  0.5070530936239748
--------------------------------
Validation loss decreased (0.706653 --> 0.697339).  Saving model ...
[Epoch 5]
train_loss:  0.686940216

In [12]:
# load the last checkpoint with the best model
net = KGCN(num_user, num_entity, num_relation, kg, args, device).to(device)
net.load_state_dict(torch.load(f'./checkpoint/{name_version}_{args.dataset}.pt'))

# test
with torch.no_grad():
    total_roc = 0
    total_f1 = 0
    for user_ids, item_ids, labels in test_loader:
        user_ids, item_ids, labels = (
            user_ids.to(device),
            item_ids.to(device),
            labels.to(device),
        )
        outputs = net(user_ids, item_ids)
        outputs = outputs.cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()
        print("outputs:", outputs)
        print("labels:", labels)
        total_roc += roc_auc_score(labels, outputs)
        outputs = np.where(outputs >= 0.5, 1, 0)
        total_f1 += f1_score(labels, outputs)

    print("test_auc: ", total_roc / len(test_loader))
    print("test_f1: ", total_f1 / len(test_loader))

outputs: [0.38348708 0.38004613 0.56233376 ... 0.44893408 0.3101708  0.373608  ]
labels: [0. 0. 0. ... 0. 0. 0.]
test_auc:  0.8175006227426221
test_f1:  0.7484843175048187


In [13]:
len(test_loader_epoch)

219

In [14]:
# load the last checkpoint with the best model
net = KGCN(num_user, num_entity, num_relation, kg, args, device).to(device)
net.load_state_dict(torch.load(f'./checkpoint/{name_version}_{args.dataset}.pt'))

# test
with torch.no_grad():
    total_roc = 0
    total_f1 = 0
    for user_ids, item_ids, labels in test_loader_epoch:
        user_ids, item_ids, labels = (
            user_ids.to(device),
            item_ids.to(device),
            labels.to(device),
        )
        outputs = net(user_ids, item_ids)
        outputs = outputs.cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()
        # print("outputs:", outputs)
        # print("labels:", labels)
        total_roc += roc_auc_score(labels, outputs)
        outputs = np.where(outputs >= 0.5, 1, 0)
        total_f1 += f1_score(labels, outputs)

    print("test_auc: ", total_roc / len(test_loader_epoch))
    print("test_f1: ", total_f1 / len(test_loader_epoch))

test_auc:  0.8177852042568738
test_f1:  0.7483137543077604
