In [51]:
import dgl
import numpy as np
import networkx as nx

import torch
from torch import nn

### Creating a dgl graph from connectivity matrix

In [7]:
TEST_PATH = "/Users/h1de0us/uni/mer-eeg-analysis/data/deap_filtered/s01_plv.npy"
connectivity_matrix = np.load(TEST_PATH)
connectivity_matrix.shape

(32, 32, 5)

In [8]:
connectivity_matrix = connectivity_matrix[:, :, -1]
connectivity_matrix

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.57803585, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.33041863, 0.56471143, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.28629317, 0.28253806, 0.27283744, ..., 0.        , 0.        ,
        0.        ],
       [0.34106162, 0.56037416, 0.60070708, ..., 0.27732514, 0.        ,
        0.        ],
       [0.37746005, 0.45468193, 0.43395707, ..., 0.25628626, 0.71649264,
        0.        ]])

In [42]:
threshold = 0.3
connectivity_matrix[connectivity_matrix < threshold] = 0 # remove weak connections
connectivity_matrix += np.rot90(np.fliplr(connectivity_matrix)) # make the matrix symmetric

In [43]:
nx_graph = nx.from_numpy_array(connectivity_matrix)
nx_graph = nx_graph.to_directed()


In [44]:
nx_graph.number_of_edges()

426

In [45]:
dgl_graph = dgl.from_networkx(nx_graph, edge_attrs=['weight'])
dgl_graph

Graph(num_nodes=32, num_edges=426,
      ndata_schemes={}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)})

### Centrality encoding

In [52]:
n_nodes = dgl_graph.number_of_nodes()
dim_feedforward = 64

centrality_encoding = nn.Embedding(n_nodes, dim_feedforward)

In [47]:
dgl_graph.in_degrees().shape

torch.Size([32])

In [53]:
centrality = centrality_encoding(dgl_graph.in_degrees())
centrality.shape # (n_nodes, dim_feedforward)

torch.Size([32, 64])

### Spatial encoding

In [54]:
n_heads = 4
spatial_encoding = nn.Embedding(n_nodes, n_heads)

In [55]:
spd = dgl.shortest_dist(dgl_graph)
spd.shape

torch.Size([32, 32])

In [57]:
spatial = spatial_encoding(spd)
spatial.shape # (n_nodes, n_nodes, n_heads)

torch.Size([32, 32, 4])

### Edge encoding

In [64]:
spd, paths = dgl.shortest_dist(dgl_graph, return_paths=True)
paths[0, 2] # Each path is a vector that consists of edge IDs with paddings of -1 at the end. (via documentation)

tensor([ 1, -1, -1])

In [66]:
path = paths[0, 2]
path = path[path >= 0]
path

tensor([1])

In [68]:
edge_encoder = nn.Embedding(n_nodes ** 2, n_heads)
edge_features = dgl_graph.edata['weight']

In [79]:
i, j = 12, 24

_, path = dgl.shortest_dist(dgl_graph, i, return_paths=True)
# path is a sequence of nodes, len(path) == max_path 
# -1 is a padding value
path = path[j]
path = path[path >= 0] # remove padding
edge_embeds = edge_encoder(path) # (n_spd, n_heads)
spd_features = edge_features[path] # (n_spd)
result = torch.mean(edge_embeds * spd_features.unsqueeze(-1), dim=0)
result, result.shape

(tensor([ 0.9240,  0.2532, -0.2192,  0.0430], grad_fn=<MeanBackward1>),
 torch.Size([4]))