In [1]:
import networkx as nx
import pandas as pd
import numpy as np
import itertools
import os

import matplotlib.pyplot as plt

import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import GraphSAGELinkGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.calibration import expected_calibration_error, plot_reliability_diagram
from stellargraph.calibration import IsotonicCalibration, TemperatureCalibration
from stellargraph import StellarGraph

from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection
from sklearn.calibration import calibration_curve
from sklearn.isotonic import IsotonicRegression

from sklearn.metrics import accuracy_score

from stellargraph import globalvar
from stellargraph import datasets
from IPython.display import display, HTML

%matplotlib inline

In [2]:
batch_size = 50
epochs = 20  # The number of training epochs for training the GraphSAGE model.

In [3]:
playlist_nodes = pd.read_csv('D:\\extracted_features\\playlists_256_full.csv')
playlist_nodes = playlist_nodes.set_index('pid')
playlist_nodes

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
pid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.548065,-0.004969,-0.152679,0.616827,0.274475,0.976250,-0.902120,0.046461,-0.536994,0.020827,...,-0.113749,-0.910833,0.997328,-0.815965,-0.784344,-0.776490,-0.995361,0.768619,0.870750,0.010204
1,0.546270,-0.033351,-0.392608,0.468498,0.234031,0.957838,-0.913424,0.029962,-0.517722,0.039464,...,-0.086244,-0.909943,0.995852,-0.778645,-0.843372,-0.733728,-0.996370,0.646433,0.860160,0.056390
2,0.253194,-0.020886,-0.405861,0.620173,0.321243,0.968687,-0.886844,0.061407,-0.320926,0.023277,...,-0.104075,-0.907041,0.993235,-0.916639,-0.735473,-0.806277,-0.994427,0.797605,0.815965,0.025856
3,0.387236,-0.055463,-0.166385,0.481297,0.278856,0.972705,-0.885096,0.059435,-0.443821,0.010571,...,-0.107550,-0.914139,0.993060,-0.819462,-0.833836,-0.794985,-0.994906,0.583700,0.758970,0.022974
4,0.570883,-0.019489,-0.218323,0.646070,0.287380,0.965244,-0.866297,0.019621,-0.610707,0.014086,...,-0.085111,-0.889409,0.996961,-0.819861,-0.806127,-0.684427,-0.995705,0.726435,0.890715,-0.050058
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,0.509209,-0.032868,-0.259281,0.576987,0.271568,0.977360,-0.913212,0.061560,-0.450586,0.038395,...,-0.120194,-0.904451,0.996007,-0.853962,-0.829651,-0.785840,-0.994907,0.715587,0.847042,0.020912
14996,0.532920,0.023744,-0.080581,0.642182,0.273193,0.974883,-0.890937,0.028100,-0.318348,0.036292,...,-0.128532,-0.891057,0.996765,-0.848702,-0.856475,-0.793554,-0.993932,0.639508,0.766159,0.043139
14997,0.438527,0.008225,-0.064779,0.618701,0.240719,0.973851,-0.883567,0.035740,-0.477776,0.058208,...,-0.120472,-0.915622,0.996209,-0.861927,-0.823383,-0.751569,-0.996538,0.701738,0.850584,0.037790
14998,0.579338,-0.052516,-0.360727,0.458734,0.264151,0.979575,-0.924249,0.046682,-0.514681,0.025852,...,-0.082898,-0.879444,0.997280,-0.776388,-0.907032,-0.776681,-0.995498,0.722812,0.852588,-0.011156


