## 1. About the Dataset <a class="anchor" id="section1"></a>
- **Source**: The dataset used for experiments is the heterogeneous rating graph, assembled by GroupLens Research from the [*MovieLens*](https://movielens.org).
- **Description**: The dataset contains two types of nodes: "user" and "movie". A user node is linked to a movie node if he has rated the movie. The link is then labeled with the rating he gave.
- **Task**: Predict the rating that users are likely to give to a movie.


In [1]:
import pandas as pd, numpy as np
from itertools import product
import io, os, json

import time

from sklearn.metrics import mean_squared_error

import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
pio.templates.default = "plotly_white"

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import to_hetero
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

%matplotlib inline

In [2]:
# Load dataset
dataset = MovieLens(root=os.path.join(os.getcwd(),'movielens'), force_reload=True)
data = dataset[0]

Processing...


Batches:   0%|          | 0/305 [00:00<?, ?it/s]

Done!


In [3]:
data

HeteroData(
  movie={ x=[9742, 404] },
  user={ num_nodes=610 },
  (user, rates, movie)={
    edge_index=[2, 100836],
    edge_label=[100836],
    time=[100836],
  }
)

In [4]:
node_types, edge_types = data.metadata()
print('Node types:', node_types)
print('Edge types:',edge_types)

Node types: ['movie', 'user']
Edge types: [('user', 'rates', 'movie')]


In [5]:
print('Isolated nodes?', data.has_isolated_nodes())
print('Self loops?', data.has_self_loops())
print('Undirected graph? ', data.is_undirected())

Isolated nodes? True
Self loops? False
Undirected graph?  False


In [6]:
# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.

counts = torch.bincount(data['user', 'movie'].edge_label)
weight = counts.max() / counts

In [7]:
data_dict = {'ratings': (counts, '# rows','coral'), 'weights': (weight, 'weights','royalblue')}

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(
    go.Scatter(x=np.arange(6), y=counts.detach().cpu().numpy(),
               name = 'nb rows', line_color= 'coral'))
fig.add_trace(
    go.Scatter(x=np.arange(6), y=weight.detach().cpu().numpy(),
               name = 'weights', line_color= 'royalblue'),  secondary_y=True)


fig.update_yaxes(title_text="# rows", secondary_y=False)
fig.update_yaxes(title_text="weights", secondary_y=True)
fig.update_xaxes(title_text="Rating")
fig

## 2. Graph-Based Modeling  <a class="anchor" id="section2"></a>

The objective of this section is to train graph-based models to predict the rating a user is likely to give to a movie.  For each model, it tests and evaluates different hyperparameters.

### 2.1. Training a first model <a class="anchor" id="section21"></a>

In [8]:
# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes)

In [9]:
# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

In [10]:
# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.2,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)

In [11]:
def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

In [27]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, conv):
        super().__init__()
        if conv.__name__ == "GATConv":
            self.conv1 = conv((-1, -1), hidden_channels, add_self_loops=False)
        else: 
            self.conv1 = conv((-1, -1), hidden_channels)

        if conv.__name__ == "GATConv":
            self.conv2 = conv((-1, -1), out_channels, add_self_loops=False)
        else: 
            self.conv2 = conv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels,  conv=SAGEConv):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels, conv)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)

In [28]:
def train(train_data, model, optimizer, loss=weighted_mse_loss):
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'movie'].edge_label_index)
    target = train_data['user', 'movie'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss.sqrt())

@torch.no_grad()
def test(data, model, metric=F.mse_loss):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)

def train_test(model, model_params, learning_rate=0.01, e_patience = 10, min_acc= 0.05, n_epochs=500):
    t0 = time.time()
    model = model(**model_params)

    # Due to lazy initialization, we need to run one model step so the number
    # of parameters can be inferred:
    with torch.no_grad():
        model.encoder(train_data.x_dict, train_data.edge_index_dict)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    k=0
    loss, train_rmse, val_rmse, test_rmse = [], [], [], []
    train_wrmse, val_wrmse, test_wrmse = [], [], []
    for epoch in range(n_epochs):
        loss += [train(train_data, model, optimizer, loss=weighted_mse_loss)]
        
        train_wrmse += [test(train_data, model, metric=weighted_mse_loss)]
        train_rmse += [test(train_data, model, metric=F.mse_loss)]
        
        val_wrmse += [test(val_data, model, metric=weighted_mse_loss)]
        val_rmse += [test(val_data, model, metric=F.mse_loss)]
        
        test_wrmse += [test(test_data, model, metric=weighted_mse_loss)]
        test_rmse += [test(test_data, model, metric=F.mse_loss)]

        if epoch+1 %10==0:
            print(f'Epoch: {epoch+1:03d}, Loss: {loss[-1]:.4f}, Train: {train_rmse[-1]:.4f}, '
                  f'Val: {val_rmse[-1]:.4f}, Test: {test_rmse[-1]:.4f}')

        results = pd.DataFrame({
            'loss': loss,
            'train_rmse': train_rmse, 'val_rmse': val_rmse, 'test_rmse': test_rmse,
            'train_wrmse': train_rmse, 'val_wrmse': val_rmse, 'test_wrmse': test_rmse,
            'time':(time.time()-t0)/60
        })

        # enable early stopping
        if (epoch > 1) and abs(loss[-1]/loss[-2]-1) < min_acc :
            k += 1
        if k> e_patience:
            print('Early stopping')
            break

    return results

