In [1]:
from tqdm import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import EvolveGCNH

from torch_geometric_temporal.dataset import TwitterTennisDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

In [2]:
loader = TwitterTennisDatasetLoader()

dataset = loader.get_dataset()

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.2)

In [3]:
print("Num timestamps:", train_dataset.snapshot_count)
feats = train_dataset.features[0]
print("feats:", np.min(feats), np.max(feats))
print(feats.shape)
print("edge_index:", train_dataset.edge_indices[0])
print("edge_weight:", train_dataset.edge_weights[0])
targets = train_dataset.targets[0]
print("targets:", np.min(targets), np.max(targets))
print(targets.shape)

Num timestamps: 24
feats: 0.0 1.0
(1000, 16)
edge_index: [[ 42 909 909 909 233 233 450 256 256 256 256 256 434 434 434 233 233 233
  233 233 233 233   9   9 355  84  84  84  84 140 140 140 140   0 140 238
  238 238 649 875 875 234  73  73 341 341 341 341 341 417 293 991  74 581
  282 162 144 383 383 135 135 910 910 910 910 910  87  87  87  87   9   9
  934 934 162 225  42 911 911 911 911 911 911 911 911 498 498  64 435]
 [  0 138   5   0   1 121 389 133 124 103   0 436   0   2 926 256 103 133
    0 436 124 155   0   8   0  15 457   0  17   3   8  74   0   5   9   0
    6  15   2   0   5   0  31  73   8   9  13  32   0   0   0   1   0   0
    0   0   0   0   9   0   4 670 846   4   0 679 564   0 119 128  32  74
    5   8 244   0  11   0  79  54  32 119   8 225  19  24   0   7   0]]
edge_weight: [2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 1 1 1 1 1 1 1 1
 1 1 3 2 1 1 1 1 2 2 2 1 1 1 3]
targets: 0.0 

In [4]:
# Link prediction

def gather_node_embs(x, edge_index):
    cls_input = list()
    for node_index in edge_index:
        cls_input.append(x[node_index])
    return torch.cat(cls_input, dim=1)

class Classifier(torch.nn.Module):
    def __init__(self, in_feats, cls_feats, out_class):
        super(Classifier, self).__init__()
        lin1 = torch.nn.Linear(in_features=in_feats, out_features=cls_feats)
        activation = torch.nn.ReLU()
        lin2 = torch.nn.Linear(in_features=cls_feats, out_features=out_class)
        self.mlp = torch.nn.Sequential(lin1, activation, lin2)
    
    def forward(self, x):
        return self.mlp(x)

class EGCNH_Link(torch.nn.Module):
    def __init__(self, node_count, node_features):
        super(EGCNH_Link, self).__init__()
        edge_features = node_features * 2  # Concat src and dst features
        self.recurrent = EvolveGCNH(node_count, node_features)
        self.classifier = Classifier(edge_features, edge_features, 2)
    
    def forward(self, x, edge_index, edge_weight):
        nodes_embs = self.recurrent(x, edge_index, edge_weight)
        cls_input = gather_node_embs(nodes_embs, edge_index)
        out = self.classifier(cls_input)
        return out

In [5]:
num_nodes, num_feats = dataset.features[0].shape
print(num_nodes, num_feats)
model = EGCNH_Link(node_count=num_nodes, node_features=num_feats)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

1000 16


EGCNH_Link(
  (recurrent): EvolveGCNH(
    (pooling_layer): TopKPooling(16, ratio=0.016, multiplier=1.0)
    (recurrent_layer): GRU(16, 16)
    (conv_layer): GCNConv_Fixed_W(16, 16)
  )
  (classifier): Classifier(
    (mlp): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)

In [6]:
for step, snapshot in enumerate(train_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    print("step {}:".format(step), snapshot.x.shape, snapshot.edge_index.shape, y_hat.shape)
    print(y_hat[:10])
    break

step 0: torch.Size([1000, 16]) torch.Size([2, 89]) torch.Size([89, 2])
tensor([[ 0.1029, -0.1360],
        [ 0.1103, -0.1215],
        [ 0.0992, -0.1524],
        [ 0.0994, -0.1319],
        [ 0.1033, -0.0747],
        [ 0.0990, -0.0608],
        [ 0.0444, -0.0596],
        [ 0.1102, -0.1043],
        [ 0.1102, -0.1043],
        [ 0.1102, -0.1043]], grad_fn=<SliceBackward0>)