In [6]:
track_nodes = pd.read_csv('D:\\extracted_features\\tracks_256_full.csv')
track_nodes = track_nodes.set_index('track_uri')
all_nodes = pd.concat([track_nodes,playlist_nodes], axis=0)
all_nodes['id'] = all_nodes.index
all_nodes

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,247,248,249,250,251,252,253,254,255,id
spotify:track:000GjfnQc7ggBayDiy1sLW,0.464777,-0.125684,0.715710,0.358084,0.303762,0.954011,-0.295859,0.028331,-0.495248,-0.033428,...,-0.947096,0.992281,-0.958093,-0.953317,-0.917473,-0.997311,0.216808,0.580898,0.059023,spotify:track:000GjfnQc7ggBayDiy1sLW
spotify:track:000JBgYWfJQqdFaRqu2n3f,0.565548,0.048264,-0.635601,0.347392,0.253677,0.982890,-0.961283,-0.019696,-0.405045,-0.009414,...,-0.837499,0.995980,-0.911891,-0.638338,-0.850850,-0.997972,0.600172,0.876552,-0.082913,spotify:track:000JBgYWfJQqdFaRqu2n3f
spotify:track:000ULyVqUhqnAyA0Um3MEH,0.097290,-0.025617,-0.640280,0.257316,0.254600,0.988062,-0.963823,0.184923,-0.437594,0.033821,...,-0.940450,0.988101,-0.970540,-0.871060,-0.840048,-0.997457,0.465446,0.756995,0.014149,spotify:track:000ULyVqUhqnAyA0Um3MEH
spotify:track:000VZqvXwT0YNqKk7iG2GS,0.563701,-0.095213,-0.217249,-0.281310,0.295809,0.981444,-0.954634,0.029152,-0.755409,-0.055398,...,-0.833007,0.994377,-0.762435,-0.914978,-0.932535,-0.999104,0.015044,0.745456,-0.087516,spotify:track:000VZqvXwT0YNqKk7iG2GS
spotify:track:000mA0etY38nKdvf1N04af,0.822130,0.113043,0.050832,-0.044618,0.432734,0.974628,-0.971879,-0.043917,-0.616016,-0.038279,...,-0.927730,0.998422,-0.863780,-0.951869,-0.879061,-0.997855,0.898769,0.932477,-0.217486,spotify:track:000mA0etY38nKdvf1N04af
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,0.509209,-0.032868,-0.259281,0.576987,0.271568,0.977360,-0.913212,0.061560,-0.450586,0.038395,...,-0.904451,0.996007,-0.853962,-0.829651,-0.785840,-0.994907,0.715587,0.847042,0.020912,14995
14996,0.532920,0.023744,-0.080581,0.642182,0.273193,0.974883,-0.890937,0.028100,-0.318348,0.036292,...,-0.891057,0.996765,-0.848702,-0.856475,-0.793554,-0.993932,0.639508,0.766159,0.043139,14996
14997,0.438527,0.008225,-0.064779,0.618701,0.240719,0.973851,-0.883567,0.035740,-0.477776,0.058208,...,-0.915622,0.996209,-0.861927,-0.823383,-0.751569,-0.996538,0.701738,0.850584,0.037790,14997
14998,0.579338,-0.052516,-0.360727,0.458734,0.264151,0.979575,-0.924249,0.046682,-0.514681,0.025852,...,-0.879444,0.997280,-0.776388,-0.907032,-0.776681,-0.995498,0.722812,0.852588,-0.011156,14998


In [10]:
cols = all_nodes.columns.tolist() 
cols = cols[-1:] + cols[:-1] 
all_nodes[cols]