def visualize_loss(results, metric='rmse'):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=results.index, y=results['train_'+metric], name = 'train_'+metric))
    fig.add_trace(go.Scatter(x=results.index, y=results['val_'+metric], name = 'val_'+metric))
    fig.add_trace(go.Scatter(x=results.index, y=results['test_'+metric], name = 'test_'+metric))
    fig.add_trace(go.Scatter(x=results.index, y=results['loss'], name = 'loss'))

    fig.update_yaxes(title_text=metric.upper())
    fig.update_xaxes(title_text="Epoch")

    return fig

In [29]:
N_EPOCHS = 500
E_PATIENCE = 50
LEARNING_RATE = 0.01

In [16]:
model_params = {"hidden_channels":32, 'conv':SAGEConv}

results = train_test(
    Model, model_params, learning_rate=LEARNING_RATE, e_patience = E_PATIENCE, n_epochs=N_EPOCHS)

made it here
Early stopping


In [20]:
visualize_loss(results, metric='wrmse')

### 2.2. Tuning the model hyperparameters <a class="anchor" id="section22"></a>

In [30]:
hidden_channels = [16, 32, 64]
learning_rates = [0.005, 0.01, 0.05]
convs = [GATConv, SAGEConv]

In [26]:
i=0

for hc, lr, c in product(hidden_channels, learning_rates, convs):
    print('Sc: {}/{}'.format(i+1, len(hidden_channels)*len(learning_rates)*len(convs)))
    name_conv = str(c).split('.')[-1].replace('\'>', "")
    print({'hidden_channels':hc, 'learning_rate': lr, 'conv':name_conv})

    model_params = {'hidden_channels':hc, 'conv':c}
    results = train_test(
        Model, model_params, learning_rate=lr, e_patience = E_PATIENCE)

    results['model'] = name_conv
    results['hc'] = hc
    results['lr'] = lr

    results_gnn = results.iloc[-1:] if i==0 else pd.concat([results_gnn, results.iloc[-1:]], axis=0)
    results_gnn.reset_index(drop=True, inplace=True)

    display(results_gnn)

    i+=1

cols = ['model', 'hc', 'lr']
results_gnn = results_gnn[
    cols+[c for c in results_gnn.columns if c not in cols]
].sort_values(by='test_wrmse')

