# Link Prediction with RHGNN using the LastFM1b Database

##### Objectives of this notebook

In this notebook, we will be combing the knowledge presented in a modern Heterogeneous Graph Neural Network topic (RHGNNs) with a customized collection of data. 
This will be achieved in hopes to show the applications of RHGNNs for downstream tasks like node classification and link prediction. 

More specifically, this notebook will utilize the dataset provided by LastFM, called LastFM1b. With this collection of listening events 
we will create a heterogeneous graph that our machine learning model with utilize to create high-quality embeddings. 
Once this is finished we will utilize the embeddings to find patterns in user listening behavior


##### Breif Intro to RHGNN
Representational Lerning for Hetergeneous Graph Neural Networks is a trending topic in machine Learning. 
Most heterogenous methods are utilized for propagation of singular node representations, whereas in RHGNN's method, 
the relational information that exists between nodes is utilized for imporving the noderepresentations. 
Each convolutional component of this model is able to learn node represenations of a singlar relational type. 
After this is achieved, a "cross relational" message passing module is able to improve the node represenations by 
factoring in the characteristics of the nodes relational connections. These **relationally aware** representations are then 
passed through a heterogeneous gnn layer stack to allow the model to capture the "relational semantics." This makes 
RHGNN can encapsulate the characteristics of the relational connections between nodes on a heterogeneous graph.

For more info, see the paper "Representational Learning for Heterogeneous Graph Neural Networks" here: https://arxiv.org/abs/2105.11122

##### Breif Intro to LFM1b
LFM1b dataset consists of more than one billion listening events, intended to be used for various music retrieval and recommendation tasks. 
A paper describing the dataset was accepted to the ACM International Conference on Multimedia Retrieval (ICMR) 2016 and is available for download. 

For more info, see the paper "The LFM-1b Dataset for Music Retrieval and Recommendation" or the main download page here: http://www.cp.jku.at/datasets/LFM-1b/ 



##### Methodology of the Notebook

Before the notebook starts it is worth mentioning the fundamentals of a library call Deep Graph Library (DGL). DGL is one of the many competing deep learning with graph networks python libaries. 
Other notebale libraries include, PyTorch Geometric, Spektral, and many others collections. 

Within this particular notebook we will be utilizing DGL's graph database frameworks to compute complex measurements on large graphs. 
As with most libraries, I recommend reading through the user guide in the documentation, as well as maybe another library like pytorch geometric to see the similarities and differeneces.

With all of this being said, and for the interested learner, I'll refer everyone to the Standford CSS224W course on Graphs for Machine Learning here: https://web.stanford.edu/class/cs224w/. 
Many more resources are given as you work your way through the course, which I've found to help my understanding

# Imports

Here's a note on the requirements needed to operate this notebook...

You will want a GPU, it might take to long otherwise. 
Additionally the following requirements:

* PyTorch 1.7.1
* DGL 0.5.3
* PyTorch Geometric 1.6.3
* OGB 1.3.1
* tqdm
* numpy

In [29]:
import os 
import torch as th 
import torch.nn as nn
from tqdm import tqdm
import copy
import json
import shutil
import warnings
from utils.utils import *
from dgl.data.utils import load_graphs
from utils.LinkScorePredictor import *
from utils.EarlyStopping import *
from model.R_HGNN import *

th.cuda.empty_cache()
warnings.filterwarnings('ignore')


# Notebook Arguments