Unnamed: 0,id,0,1,2,3,4,5,6,7,8,...,246,247,248,249,250,251,252,253,254,255
spotify:track:000GjfnQc7ggBayDiy1sLW,spotify:track:000GjfnQc7ggBayDiy1sLW,0.464777,-0.125684,0.715710,0.358084,0.303762,0.954011,-0.295859,0.028331,-0.495248,...,-0.242615,-0.947096,0.992281,-0.958093,-0.953317,-0.917473,-0.997311,0.216808,0.580898,0.059023
spotify:track:000JBgYWfJQqdFaRqu2n3f,spotify:track:000JBgYWfJQqdFaRqu2n3f,0.565548,0.048264,-0.635601,0.347392,0.253677,0.982890,-0.961283,-0.019696,-0.405045,...,-0.089524,-0.837499,0.995980,-0.911891,-0.638338,-0.850850,-0.997972,0.600172,0.876552,-0.082913
spotify:track:000ULyVqUhqnAyA0Um3MEH,spotify:track:000ULyVqUhqnAyA0Um3MEH,0.097290,-0.025617,-0.640280,0.257316,0.254600,0.988062,-0.963823,0.184923,-0.437594,...,-0.187630,-0.940450,0.988101,-0.970540,-0.871060,-0.840048,-0.997457,0.465446,0.756995,0.014149
spotify:track:000VZqvXwT0YNqKk7iG2GS,spotify:track:000VZqvXwT0YNqKk7iG2GS,0.563701,-0.095213,-0.217249,-0.281310,0.295809,0.981444,-0.954634,0.029152,-0.755409,...,0.102217,-0.833007,0.994377,-0.762435,-0.914978,-0.932535,-0.999104,0.015044,0.745456,-0.087516
spotify:track:000mA0etY38nKdvf1N04af,spotify:track:000mA0etY38nKdvf1N04af,0.822130,0.113043,0.050832,-0.044618,0.432734,0.974628,-0.971879,-0.043917,-0.616016,...,-0.006954,-0.927730,0.998422,-0.863780,-0.951869,-0.879061,-0.997855,0.898769,0.932477,-0.217486
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14995,14995,0.509209,-0.032868,-0.259281,0.576987,0.271568,0.977360,-0.913212,0.061560,-0.450586,...,-0.120194,-0.904451,0.996007,-0.853962,-0.829651,-0.785840,-0.994907,0.715587,0.847042,0.020912
14996,14996,0.532920,0.023744,-0.080581,0.642182,0.273193,0.974883,-0.890937,0.028100,-0.318348,...,-0.128532,-0.891057,0.996765,-0.848702,-0.856475,-0.793554,-0.993932,0.639508,0.766159,0.043139
14997,14997,0.438527,0.008225,-0.064779,0.618701,0.240719,0.973851,-0.883567,0.035740,-0.477776,...,-0.120472,-0.915622,0.996209,-0.861927,-0.823383,-0.751569,-0.996538,0.701738,0.850584,0.037790
14998,14998,0.579338,-0.052516,-0.360727,0.458734,0.264151,0.979575,-0.924249,0.046682,-0.514681,...,-0.082898,-0.879444,0.997280,-0.776388,-0.907032,-0.776681,-0.995498,0.722812,0.852588,-0.011156


In [5]:
import csv

csv_headers = ['identifier']
[csv_headers.append(i) for i in range(0, 256)]

