To use pytorch geometric temporal, make sure you have torch 1.9.0 installed (uninstall 1.10.0 before).

In [2]:
import torch
print(torch.__version__)

1.9.0+cpu


In [3]:
import torch
import numpy as np

In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cpu.html
!pip install torch-geometric
!pip install torch-geometric-temporal

In [4]:
from sklearn.preprocessing import normalize
from sklearn.preprocessing import MinMaxScaler

def transform_and_split(data):
    # Normalize node features and transform data type
    data.x = normalize(data.x, axis=1, norm='max')
    data.x = torch.from_numpy(data.x).to(torch.float64)
    data.y = data.y.apply_(lambda x:  1 if (x > 0) else 0) # Change y into {0, 1} for binary classification
    data.y = data.y.to(torch.float64)    
    data.edge_attr = data.edge_attr.to(torch.double)


    # Split into train/test set
#    split = nodeSplit(split="train_rest", num_splits = 1, num_val = 0.0, num_test= 0.2)
#    masked_data = split(data)

#    print("Training samples:", torch.sum(masked_data.train_mask).item())
#    print("Validation samples:", torch.sum(masked_data.val_mask ).item())
#    print("Test samples:", torch.sum(masked_data.test_mask ).item())
    print_basic_info(data)
    return data

In [5]:
def print_basic_info(data):
    print()
    print(data)
    print('===========================================================================================================')

    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    print(f'Has isolated nodes: {data.has_isolated_nodes()}')
    print(f'Has self-loops: {data.has_self_loops()}')
    print(f'Is undirected: {data.is_undirected()}')

### Get and split data

In [6]:
path = "../data/processed/twitter/2018_q1.pt" # Customize...
dataset = torch.load(path)
data = dataset[0]
transformed_data = transform_and_split(data)


Data(x=[29, 61], edge_index=[2, 400], edge_attr=[400], y=[29])
Number of nodes: 29
Number of edges: 400
Average node degree: 13.79
Has isolated nodes: False
Has self-loops: True
Is undirected: True


In [7]:
from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader

loader = ChickenpoxDatasetLoader()

dataset = loader.get_dataset()

In [8]:
dataset.edge_weight

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [9]:
path = "../data/processed/twitter/"

In [10]:
quarter = ['2016_q4']
for i in range(2017, 2022):
    for j in range(1, 5):
        if i == 2021 and j == 4: break
        quarter.append(str(i)+'_q'+str(j))

In [11]:
quarter

['2016_q4',
 '2017_q1',
 '2017_q2',
 '2017_q3',
 '2017_q4',
 '2018_q1',
 '2018_q2',
 '2018_q3',
 '2018_q4',
 '2019_q1',
 '2019_q2',
 '2019_q3',
 '2019_q4',
 '2020_q1',
 '2020_q2',
 '2020_q3',
 '2020_q4',
 '2021_q1',
 '2021_q2',
 '2021_q3']

In [12]:
paths = []
for i in quarter:
    paths.append(path+i+'.pt')

In [13]:
paths

['../data/processed/twitter/2016_q4.pt',
 '../data/processed/twitter/2017_q1.pt',
 '../data/processed/twitter/2017_q2.pt',
 '../data/processed/twitter/2017_q3.pt',
 '../data/processed/twitter/2017_q4.pt',
 '../data/processed/twitter/2018_q1.pt',
 '../data/processed/twitter/2018_q2.pt',
 '../data/processed/twitter/2018_q3.pt',
 '../data/processed/twitter/2018_q4.pt',
 '../data/processed/twitter/2019_q1.pt',
 '../data/processed/twitter/2019_q2.pt',
 '../data/processed/twitter/2019_q3.pt',
 '../data/processed/twitter/2019_q4.pt',
 '../data/processed/twitter/2020_q1.pt',
 '../data/processed/twitter/2020_q2.pt',
 '../data/processed/twitter/2020_q3.pt',
 '../data/processed/twitter/2020_q4.pt',
 '../data/processed/twitter/2021_q1.pt',
 '../data/processed/twitter/2021_q2.pt',
 '../data/processed/twitter/2021_q3.pt']

In [14]:
data_list = []

