In [1]:
# %load_ext autoreload
# %autoreload 2

In [31]:
import pandas as pd

from stellargraph import StellarGraph
from stellargraph.data import UnsupervisedSampler, EdgeSplitter, BiasedRandomWalk
from stellargraph.mapper import HinSAGELinkGenerator
from stellargraph.layer import HinSAGE, link_classification

# Load the data

In [3]:
from shared.schema import DatasetSchema, GraphSchema
from shared.graph.loading import pd_from_entity_schema


In [4]:
DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)

In [5]:
explicit_label = False
explicit_timestamp = True
unix_timestamp = True
prefix_id = None
include_properties = lambda cs: [c for c in cs if c.startswith('feat_')]

nodes_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    ).set_index('id').drop(columns=['type'])
    for label, entity_schema in schema.nodes.items()
}

edges_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    ).reset_index().drop(columns=['type']).drop_duplicates(subset=['src', 'dst', 'timestamp'])
    for label, entity_schema in schema.edges.items()
}

cursor = 0
for df in edges_dfs.values():
    df.index += cursor
    cursor += len(df)

In [6]:
graph = StellarGraph(
    nodes=nodes_dfs,
    edges=edges_dfs,
    source_column='src',
    target_column='dst',
)
print(graph.info())

StellarGraph: Undirected multigraph
 Nodes: 113, Edges: 2078

 Node types:
  Character: [113]
    Features: float32 vector, length 32
    Edge types: Character-INTERACTIONS->Character, Character-MENTIONS->Character

 Edge types:
    Character-MENTIONS->Character: [1120]
        Weights: all 1 (default)
        Features: float32 vector, length 2
    Character-INTERACTIONS->Character: [958]
        Weights: all 1 (default)
        Features: float32 vector, length 2


# Split dataset

In [26]:
from sklearn.model_selection import train_test_split

In [27]:
edge_splitter_test = EdgeSplitter(graph)
graph_sub_test, examples_test, labels_test = edge_splitter_test.train_test_split(
    p=0.05, method="global"
)

print(graph_sub_test.info())

** Sampled 103 positive and 103 negative edges. **
StellarGraph: Undirected multigraph
 Nodes: 113, Edges: 1975

 Node types:
  Character: [113]
    Features: float32 vector, length 32
    Edge types: Character-INTERACTIONS->Character, Character-MENTIONS->Character

 Edge types:
    Character-MENTIONS->Character: [1069]
        Weights: all 1 (default)
        Features: none
    Character-INTERACTIONS->Character: [906]
        Weights: all 1 (default)
        Features: none


In [28]:
train_size = 0.75
val_size = 0.25

edge_splitter_train = EdgeSplitter(graph_sub_test)
graph_train, examples, labels = edge_splitter_train.train_test_split(
    p=0.1, method="global"
)

(
    examples_train,
    examples_val,
    labels_train,
    labels_val,
) = train_test_split(examples, labels, train_size=train_size, test_size=val_size)

print(graph_train.info())

** Sampled 197 positive and 197 negative edges. **
StellarGraph: Undirected multigraph
 Nodes: 113, Edges: 1778

 Node types:
  Character: [113]
    Features: float32 vector, length 32
    Edge types: Character-INTERACTIONS->Character, Character-MENTIONS->Character

 Edge types:
    Character-MENTIONS->Character: [962]
        Weights: all 1 (default)
        Features: none
    Character-INTERACTIONS->Character: [816]
        Weights: all 1 (default)
        Features: none


In [29]:
pd.DataFrame(
    [
        (
            "Training Set",
            len(examples_train),
            "Train Graph",
            "Test Graph",
            "Train the Link Classifier",
        ),
        (
            "Validation Set",
            len(examples_val),
            "Train Graph",
            "Test Graph",
            "Validate the Link Classifier",
        ),
        (
            "Test set",
            len(examples_test),
            "Test Graph",
            "Full Graph",
            "Evaluate Link Classifier",
        ),
    ],
    columns=("Split", "Number of Examples", "Hidden from", "Picked from", "Use"),
).set_index("Split")