Sc: 1/18
{'hidden_channels': 16, 'learning_rate': 0.005, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 16, heads=1)
  (conv2): GATConv((-1, -1), 16, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005


Sc: 2/18
{'hidden_channels': 16, 'learning_rate': 0.005, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 16, aggr=mean)
  (conv2): SAGEConv((-1, -1), 16, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005


Sc: 3/18
{'hidden_channels': 16, 'learning_rate': 0.01, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 16, heads=1)
  (conv2): GATConv((-1, -1), 16, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01


Sc: 4/18
{'hidden_channels': 16, 'learning_rate': 0.01, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 16, aggr=mean)
  (conv2): SAGEConv((-1, -1), 16, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01


Sc: 5/18
{'hidden_channels': 16, 'learning_rate': 0.05, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 16, heads=1)
  (conv2): GATConv((-1, -1), 16, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05


Sc: 6/18
{'hidden_channels': 16, 'learning_rate': 0.05, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 16, aggr=mean)
  (conv2): SAGEConv((-1, -1), 16, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05


Sc: 7/18
{'hidden_channels': 32, 'learning_rate': 0.005, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 32, heads=1)
  (conv2): GATConv((-1, -1), 32, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005


Sc: 8/18
{'hidden_channels': 32, 'learning_rate': 0.005, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 32, aggr=mean)
  (conv2): SAGEConv((-1, -1), 32, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005


Sc: 9/18
{'hidden_channels': 32, 'learning_rate': 0.01, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 32, heads=1)
  (conv2): GATConv((-1, -1), 32, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01


Sc: 10/18
{'hidden_channels': 32, 'learning_rate': 0.01, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 32, aggr=mean)
  (conv2): SAGEConv((-1, -1), 32, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 11/18
{'hidden_channels': 32, 'learning_rate': 0.05, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 32, heads=1)
  (conv2): GATConv((-1, -1), 32, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 12/18
{'hidden_channels': 32, 'learning_rate': 0.05, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 32, aggr=mean)
  (conv2): SAGEConv((-1, -1), 32, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 13/18
{'hidden_channels': 64, 'learning_rate': 0.005, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 64, heads=1)
  (conv2): GATConv((-1, -1), 64, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 14/18
{'hidden_channels': 64, 'learning_rate': 0.005, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 64, aggr=mean)
  (conv2): SAGEConv((-1, -1), 64, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 15/18
{'hidden_channels': 64, 'learning_rate': 0.01, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 64, heads=1)
  (conv2): GATConv((-1, -1), 64, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 16/18
{'hidden_channels': 64, 'learning_rate': 0.01, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 64, aggr=mean)
  (conv2): SAGEConv((-1, -1), 64, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 17/18
{'hidden_channels': 64, 'learning_rate': 0.05, 'conv': 'GATConv'}
success1!
success2!
output GNNEncoder(
  (conv1): GATConv((-1, -1), 64, heads=1)
  (conv2): GATConv((-1, -1), 64, heads=1)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


Sc: 18/18
{'hidden_channels': 64, 'learning_rate': 0.05, 'conv': 'SAGEConv'}
output GNNEncoder(
  (conv1): SAGEConv((-1, -1), 64, aggr=mean)
  (conv2): SAGEConv((-1, -1), 64, aggr=mean)
)
Early stopping


Unnamed: 0,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time,model,hc,lr
0,2.450284,1.304201,1.348433,1.347838,1.304201,1.348433,1.347838,0.185396,GATConv,16,0.005
1,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496,SAGEConv,16,0.005
2,2.348879,1.280352,1.32259,1.318935,1.280352,1.32259,1.318935,0.176415,GATConv,16,0.01
3,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687,SAGEConv,16,0.01
4,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843,GATConv,16,0.05
5,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111,SAGEConv,16,0.05
6,2.393313,1.296184,1.333969,1.334517,1.296184,1.333969,1.334517,0.251543,GATConv,32,0.005
7,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624,SAGEConv,32,0.005
8,2.244836,1.309098,1.329747,1.291823,1.309098,1.329747,1.291823,0.241752,GATConv,32,0.01
9,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669,SAGEConv,32,0.01


In [31]:
results_gnn

Unnamed: 0,model,hc,lr,loss,train_rmse,val_rmse,test_rmse,train_wrmse,val_wrmse,test_wrmse,time
5,SAGEConv,16,0.05,1.756259,1.095855,1.131961,1.12405,1.095855,1.131961,1.12405,0.368111
4,GATConv,16,0.05,2.072306,1.137587,1.150139,1.127268,1.137587,1.150139,1.127268,0.189843
15,SAGEConv,64,0.01,1.85277,1.142838,1.148923,1.150611,1.142838,1.148923,1.150611,0.577445
7,SAGEConv,32,0.005,1.952181,1.17057,1.166871,1.173838,1.17057,1.166871,1.173838,0.424624
13,SAGEConv,64,0.005,1.90705,1.184614,1.184296,1.185819,1.184614,1.184296,1.185819,0.486634
3,SAGEConv,16,0.01,1.894977,1.187023,1.185236,1.189817,1.187023,1.185236,1.189817,0.38687
11,SAGEConv,32,0.05,1.967347,1.185745,1.187492,1.19182,1.185745,1.187492,1.19182,0.481373
9,SAGEConv,32,0.01,1.907324,1.194591,1.19261,1.193682,1.194591,1.19261,1.193682,0.463669
16,GATConv,64,0.05,2.039817,1.085559,1.191846,1.197255,1.085559,1.191846,1.197255,0.41325
1,SAGEConv,16,0.005,2.199547,1.266163,1.267525,1.275587,1.266163,1.267525,1.275587,0.403496


In [32]:
metric='wrmse'
cols = ['train', 'test', 'val']

table = results_gnn.groupby(['model'],as_index=False).agg(
    {c+'_'+metric:'first' for c in cols})

colors = ['#404B69', '#F73859', '#666d9b', '#ddddff', '#DBEDF3']

r_colors = {m: colors[i] for i, m in enumerate(table.model.unique())}

fig = go.Figure(data=[
    go.Bar(
        name=m, x=cols, y=[table[table.model==m][c+'_'+metric].iloc[0] for c in cols],
        textposition='outside',
        text=['{}'.format(round(a,3)) for a in [table[table.model==m][c+'_'+metric].iloc[0] for c in cols]],
        marker_color= r_colors[m]
    )
    for m in table.model.unique()])

fig.update_layout(barmode='group', width=800, height=400)
fig.update_yaxes(range=[0, 1.5])
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
))
fig.show()

In [33]:
param_names = ['hc', 'lr']

best_params = results_gnn.groupby(['model'],as_index=False).agg(
    {p:'first' for p in param_names})
best_params.index= best_params.model
best_params.drop('model', axis=1, inplace=True)
best_params = best_params.to_dict(orient='index')

In [34]:
best_params

{'GATConv': {'hc': 16, 'lr': 0.05}, 'SAGEConv': {'hc': 16, 'lr': 0.05}}