with open('D:\\extracted_features\\all_nodes.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    
    writer.writerow(csv_headers)
    
    writer.writerows(all_nodes.values)

In [7]:
train_graph = pd.read_csv('D:\\extracted_features\\train_edges_graph.csv').drop('pos', axis=1)
test_graph = pd.read_csv('D:\\extracted_features\\test_edges_graph.csv').drop('pos', axis=1)

In [13]:
train_graph.columns = ['target', 'source']
train_graph

Unnamed: 0,target,source
0,0,spotify:track:0UaMYEvWZi0ZqiDOoHU3YI
1,0,spotify:track:6I9VzXrHxO9rA9A5euc8Ak
2,0,spotify:track:0WqIKmW4BTrj3eJFmnCKMv
3,0,spotify:track:1AWQoqb9bSvzTjaLralEkT
4,0,spotify:track:1lzr43nnXAijIGYnCT8M8H
...,...,...
767001,14999,spotify:track:5hTpBe8h35rJ67eAWHQsJx
767002,14999,spotify:track:3kxfsdsCpFgN412fpnW85Y
767003,14999,spotify:track:6eT7xZZlB2mwyzJ2sUKG6w
767004,14999,spotify:track:4Q3N4Ct4zCuIHuZ65E3BD4


In [14]:
test_graph.columns = ['target', 'source']
test_graph

Unnamed: 0,target,source
0,0,spotify:track:0UaMYEvWZi0ZqiDOoHU3YI
1,0,spotify:track:6I9VzXrHxO9rA9A5euc8Ak
2,0,spotify:track:0WqIKmW4BTrj3eJFmnCKMv
3,0,spotify:track:1AWQoqb9bSvzTjaLralEkT
4,0,spotify:track:1lzr43nnXAijIGYnCT8M8H
...,...,...
922430,14999,spotify:track:5hTpBe8h35rJ67eAWHQsJx
922431,14999,spotify:track:3kxfsdsCpFgN412fpnW85Y
922432,14999,spotify:track:6eT7xZZlB2mwyzJ2sUKG6w
922433,14999,spotify:track:4Q3N4Ct4zCuIHuZ65E3BD4


In [42]:
G_train = StellarGraph({'item': all_nodes} , train_graph)
print(G_train.info())

G_test = StellarGraph({'item': all_nodes} , test_graph)
print(G_test.info())

StellarGraph: Undirected multigraph
 Nodes: 233129, Edges: 767006

 Node types:
  item: [233129]
    Features: float32 vector, length 256
    Edge types: item-default->item

 Edge types:
    item-default->item: [767006]
        Weights: all 1 (default)
        Features: none
StellarGraph: Undirected multigraph
 Nodes: 233129, Edges: 922435

 Node types:
  item: [233129]
    Features: float32 vector, length 256
    Edge types: item-default->item

 Edge types:
    item-default->item: [922435]
        Weights: all 1 (default)
        Features: none


In [16]:
training_samples = pd.read_csv('D:\\extracted_features\\train_samples.csv')
testing_samples = pd.read_csv('D:\\extracted_features\\test_samples.csv')

In [26]:
training_labels = np.asarray(training_samples['link'].values)
testing_labels = np.asarray(testing_samples['link'].values)

In [36]:
training_sample_nodes = np.asarray(training_samples.values[:,:2])
testing_sample_nodes = np.asarray(testing_samples.values[:,:2])

In [37]:
num_samples = [10, 5]

In [43]:
train_gen = GraphSAGELinkGenerator(G_train, batch_size, num_samples)
test_gen = GraphSAGELinkGenerator(G_test, batch_size, num_samples)

In [44]:
layer_sizes = [32, 32]
graphsage = GraphSAGE(
    layer_sizes=layer_sizes, generator=train_gen, bias=True, dropout=0.2
)

In [45]:
x_inp, x_out = graphsage.in_out_tensors()

In [46]:
logits = link_classification(
    output_dim=1, output_act="linear", edge_embedding_method="ip"
)(x_out)

prediction = keras.layers.Activation(keras.activations.sigmoid)(logits)

link_classification: using 'ip' method to combine node embeddings into edge embeddings


In [48]:
model = keras.Model(inputs=x_inp, outputs=prediction)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy],
)

In [49]:
train_flow = train_gen.flow(training_sample_nodes, training_labels, shuffle=True)
test_flow = test_gen.flow(testing_sample_nodes, testing_labels)