In [42]:
import argparse
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=0, type=int, help='seed for reproducibility')
parser.add_argument('--sample_edge_rate', default=0.01, type=float, help='train: validate: test ratio')
parser.add_argument('--num_layers', default=2, type=int, help='number of convolutional layers for a model')
parser.add_argument('--batch_size', default=512, type=int, help='the number of edges to train in each batch')
parser.add_argument('--num_neg_samples', default=5, type=int, help='the number of negative edges to sample when training')
parser.add_argument('--node_min_neighbors', default=5, type=int, help='the number of nodes to sample per target node')
parser.add_argument('--shuffle',  default=True, type=str2bool, nargs='?', const=True, help='string bool wether to shuffle indicies before splitting')
parser.add_argument('--drop_last',  default=False, type=str2bool, nargs='?', const=True, help='string bool wether to drop the last sample in data loading')
parser.add_argument('--num_workers', default=2, type=int, help='number of workers for a specified data loader')
parser.add_argument('--hidden_dim', default=16, type=int, help='dimension of the hidden layer input')
parser.add_argument('--rel_input_dim', default=16, type=int, help='input dimension of the edges')
parser.add_argument('--rel_hidden_dim', default=16, type=int, help='hidden dimension of the edges')
parser.add_argument('--num_heads', default=8, type=int, help='the number of attention heads used')
parser.add_argument('--dropout', default=0.5, type=float, help='the dropout rate for the models')
parser.add_argument('--residual', default=True, type=str2bool, nargs='?', const=True, help='string for using the residual values in computation')
parser.add_argument('--norm', default=True, type=str2bool, nargs='?', const=True, help=' string for using normalization of values in computation')
parser.add_argument('--opt', default='adam', type=str, help='the name of the optimizer to be used')
parser.add_argument('--learning_rate', default=0.0001, type=float, help='the learning rate used for training')
parser.add_argument('--weight_decay', default=0.0, type=float, help='the decay of the weights used for training')
parser.add_argument('--epochs', default=200, type=int, help='the number of epochs to train the model with')
parser.add_argument('--patience', default=25, type=int, help='the number of epochs to allow before early stopping')
parser.add_argument('--split_by_users', default=False, type=str2bool, nargs='?', const=True, help='boolean inidicator if you want to split train/val/test by users and not just targetedges')
parser.add_argument('--device', default='cpu', type=str, help='GPU or CPU device specification')


args = parser.parse_args([])
print(args)

set_random_seed(args.seed)

Namespace(albums=True, artists=True, batch_size=512, context_size=7, device='cpu', drop_last=False, dropout=0.5, emb_dim=16, epochs=200, hidden_dim=16, learing_rate=0.0001, learning_rate=0.001, logs=100, metapath2vec=True, metapath2vec_epochs=5, metapath2vec_epochs_batch_size=512, n_users=None, node_min_neighbors=5, norm=True, norm_playcount_weight=True, num_heads=8, num_layers=2, num_neg_samples=5, num_workers=2, opt='adam', overwrite_preprocessed=False, overwrite_processed=False, patience=25, playcount_weight=False, popular_artists=False, rel_hidden_dim=16, rel_input_dim=16, residual=True, sample_edge_rate=0.01, seed=0, shuffle=True, split_by_users=False, tracks=True, walk_length=16, walks_per_node=3, weight_decay=0.0)


# Loading in heterogenous data 

As a mindful note, graphs are everywhere. In many cases with large databases,
there exists a graphical model that can represent the complex interconnections data in a single visualize. 
If this doesn't make to much since I recommend reading through some of the references above or taking a look at hese visuals

Specifcally a heterogenous graph is a colleciton of sets of nodes/vertices (V), edges/links (E), node types (A), and edge types (R) such that
<center>
<img
  src="https://latex.codecogs.com/svg.image?\LARGE&space;G=(V,E,A,R),&space;where&space;\begin{vmatrix}A&space;\\\end{vmatrix}&space;&plus;&space;\begin{vmatrix}&space;R&space;\\\end{vmatrix}&space;>2"
  />
</center>
<br>
<center>
<img
  src="https://latex.codecogs.com/svg.image?\LARGE&space;e=(\phi\left&space;(u&space;&space;\right&space;),\psi\left&space;(e&space;&space;\right&space;),\phi\left&space;(v&space;&space;\right&space;)),&space;where&space;\left\{u,&space;v&space;\epsilon&space;V&space;\right\},&space;\left\{e&space;\epsilon&space;E&space;\right\},&space;\left\{\phi\left&space;(u&space;&space;\right&space;),\phi\left&space;(v&space;&space;\right&space;)&space;\epsilon&space;A&space;\right\},&space;\left\{\psi\left&space;(e&space;&space;\right&space;)&space;\epsilon&space;R&space;\right\}"
/>
</center>


In many cases, this can be through of as a knowledge graph as well. These concepts tend to be blurred by the task at hand. **Here is an example of a heterogenous graph from the OBG baseline dataset found on PyTorch's documentation**

