In [1]:
%reload_ext autoreload
%autoreload 2

In [102]:
import pandas as pd
import networkx as nx
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
from torch_geometric_temporal.signal import (
    temporal_signal_split,
    StaticGraphTemporalSignal,
)

from tqdm import tqdm

# Read and process email network data

In [61]:
df = pd.read_csv('https://snap.stanford.edu/data/email-Eu-core-temporal-Dept1.txt.gz', compression='gzip', delim_whitespace=True, header=None, names=['source', 'target','timestamp'])

In [62]:
# Timestamp is in seconds. Generate day.
df['day'] = df.timestamp//(60*60*24)

# Drop last (not full) day as it is not representative
df = df[df.day!=df.day.max()]

# Aggregate to daily number of emails to have edge weights
aggregates = df.groupby([x for x in list(df) if x!='timestamp']).count().reset_index().rename(columns={'timestamp':'weight'})


# Construct network

In [63]:
# create node indices
translator = dict(zip(a:=pd.concat([aggregates.source, aggregates.target]).unique(), range(len(a))))
aggregates['id_source'] = aggregates.source.map(lambda x: translator[x])
aggregates['id_target'] = aggregates.target.map(lambda x: translator[x])

In [64]:
# Generate network
network = nx.from_pandas_edgelist(aggregates, source='id_source', target='id_target', edge_attr='weight', create_using=nx.DiGraph)

# Transform to line graph
line_graph = nx.line_graph(network)

In [65]:
# create final node indices
line_node_list = list(line_graph.nodes())
line_translator = dict(zip(a:=list(set(line_node_list)), range(len(a))))

In [66]:
# Create edge list and edge weights
# Edge weights are average number of emails of connecting node

edges = [[],[]]
edge_weights = []
for source, target in list(line_graph.edges):
    edges[0].append(line_translator[source])
    edges[1].append(line_translator[target])
    edge_weights.append(aggregates[((aggregates.id_source==source[1])|(aggregates.id_target==source[1]))].weight.mean())

edges = np.array(edges)
edge_weights = np.array(edge_weights)

# Create data for GNN

In [67]:
# Generate target arrays
numdays = aggregates.day.nunique()
aggregates = aggregates.set_index(['id_source', 'id_target', 'day'])
line_retranslator = {k:v for v, k in line_translator.items()}

In [89]:
targets = []
for day in range(numdays):
    daily_targets = []
    for node in range(len(line_node_list)):
        source, target = line_retranslator[node]
        try:
            daily_targets.append(aggregates.loc[(source, target, day)].weight)
        except KeyError:
            daily_targets.append(0)
    targets.append(daily_targets)

In [91]:
# Generate features (messages in last 10 days)
feat_num = 10
features = []

for day in range(numdays-feat_num):
    daily_features = []
    for node in range(len(line_node_list)):
        node_feat = []
        for feature_i in range(feat_num):
            node_feat.append(targets[day+feature_i][node])
        daily_features.append(node_feat)
    features.append(daily_features)

In [95]:
# To array
targets = np.array(targets[feat_num:])
features = np.array(features)

In [96]:
# Create data iterator
data = StaticGraphTemporalSignal(edges, edge_weights, features, targets)

# Define and train GNN

In [98]:
# Set up dcrnn model


class RecurrentGCN(torch.nn.Module):
    """Class for a pytorch neural network module"""

    def __init__(self, node_features: int, out_channels: int, filter_size: int):
        """
        Initialize a pytorch model with DCRNN architecture

        Args:
            node_features:
            Number of node features to use.
            out_channels:
            Number of DCRNN hidden features.
            filter_size:
            DCRNN filter size.
        """
        super().__init__()
        self.recurrent = DCRNN(node_features, out_channels, filter_size)
        self.linear = torch.nn.Linear(out_channels, filter_size)

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor
    ) -> torch.Tensor:
        """Perform a forward feed
        Args:
            x:
            feature Pytorch Float Tensor
            edge_index:
            Pytorch Float Tensor of edge indices
            edge_weight:
            Pytorch Float Tensor of edge weights

        Returns:
            tens:
            Pytorch Float Tensor of Hidden state matrix for all nodes
        """
        tens = self.recurrent(x, edge_index, edge_weight)
        tens = F.relu(tens)
        tens = self.linear(tens)
        return tens

In [99]:
# Instantiate a model
model = RecurrentGCN(10, 32, 1)

In [118]:
# train with adam
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
model.train()

for _ in tqdm(range(20), "Optimization with temporal backpropagation"):
    cost = 0
    datapoints = 0
    for _, snapshot in enumerate(data):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        cost = cost + torch.mean((y_hat - snapshot.y) ** 2)
        datapoints += 1
    cost = cost / (datapoints)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

Optimization with temporal backpropagation: 100%|██████████| 20/20 [10:57<00:00, 32.90s/it]


# Post mortem

In [119]:
# get predictions
yhats = []
for _, snapshot in enumerate(data):
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        yhats.append(y_hat)

In [131]:
# check first period prediction
yhats[0]

tensor([[-0.2744],
        [-0.2744],
        [-0.2744],
        ...,
        [-0.4697],
        [-0.4911],
        [-0.2744]], grad_fn=<AddmmBackward0>)

In [132]:
# check first period fact
enumerate(data).__next__()[1].y

tensor([0, 0, 0,  ..., 0, 0, 0])