In [50]:
history = model.fit(
    train_flow, epochs=10, verbose=1, shuffle=True
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [51]:
train_metrics = model.evaluate(train_flow)
test_metrics = model.evaluate(test_flow)

print("\nTrain Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, train_metrics):
    print("\t{}: {:0.4f}".format(name, val))

print("\nTest Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, test_metrics):
    print("\t{}: {:0.4f}".format(name, val))


Train Set Metrics of the trained model:
	loss: 0.4901
	binary_accuracy: 0.8487

Test Set Metrics of the trained model:
	loss: 0.5042
	binary_accuracy: 0.8299


In [53]:
y_pred = model.predict(test_flow)
y_pred

array([[0.69759125],
       [0.7119384 ],
       [0.5603485 ],
       ...,
       [0.67303574],
       [0.6317589 ],
       [0.34943926]], dtype=float32)

In [61]:
y_predicted = []
[y_predicted.append(int(round(value[0], 0))) for value in y_pred]

from sklearn.metrics import roc_auc_score

roc_auc_score(testing_labels, y_predicted)

0.8301655829625925

In [62]:
from sklearn.metrics import average_precision_score

average_precision_score(testing_labels, y_predicted)

0.7605334689705625

In [63]:
GHIN_train = StellarGraph({'playlist': playlist_nodes, 'track': track_nodes} , train_graph)
print(GHIN_train.info())

GHIN_test = StellarGraph({'playlist': playlist_nodes, 'track': track_nodes} , test_graph)
print(GHIN_test.info())

StellarGraph: Undirected multigraph
 Nodes: 233129, Edges: 767006

 Node types:
  track: [218129]
    Features: float32 vector, length 256
    Edge types: track-default->playlist
  playlist: [15000]
    Features: float32 vector, length 256
    Edge types: playlist-default->track

 Edge types:
    playlist-default->track: [767006]
        Weights: all 1 (default)
        Features: none
StellarGraph: Undirected multigraph
 Nodes: 233129, Edges: 922435

 Node types:
  track: [218129]
    Features: float32 vector, length 256
    Edge types: track-default->playlist
  playlist: [15000]
    Features: float32 vector, length 256
    Edge types: playlist-default->track

 Edge types:
    playlist-default->track: [922435]
        Weights: all 1 (default)
        Features: none


In [65]:
from stellargraph.mapper import HinSAGELinkGenerator

train_gen = HinSAGELinkGenerator(GHIN_train, batch_size, num_samples, head_node_types=["playlist", "track"])
test_gen = HinSAGELinkGenerator(GHIN_test, batch_size, num_samples, head_node_types=["playlist", "track"])

In [66]:
from stellargraph.layer import HinSAGE

layer_sizes = [32, 32]
hinsage = HinSAGE(
    layer_sizes=layer_sizes, generator=train_gen, bias=True, dropout=0.2
)

In [67]:
x_inp, x_out = hinsage.in_out_tensors()

In [68]:
logits = link_classification(
    output_dim=1, output_act="linear", edge_embedding_method="ip"
)(x_out)

prediction = keras.layers.Activation(keras.activations.sigmoid)(logits)

link_classification: using 'ip' method to combine node embeddings into edge embeddings


In [69]:
modelHIN = keras.Model(inputs=x_inp, outputs=prediction)

modelHIN.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy],
)

In [70]:
history2 = modelHIN.fit(
    train_flow, epochs=10, verbose=1, shuffle=True
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [71]:
y_pred = modelHIN.predict(test_flow)
y_pred

array([[0.72315085],
       [0.717582  ],
       [0.72284466],
       ...,
       [0.527136  ],
       [0.63651794],
       [0.55006355]], dtype=float32)

In [72]:
y_predicted = []
[y_predicted.append(int(round(value[0], 0))) for value in y_pred]

roc_auc_score(testing_labels, y_predicted)

0.8133637823465922

In [73]:
average_precision_score(testing_labels, y_predicted)

0.7464333227652322

In [74]:
train_metrics = modelHIN.evaluate(train_flow)
test_metrics = modelHIN.evaluate(test_flow)

print("\nTrain Set Metrics of the trained model:")
for name, val in zip(modelHIN.metrics_names, train_metrics):
    print("\t{}: {:0.4f}".format(name, val))

print("\nTest Set Metrics of the trained model:")
for name, val in zip(modelHIN.metrics_names, test_metrics):
    print("\t{}: {:0.4f}".format(name, val))


Train Set Metrics of the trained model:
	loss: 0.4847
	binary_accuracy: 0.8526

Test Set Metrics of the trained model:
	loss: 0.5148
	binary_accuracy: 0.8128