Unnamed: 0_level_0,Number of Examples,Hidden from,Picked from,Use
Split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Training Set,295,Train Graph,Test Graph,Train the Link Classifier
Validation Set,99,Train Graph,Test Graph,Validate the Link Classifier
Test set,206,Test Graph,Full Graph,Evaluate Link Classifier


# Train HinSage

In [38]:
from tensorflow import keras

In [73]:
batch_size = 30
epochs = 30
dimensions = [128, 128]
num_samples = [10, 5]
walk_length = 5
walk_number = 1

In [74]:
unsupervised_samples = UnsupervisedSampler(
    graph_train, nodes=list(graph_train.nodes()),
    walker=BiasedRandomWalk(graph_train, n=walk_number, length=walk_length, p=1, q=1)
)

generator = HinSAGELinkGenerator(
    graph_train, batch_size, num_samples, head_node_types=["Character", "Character"]
)

In [75]:
hinsage_layer_sizes = [32, 32]
assert len(hinsage_layer_sizes) == len(num_samples)

hinsage = HinSAGE(
    layer_sizes=hinsage_layer_sizes, generator=generator, bias=True, dropout=0.0
)

x_inp, x_out = hinsage.in_out_tensors()
prediction = link_classification(
    output_dim=1, output_act="sigmoid", edge_embedding_method="ip"
)(x_out)

link_classification: using 'ip' method to combine node embeddings into edge embeddings


In [76]:
model = keras.Model(inputs=x_inp, outputs=prediction)
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy],
)

In [77]:
model.summary()

Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_59 (InputLayer)           [(None, 10, 32)]     0                                            
__________________________________________________________________________________________________
input_60 (InputLayer)           [(None, 10, 32)]     0                                            
__________________________________________________________________________________________________
input_63 (InputLayer)           [(None, 50, 32)]     0                                            
__________________________________________________________________________________________________
input_64 (InputLayer)           [(None, 50, 32)]     0                                            
____________________________________________________________________________________________

In [78]:
model.fit(
    generator.flow(unsupervised_samples),
    validation_data=generator.flow(examples_val, labels_val),
    epochs=epochs,
    verbose=2,
    use_multiprocessing=False,
    workers=4,
    shuffle=True,
)

Epoch 1/30
31/31 - 5s - loss: 0.7228 - binary_accuracy: 0.4989 - val_loss: 0.6845 - val_binary_accuracy: 0.5152
Epoch 2/30
31/31 - 2s - loss: 0.7245 - binary_accuracy: 0.5011 - val_loss: 0.6885 - val_binary_accuracy: 0.4848
Epoch 3/30
31/31 - 2s - loss: 0.7166 - binary_accuracy: 0.5022 - val_loss: 0.6777 - val_binary_accuracy: 0.5051
Epoch 4/30
31/31 - 2s - loss: 0.7134 - binary_accuracy: 0.5033 - val_loss: 0.6775 - val_binary_accuracy: 0.4949
Epoch 5/30
31/31 - 2s - loss: 0.7024 - binary_accuracy: 0.5100 - val_loss: 0.6622 - val_binary_accuracy: 0.5051
Epoch 6/30
31/31 - 2s - loss: 0.7050 - binary_accuracy: 0.5155 - val_loss: 0.6585 - val_binary_accuracy: 0.5455
Epoch 7/30
31/31 - 2s - loss: 0.6924 - binary_accuracy: 0.5144 - val_loss: 0.6526 - val_binary_accuracy: 0.5758
Epoch 8/30
31/31 - 2s - loss: 0.6887 - binary_accuracy: 0.5155 - val_loss: 0.6487 - val_binary_accuracy: 0.5758
Epoch 9/30
31/31 - 2s - loss: 0.6917 - binary_accuracy: 0.5055 - val_loss: 0.6444 - val_binary_accuracy:

<keras.callbacks.History at 0x7f5c89562b80>