In [15]:
for path in paths:
    dataset = torch.load(path)
    data = dataset[0]
    data_list.append(transform_and_split(data))


Data(x=[29, 63], edge_index=[2, 760], edge_attr=[760], y=[29])
Number of nodes: 29
Number of edges: 760
Average node degree: 26.21
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 62], edge_index=[2, 312], edge_attr=[312], y=[29])
Number of nodes: 29
Number of edges: 312
Average node degree: 10.76
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 400], edge_attr=[400], y=[29])
Number of nodes: 29
Number of edges: 400
Average node degree: 13.79
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 552], edge_attr=[552], y=[29])
Number of nodes: 29
Number of edges: 552
Average node degree: 19.03
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data(x=[29, 63], edge_index=[2, 805], edge_attr=[805], y=[29])
Number of nodes: 29
Number of edges: 805
Average node degree: 27.76
Has isolated nodes: False
Has self-loops: True
Is undirected: True

Data

In [16]:
len(data_list)

20

In [17]:
data_list[1].x.shape

torch.Size([29, 62])

In [18]:
"""
edge_indices = [i.edge_index.double() for i in data_list]
edge_weights = [i.edge_attr.double() for i in data_list]
features = [i.x.double() for i in data_list]
targets = [i.y.double() for i in data_list]
"""
"""
edge_indices = [i.edge_index.cpu().detach().numpy() for i in data_list]
edge_weights = [i.edge_attr.cpu().detach().numpy() for i in data_list]
features = [i.x.cpu().detach().numpy() for i in data_list]
targets = [i.y.cpu().detach().numpy() for i in data_list]
"""
edge_indices = [i.edge_index.numpy() for i in data_list]
edge_weights = [i.edge_attr.numpy() for i in data_list]
features = [i.x.numpy() for i in data_list]
targets = [i.y.numpy() for i in data_list]

In [19]:
features[1].shape

(29, 62)

In [20]:
padded_features = []
for i in features:
    padded_features.append(np.pad(i, [(0, 0), (0, 64-i.shape[1])], 'mean'))

In [21]:
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal

In [22]:
temporal_signal = DynamicGraphTemporalSignal(edge_indices = edge_indices , edge_weights = edge_weights, features = padded_features, targets = targets)

In [23]:
temporal_signal

<torch_geometric_temporal.signal.dynamic_graph_temporal_signal.DynamicGraphTemporalSignal at 0x294251cf0d0>

In [24]:
from torch_geometric_temporal.signal import temporal_signal_split

train_dataset, test_dataset = temporal_signal_split(temporal_signal, train_ratio=0.8)

In [54]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal import TemporalConv
from torch_geometric_temporal import EvolveGCNO
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.evol = EvolveGCNO(node_features)
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(64, 2)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x, edge_index, edge_weight):
        h = self.evol(x, edge_index, edge_weight)
#        h = self.dropout(h)
#        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        h = torch.sigmoid(h)
        return h

In [55]:
for time, snapshot in enumerate(temporal_signal):
    print(time)
    print(snapshot.edge_index.shape)

0
torch.Size([2, 760])
1
torch.Size([2, 312])
2
torch.Size([2, 400])
3
torch.Size([2, 552])
4
torch.Size([2, 805])
5
torch.Size([2, 400])
6
torch.Size([2, 616])
7
torch.Size([2, 585])
8
torch.Size([2, 805])
9
torch.Size([2, 552])
10
torch.Size([2, 552])
11
torch.Size([2, 672])
12
torch.Size([2, 777])
13
torch.Size([2, 480])
14
torch.Size([2, 585])
15
torch.Size([2, 697])
16
torch.Size([2, 805])
17
torch.Size([2, 616])
18
torch.Size([2, 720])
19
torch.Size([2, 741])


In [58]:
from tqdm import tqdm

model = RecurrentGCN(node_features = 64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(200)):
    loss = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        loss += torch.nn.CrossEntropyLoss()(y_hat, snapshot.y.long())
        
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

100%|██████████| 200/200 [00:05<00:00, 38.89it/s]


In [72]:
y_hat_l = []
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    #cost = cost + torch.mean((y_hat-snapshot.y)**2)
    y_hat_l.append(y_hat)