<center>
<img
  src="https://pytorch-geometric.readthedocs.io/en/latest/_images/hg_example.svg"
  />
</center>

For our purposes, we'll being using a custom made dataset :) It iterates through all the one billion listening events and uses the dgl heterogrpah object 
to create a massive heterogenous model for our machine learning model to work with.

Do mind this warning: This is note the full dataset being used in this notebook. Rather in this notebook we'll just pull the first 10 million
listening events and to collect the unique user, artist, album and track ids that exist in the database

In [43]:
# using DGl's load_graphs function to load pre-computed and processed files
glist,_=load_graphs('data/DGL_LFM1b/processed/lastfm1b.bin') # <- this file represents a subset of the full dataset
hg=glist[0] # hg=='heterogeneous graph' ;) from the list of graphs in the processed file (hint: theres only one) pick our heterogenous subset graph
hg

Graph(num_nodes={'album': 4902904, 'artist': 12173, 'genre': 21, 'track': 9111082, 'user': 10},
      num_edges={('album', 'album_listened_by', 'user'): 102644, ('album', 'produced_by', 'artist'): 4902904, ('artist', 'artist_listened_by', 'user'): 102644, ('artist', 'in_genre', 'genre'): 40063, ('artist', 'preformed', 'track'): 9111082, ('artist', 'produced', 'album'): 4902904, ('genre', 'is_genre_of', 'artist'): 40063, ('track', 'preformed_by', 'artist'): 9111082, ('track', 'track_listened_by', 'user'): 102644, ('user', 'listened_to_album', 'album'): 102644, ('user', 'listened_to_artist', 'artist'): 102644, ('user', 'listened_to_track', 'track'): 102644},
      metagraph=[('album', 'user', 'album_listened_by'), ('album', 'artist', 'produced_by'), ('user', 'album', 'listened_to_album'), ('user', 'artist', 'listened_to_artist'), ('user', 'track', 'listened_to_track'), ('artist', 'user', 'artist_listened_by'), ('artist', 'genre', 'in_genre'), ('artist', 'track', 'preformed'), ('artist', 

This is how DGL represents their HeteroData Object, for more info see there documentation...

You might see that in our meta-graph variable, there are also reverse edges labeled conviently with a '-rev' ending

In [44]:
# creating a dictionary of every edge and it's reverse edge
reverse_etypes = dict()
for stype, etype, dtype in hg.canonical_etypes: # for every edge type structured as (phi(u), psi(e), phi(v))
    for srctype, reltype, dsttype in hg.canonical_etypes:
        if srctype == dtype and dsttype == stype and reltype != etype:
            reverse_etypes[etype] = reltype
            break
reverse_etypes

{'album_listened_by': 'listened_to_album',
 'produced_by': 'produced',
 'artist_listened_by': 'listened_to_artist',
 'in_genre': 'is_genre_of',
 'preformed': 'preformed_by',
 'produced': 'produced_by',
 'is_genre_of': 'in_genre',
 'preformed_by': 'preformed',
 'track_listened_by': 'listened_to_track',
 'listened_to_album': 'album_listened_by',
 'listened_to_artist': 'artist_listened_by',
 'listened_to_track': 'track_listened_by'}

From our data, we are tasked with generating representations for each these nodes in a latent space. Once this is achieved, we can utilize 
different algorithms to generate predictions. However, throughout the academic research of GNNs, the authors of the published RHGNN paper have brought up a very important flaw that most 
traditional GNN models do not address. Specifically, many of these algorithms do not utilize the relational information that exists between 
nodes of different relatinoal edges for generating node representations.

As we continue through this notebook we will continue this dicussion of why it is important to utilize the relational information that exists in a 
graph for creating high quality node embedding representations. For now however, we need to understand what we will need for our task at hand

# Link Prediction on Heterogeneous Graphs

Simply put, this is the task for prediting the probability of a edge exsisting between two nodes in a graph. Mathematically, we can present it as 
The likelihood of connectivity between two nodes u and v such that

<center><img
src="https://latex.codecogs.com/svg.image?\LARGE&space;y_{u,v}=\phi\left&space;(&space;h_{u}^L&space;,&space;h_{v}^L\right&space;)"
/></center>

where we have a function 
<img
src="https://latex.codecogs.com/svg.image?\LARGE&space;\phi&space;"
/>
to predict the likelihood of an edge existing between the embedding representations that our GNN model is capable of computing
<img
src="https://latex.codecogs.com/svg.image?\LARGE&space;h_{u}^L,&space;h_{v}^L"
/>

If you care to theorize, you can notice that we haven't added the edge type into this equation. This is due to the fact that for hetergeneous models, 
we will have separate functions for each edge type that exists in the graph. (Or a least just the edges that we want to predict for a particular task).
 With this intution in mind, what does this mean for our heterogenous graph? This means that with our data we will be able to preform link prediciton for any of the edge types that exist! 

In this notebook we will just be working on predicting the likelihood of a user listening to a track, but the notebook can easily be extended to work with different edges

In this notebook the SAMPLED_EDGE_TYPE, references which link we want our model to learn to predict

In [45]:
hg.etypes

['album_listened_by',
 'produced_by',
 'artist_listened_by',
 'in_genre',
 'preformed',
 'produced',
 'is_genre_of',
 'preformed_by',
 'track_listened_by',
 'listened_to_album',
 'listened_to_artist',
 'listened_to_track']

In [65]:
SAMPLED_EDGE_TYPE='listened_to_track'

In [66]:
train_edge_idx, valid_edge_idx, test_edge_idx = get_predict_edge_index(
    hg,
    sample_edge_rate=args.sample_edge_rate,
    sampled_edge_type=SAMPLED_EDGE_TYPE,
    seed=args.seed)
    
print(f'train edge num: {len(train_edge_idx)}, valid edge num: {len(valid_edge_idx)}, test edge num: {len(test_edge_idx)}')

train edge num: 3078, valid edge num: 1026, test edge num: 1026


# Train, Val, Test Splits for link Prediction

Now the reason we stop here to discuss more about link prediction is due to our model needing to learn to compare scores between nodes 
connected by an edge against the scores between an arbitrary pair of nodes. (Wait what?)

Link prediction is a common unsupervised or self-supervised task. Meaning we need to split the our graph and create corresponding labels 
for our edges to train our model on. For training with Graph Neural Networks there exist two general methods for this (more have been discovered during academic research):


Inductive Splits:

    Where the training, validation, and test sets are different graphs. When using this approach a successful model should be able generalize unseen graphs for node, edge, graph level tasks

Transductive Splits: 

    Where the training, validation, and test sets exist all on the same graph. This might not be so intuitive to think about, but the original full graph has all the splits, but the labels for the edges are different. This is specifically only applicable to node or edge level tasks


For example, given an edge connecoing node u and v, the model will train to score between 
node u and v to be higher than a score between node 𝑢 and another node v' from an arbitrary noise distribution v'~P(v). 
This is known as a funamental concept of negative sampling. Because a score prediction model operates on graphs, 
we need to express the negative samples as graphs. The graph will contain all negative node pairs as edges.

So what does this mean for our hetergeous graph, and the inputs into our model of choice? Well it means we first need to utilize our labeled train, validation, and test edges to preform a specific split. 
For the purposes of this notebook, we will be using a inductive split of the graph.

DGL offers the unique ability to split a very large graph for the specific purposes of training a GNN model in a stochastic process. 
This alows us to generate small training batches that our model can iteratively learn as be evaluate the loss of the model over time.

Specifically, well be using the DataLoader object to create iterable objects that can create next batchs of the training, validation, 
and testing data for the model. To see the DGL implementation see the utils.py file and find the function get_edge_data_loader().


In [67]:
train_loader, val_loader, test_loader = get_edge_data_loader(
    args.node_min_neighbors,
    args.num_layers,
    hg,
    args.batch_size,
    SAMPLED_EDGE_TYPE,
    args.num_neg_samples,
    train_edge_idx=train_edge_idx,
    valid_edge_idx=valid_edge_idx,
    test_edge_idx=test_edge_idx,
    reverse_etypes=reverse_etypes,
    shuffle = args.shuffle, 
    drop_last = args.dropout,
    num_workers = args.num_workers
    )

# Model Selection: Relation-aware Heterogeneous Graph Neural Networks (RHGNN)

Now that we've prepared our dataset and have compiled the necessary iterators to send train, validation, and testing batches to our model. 
All that is left is to select the model. Model selection is a very important step in deep learning, it must particularly reflect the goals of the task, 
as well as utilize the data to it's full capabilities. 

For our heterogeneous graph we will be using a novel Relation-aware Heterogeneous Graph Neural Network. 
Published in 2021 by Le Yu, Leilei Sun, Bowen Du, Chuanren Liu, Weifeng Lv, and Hui Xiong, these researchers have implemented a model capable of 
high quality relationally aware node embeddings that are able to capture characteristics of not only the hetergenous neighbouring nodes, 
but the relation ships that exist between them. To be breif their research proposes 3 contributions as well as outlines the 4 necessary steps 
inorder for their models final node representations to be computed.

Contributions of the paper:
1. A methogolgy to compute relational aware node embeddings
2. A methogolgy to compute relational edge embeddings
3. A methogolgy to compute a embeddings through a fusing module that utilizes the information of both the pior contributions 

Computation steps that are used to compute the final node represenations
1. Multiple convolutional layers that are able to learn the specific node represenations independently of the relational connections they have
2. A Cross Relational Learning module to determine the importance of the edges between the nodes depending on the type of the relationship
3. A GNN containing the neccessary deep learning methodologies to utilize the piror computed representations to update the graph
4. A fusing aggregate module of relationally aware node representations that results in a singluar compact node representation to facilitate downstream prediction tasks

Here is the depiction the authors made for the published paper, it represents the full alogirhtm with the major components mentioned above


<center>
<img 
src="https://d3i71xaburhd42.cloudfront.net/8a30c43eec88d087d2029c8de1f3a7961b753340/4-Figure2-1.png"
/>
</center>


For our project, all we have to do is understand how the model works, the implementation of the physical model is actaully importable! 
If you havent'y noticed already, the directory you are in is actually a modified version of the published repository found here: https://github.com/yule-BUAA/R-HGNN

# Initializing the Model

In [68]:
r_hgnn = R_HGNN(graph=hg,
                input_dim_dict={ntype: hg.nodes[ntype].data['feat'].shape[1] for ntype in hg.ntypes},
                hidden_dim=args.hidden_dim, 
                relation_input_dim=args.rel_input_dim,
                relation_hidden_dim=args.rel_hidden_dim,
                num_layers=args.num_layers, 
                n_heads=args.num_heads, 
                dropout=args.dropout,
                residual=args.residual, 
                norm=args.norm)

link_score_predictor = LinkScorePredictor(args.hidden_dim * args.num_heads)

model = nn.Sequential(r_hgnn, link_score_predictor)
model = convert_to_gpu(model, device='cuda')

print(f'Model #Params: {get_n_params(model)}.')
print(model)

optimizer, scheduler = get_optimizer_and_lr_scheduler(
    model, 
    args.opt, 
    args.learning_rate, 
    args.weight_decay,
    steps_per_epoch=len(train_loader), 
    epochs=args.epochs)



# save the model result
save_result_dir= f"results/lfm1b-demo/{SAMPLED_EDGE_TYPE}"
if not os.path.exists(save_result_dir):
    os.makedirs(save_result_dir, exist_ok=True)

early_stopping = EarlyStopping(
    patience=args.patience, 
    save_model_folder=save_result_dir,
    save_model_name=SAMPLED_EDGE_TYPE)

Model #Params: 1077962.
Sequential(
  (0): R_HGNN(
    (relation_embedding): ParameterDict(
        (album_listened_by): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (artist_listened_by): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (in_genre): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (is_genre_of): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (listened_to_album): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (listened_to_artist): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (listened_to_track): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (preformed): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (preformed_by): Parameter containing: [torch.cuda.FloatTensor of size 16x1 (GPU 0)]
        (produced): Parameter containing: [torch.cuda.FloatTe

# Define the evaluation

Once we have our model, we'll need a way to measure the model's ability to predict the probability of an edge existing between two nodes with accuracy.
This means we'll first need an algorithm to determine the loss of the model.

Quick note on that... what loss function are well even using for Link Prediction

There are lots of loss functions that can achieve the behavior above if minimized. A non-exhaustive list include:
- Cross-entropy loss
- BPR loss
- Margin loss

By training our model to minimize any of the above mentioned loss functions, we will be able to achieve a model that is able to score nodes that 
should have an edge between them to have a higher score than two nodes who should not have a connection between them.

This definition will only be used twice for every epoch of the training process. 
It will compute the loss of the validation and tests sets after the training process in that specific epoch is complete.
This will let use visually see the change in loss over time for our model

In [69]:
def evaluate(model, loader, loss_func, sampled_edge_type, device, mode):
    """
    :param model: model
    :param loader: data loader (validate or test)
    :param loss_func: loss function
    :param sampled_edge_type: str
    :param device: device str
    :param mode: str, evaluation mode, validate or test
    :return:
    total_loss, y_trues, y_predicts
    """
    model.eval()
    with th.no_grad():
        y_trues = []
        y_predicts = []
        total_loss = 0.0
        for batch, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(loader):
            blocks = [convert_to_gpu(b, device=device) for b in blocks]
            blocks = [convert_to_gpu(b, device=device) for b in blocks]
            positive_graph, negative_graph = convert_to_gpu(positive_graph, negative_graph, device=device)
            # target node relation representation in the heterogeneous graph
            input_features = {(stype, etype, dtype): blocks[0].srcnodes[dtype].data['feat'] for stype, etype, dtype in
                              blocks[0].canonical_etypes}

            nodes_representation, _ = model[0](blocks, copy.deepcopy(input_features))

            positive_score = model[1](
                positive_graph, 
                nodes_representation, 
                sampled_edge_type).squeeze(dim=-1)
            negative_score = model[1](
                negative_graph, 
                nodes_representation, 
                sampled_edge_type).squeeze(dim=-1)

            y_predict = th.cat([positive_score, negative_score], dim=0)
            y_true = th.cat(
                [th.ones_like(positive_score), 
                th.zeros_like(negative_score)], dim=0)

            loss = loss_func(y_predict, y_true)
            total_loss += loss.item()
            y_trues.append(y_true.detach().cpu())
            y_predicts.append(y_predict.detach().cpu())

        total_loss /= (batch + 1)
        y_trues = th.cat(y_trues, dim=0)
        y_predicts = th.cat(y_predicts, dim=0)

    return total_loss, y_trues, y_predicts

# Training Loop

With our model defined and out data loaders defined, we can begin training our model. 
Our loss for our RHGNN model will be the Binary Cross Entropy loss. After every batch training block, 
we'll calculate the loss, using the ground truth edges, and the predicted edges. 
To calculate the error we'll use RMSE and MAE, and print the results of each epoch below

In [72]:
def train_model(model, optimizer, scheduler, train_loader, val_loader, test_loader, save_folder, sample_edge_type, date, args):
    shutil.rmtree(save_folder, ignore_errors=True)
    os.makedirs(save_folder, exist_ok=True)
    
    early_stopping = EarlyStopping(patience=args.patience, save_model_folder=save_folder, save_model_name=sample_edge_type)
    tqdm_loader = tqdm(range(args.epochs), total=args.epochs)
    loss_func = nn.BCELoss()
    train_steps = 0
    best_validate_RMSE, final_result = None, None

    total_loss_vals={'train':[],'val':[],'test':[]}
    RMSE_vals={'train':[],'val':[],'test':[]}
    MAE_vals={'train':[],'val':[],'test':[]}
    AUC_vals={'train':[],'val':[],'test':[]}
    AP_vals={'train':[],'val':[],'test':[]}
    for epoch in tqdm_loader:
        model.train()
        train_y_trues = []
        train_y_predicts = []
        train_total_loss = 0.0
        for batch, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(train_loader):
            blocks = [convert_to_gpu(b, device=args.device) for b in blocks]
            blocks = [convert_to_gpu(b, device=args.device) for b in blocks]
            positive_graph, negative_graph = convert_to_gpu(positive_graph, negative_graph, device=args.device)

            input_features = {(stype, etype, dtype): blocks[0].srcnodes[dtype].data['feat'] for stype, etype, dtype in blocks[0].canonical_etypes}

            nodes_representation, _ = model[0](blocks, copy.deepcopy(input_features), args=args)

            positive_score = model[1](positive_graph, nodes_representation, sample_edge_type).squeeze(dim=-1)
            negative_score = model[1](negative_graph, nodes_representation, sample_edge_type).squeeze(dim=-1)

            train_y_predict = th.cat([positive_score, negative_score], dim=0)
            train_y_true = th.cat([th.ones_like(positive_score), th.zeros_like(negative_score)], dim=0)
            loss = loss_func(train_y_predict, train_y_true)
            train_total_loss += loss.item()
            train_y_trues.append(train_y_true.detach().cpu())
            train_y_predicts.append(train_y_predict.detach().cpu())
            

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # step should be called after a batch has been used for training.
            train_steps += 1
            scheduler.step(train_steps)
        
        train_total_loss /= (batch + 1)
        train_y_trues = th.cat(train_y_trues, dim=0)
        train_y_predicts = th.cat(train_y_predicts, dim=0)
        train_RMSE, train_MAE, train_AUC, train_AP = evaluate_link_prediction(
            predict_scores=train_y_predicts, 
            true_scores=train_y_trues)

        total_loss_vals['train'].append(train_total_loss)
        RMSE_vals['train'].append(train_RMSE)
        MAE_vals['train'].append(train_MAE)
        AUC_vals['train'].append(train_AUC)
        AP_vals['train'].append(train_AP)

        model.eval()

        val_total_loss, val_y_trues, val_y_predicts = evaluate(
            model, 
            loader=val_loader, 
            loss_func=loss_func,
            sampled_edge_type=sample_edge_type,
            device=args.device, 
            mode='validate')
        val_RMSE, val_MAE, val_AUC, val_AP = evaluate_link_prediction(
            predict_scores=val_y_predicts,
            true_scores=val_y_trues)

        total_loss_vals['val'].append(val_total_loss)
        RMSE_vals['val'].append(val_RMSE)
        MAE_vals['val'].append(val_MAE)
        AUC_vals['val'].append(val_AUC)
        AP_vals['val'].append(val_AP)

        test_total_loss, test_y_trues, test_y_predicts = evaluate(
            model, 
            loader=test_loader, 
            loss_func=loss_func,
            sampled_edge_type=sample_edge_type,
            device=args.device, 
            mode='test')
        test_RMSE, test_MAE, test_AUC, test_AP = evaluate_link_prediction(
            predict_scores=test_y_predicts,
            true_scores=test_y_trues)

        total_loss_vals['test'].append(test_total_loss)
        RMSE_vals['test'].append(test_RMSE)
        MAE_vals['test'].append(test_MAE)
        AUC_vals['test'].append(test_AUC)
        AP_vals['test'].append(test_AP)

        if best_validate_RMSE is None or val_RMSE < best_validate_RMSE:
            best_validate_RMSE = val_RMSE
            scores = {"RMSE": float(f"{test_RMSE:.4f}"), "MAE": float(f"{test_MAE:.4f}"),"AUC": float(f"{test_AUC:.4f}"), "AP": float(f"{test_AP:.4f}")}
            final_result = json.dumps(scores, indent=4)

        tqdm_loader.set_description(f'EPOCH #{epoch} RMSE: {test_RMSE:.4f}, MAE: {test_MAE:.4f}, AUC: {test_AUC:.4f}, AP: {test_AP:.4f} ')

        early_stop = early_stopping.step([('RMSE', val_RMSE, False), ('MAE', val_MAE, False)], model)

        if early_stop:
            break

    print(f"predicted relation: {sample_edge_type}")
    print(f'result: {final_result}')


In [73]:
train_model(model, optimizer,scheduler, train_loader, val_loader, test_loader, save_result_dir, SAMPLED_EDGE_TYPE, None, args)

EPOCH #120 RMSE: 0.3648, MAE: 0.1571, AUC: 0.9581, AP: 0.9615 :  60%|██████    | 120/200 [05:08<03:25,  2.57s/it]

predicted relation: listened_to_track
result: {
    "RMSE": 0.3728,
    "MAE": 0.1597,
    "AUC": 0.9523,
    "AP": 0.9551
}



