In [1]:
import sys
sys.path.append('./src')

In [2]:
import time
import numpy as np

import torch
from torch_geometric.data import Batch
from torch_geometric.datasets import TUDataset

In [3]:
dataset = TUDataset('tempdata/', name='ENZYMES')
data1 = Batch.from_data_list(dataset[:300])
data2 = Batch.from_data_list(dataset[300:])

In [4]:
#create egonets
def togk(data, bi):
# bi=1
    m = np.nonzero(data.batch==bi)
    x = data.x[m]
#     edg = torch.stack([d for d in data.edge_index.t() if min(d[0],d[1])>=m[0] and max(d[0],d[1])<=m[-1]],0)
#     edg = edg-m[0]
    em = ((data.edge_index>=m[0]) * (data.edge_index<=m[-1])).all(0)
    edg = data.edge_index[:,em].t()-m[0]
    return set([tuple(e) for e in edg.numpy()] + [(e[1],e[0]) for e in edg.numpy()] + [(i,i) for i in range(len(m))]), dict(enumerate(x.argmax(-1)[:,0].numpy()))

Gs1 = [togk(data1,bi) for bi in range(data1.y.shape[0])]
Gs2 = [togk(data2,bi) for bi in range(data2.y.shape[0])]


In [5]:
from WL_gpu import WL

device='cuda'
wl = WL(num_layers=3)

X1 = data1.x.argmax(-1).to(device)
E1 = data1.edge_index.to(device)
B1 =  data1.batch.to(device)

X2 = data2.x.argmax(-1).to(device)
E2 = data2.edge_index.to(device)
B2 =  data2.batch.to(device)

wl = WL(3)
#warmup
wl.fit((X1,E1,B1))


print("Fit & transform n times")
t=time.time()
for i in range(10):
    wl.fit((X1,E1,B1))
    sim_self_gpu = wl.transform((X1,E1,B1))
    sim_gpu = wl.transform((X2,E2,B2))
print(time.time()-t)

print("Fit once & transform n times")
t=time.time()
wl.fit((X1,E1,B1))
for i in range(10):
    sim_self_gpu = wl.transform((X1,E1,B1))
    sim_gpu = wl.transform((X2,E2,B2))
print(time.time()-t)


Fit & transform n times
5.066545724868774
Fit once & transform n times
2.264070749282837


In [6]:
from grakel.kernels import WeisfeilerLehman
wl = WeisfeilerLehman(n_iter=3, normalize=True)

print("Fit & transform n times")
t=time.time()
for i in range(10):
    wl.fit(Gs1)
    sim_self_gk = wl.transform(Gs1).T
    sim_gk = wl.transform(Gs2).T
print(time.time()-t)

print("Fit once & transform n times")
t=time.time()
wl.fit(Gs1)
for i in range(10):
    sim_self_gk = wl.transform(Gs1).T
    sim_gk = wl.transform(Gs2).T
print(time.time()-t)


Fit & transform n times
6.560385227203369
Fit once & transform n times
5.238116025924683


In [7]:
display(sim_self_gpu)
display(torch.tensor(sim_self_gk))

tensor([[1.0000, 0.8005, 0.8063,  ..., 0.5853, 0.7059, 0.6917],
        [0.8005, 1.0000, 0.7670,  ..., 0.5342, 0.7214, 0.6610],
        [0.8063, 0.7670, 1.0000,  ..., 0.4188, 0.6705, 0.5641],
        ...,
        [0.5853, 0.5342, 0.4188,  ..., 1.0000, 0.6883, 0.7977],
        [0.7059, 0.7214, 0.6705,  ..., 0.6883, 1.0000, 0.8200],
        [0.6917, 0.6610, 0.5641,  ..., 0.7977, 0.8200, 1.0000]],
       device='cuda:0')

tensor([[1.0000, 0.8005, 0.8063,  ..., 0.5853, 0.7059, 0.6921],
        [0.8005, 1.0000, 0.7670,  ..., 0.5342, 0.7214, 0.6614],
        [0.8063, 0.7670, 1.0000,  ..., 0.4188, 0.6705, 0.5644],
        ...,
        [0.5853, 0.5342, 0.4188,  ..., 1.0000, 0.6883, 0.7982],
        [0.7059, 0.7214, 0.6705,  ..., 0.6883, 1.0000, 0.8205],
        [0.6921, 0.6614, 0.5644,  ..., 0.7982, 0.8205, 1.0000]],
       dtype=torch.float64)

In [8]:
 #to save computation, normalization is approximated and in case there isn't any isometric graph, it could lead to slightly different results.
display(sim_gpu)
display(torch.tensor(sim_gk)) 

tensor([[0.8881, 0.8775, 0.8761,  ..., 0.6911, 0.8061, 0.7627],
        [0.8280, 0.8281, 0.8240,  ..., 0.6453, 0.7636, 0.7271],
        [0.8516, 0.8308, 0.8496,  ..., 0.5540, 0.7277, 0.6864],
        ...,
        [0.4890, 0.5035, 0.4528,  ..., 0.8353, 0.7198, 0.7040],
        [0.6952, 0.7050, 0.6684,  ..., 0.7858, 0.8230, 0.8022],
        [0.6495, 0.6625, 0.6165,  ..., 0.8899, 0.7959, 0.7758]],
       device='cuda:0')

tensor([[0.8881, 0.8734, 0.8728,  ..., 0.6915, 0.8046, 0.7571],
        [0.8280, 0.8243, 0.8209,  ..., 0.6456, 0.7621, 0.7217],
        [0.8516, 0.8269, 0.8464,  ..., 0.5543, 0.7263, 0.6813],
        ...,
        [0.4890, 0.5011, 0.4511,  ..., 0.8357, 0.7184, 0.6989],
        [0.6952, 0.7017, 0.6659,  ..., 0.7862, 0.8214, 0.7963],
        [0.6499, 0.6598, 0.6145,  ..., 0.8909, 0.7949, 0.7706]],
       dtype=torch.float64)