#cost = cost / (time+1)
#cost = cost.item()
#print("MSE: {:.4f}".format(cost))


In [70]:
y_hat_l = [list(x.detach().numpy()) for x in y_hat_l]

In [71]:
y_hat_l

[[array([1.4112265e-05, 9.5221007e-01], dtype=float32),
  array([2.3635634e-04, 9.2230040e-01], dtype=float32),
  array([1.6694033e-04, 9.2669183e-01], dtype=float32),
  array([1.5047046e-14, 9.9902546e-01], dtype=float32),
  array([6.4532976e-08, 9.8286659e-01], dtype=float32),
  array([6.0020342e-05, 9.4057894e-01], dtype=float32),
  array([2.3637977e-04, 9.2248428e-01], dtype=float32),
  array([4.2942323e-05, 9.4137961e-01], dtype=float32),
  array([1.6239964e-07, 9.7796118e-01], dtype=float32),
  array([9.656785e-10, 9.921604e-01], dtype=float32),
  array([5.9070389e-06, 9.6132946e-01], dtype=float32),
  array([2.0623754e-06, 9.6632743e-01], dtype=float32),
  array([2.0013958e-04, 9.2452854e-01], dtype=float32),
  array([1.6824447e-04, 9.2715538e-01], dtype=float32),
  array([2.3627274e-04, 9.2238325e-01], dtype=float32),
  array([1.9963564e-04, 9.2460525e-01], dtype=float32),
  array([8.04485e-10, 9.92780e-01], dtype=float32),
  array([5.007305e-05, 9.422055e-01], dtype=float32),


In [73]:
y_hat_l = [list(np.squeeze(i.detach().numpy())) for i in y_hat_l]
y_hat_l = [z for y in y_hat_l for z in y]

In [75]:
y_hat_l = [y[1] for y in y_hat_l]

In [76]:
y_hat_l

[0.95221007,
 0.9223004,
 0.92669183,
 0.99902546,
 0.9828666,
 0.94057894,
 0.9224843,
 0.9413796,
 0.9779612,
 0.9921604,
 0.96132946,
 0.9663274,
 0.92452854,
 0.9271554,
 0.92238325,
 0.92460525,
 0.99278,
 0.9422055,
 0.9962095,
 0.9223467,
 0.95540285,
 0.939182,
 0.9531975,
 0.92244595,
 0.9269326,
 0.9999999,
 0.931438,
 0.9224272,
 0.9317685,
 0.8711903,
 0.87229556,
 0.87374264,
 0.9860931,
 0.9848681,
 0.87131196,
 0.87335646,
 0.8891789,
 0.9630519,
 0.98769194,
 0.9068682,
 0.90559715,
 0.873081,
 0.87322354,
 0.8721459,
 0.8736636,
 0.93945384,
 0.8730093,
 0.8907325,
 0.8720679,
 0.87244177,
 0.96595395,
 0.9477264,
 0.87302196,
 0.8731846,
 0.99999905,
 0.8901004,
 0.8732694,
 0.9416565,
 0.8062335,
 0.75295025,
 0.7683748,
 0.9676894,
 0.95532334,
 0.83741784,
 0.7544498,
 0.7542448,
 0.8525401,
 0.88807017,
 0.7862438,
 0.7700982,
 0.7669716,
 0.7530386,
 0.7536409,
 0.7685701,
 0.8201763,
 0.7696135,
 0.75365984,
 0.7542463,
 0.76697147,
 0.9181942,
 0.7697987,
 0.75

In [130]:
y_hat_list = [1 if x > 0.82 else 0 for x in y_hat_l]

In [131]:
true_label = []
for time, snapshot in enumerate(test_dataset):
    true_label.append(list(snapshot.y.detach().numpy()))

In [132]:
true_label = [int(z) for y in true_label for z in y]

In [133]:
from sklearn.metrics import classification_report
y_true = true_label
target_names = ['class 0', 'class 1']
print(classification_report(y_true, y_hat_list, target_names=target_names))

              precision    recall  f1-score   support

     class 0       0.57      0.64      0.61        42
     class 1       0.78      0.73      0.76        74

    accuracy                           0.70       116
   macro avg       0.68      0.69      0.68       116
weighted avg       0.71      0.70      0.70       116

