In [1]:
import torch
import torch_geometric

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

In [2]:
import matplotlib as plt
import numpy as np
import networkx
from torch_geometric.data import Data
from sklearn.manifold import TSNE

In [3]:
from torch_geometric.utils import to_dense_adj

## Data Preparation on MUTAG

In [4]:
dataset = TUDataset(root="../dataset", name='MUTAG')

In [7]:
adj_o = to_dense_adj(dataset[2].edge_index)
adj_c = abs(to_dense_adj(dataset[2].edge_index) - 1) - torch.eye(len(dataset[2].x))

print("Original:\n", adj_o)
print("Complementary:\n", (adj_c))

Original:
 tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]])
Complementary:
 tensor([[[0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 0., 0., 0., 

In [8]:
def toComplementary(g):
    c = abs(to_dense_adj(g.edge_index) - 1) - torch.eye(len(g.x))
    c = c[0].nonzero().t().contiguous()
    return c

In [9]:
dataset_c = []
for graph in dataset:
    edge_c = toComplementary(graph)
    dataset_c.append(Data(edge_index=edge_c, x=graph.x, y=graph.y))

In [27]:
dataset.y

tensor([1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,
        0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0])

In [31]:
ys = []
for d in dataset_c:
    ys.append(d.y.item())
print(ys)

[1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0]


In [74]:
# Train test split
ratio = 0.5
total = len(dataset)

# original graph
g_train = dataset[:round(ratio*total)]
g_test = dataset[round(ratio*total):]

# complementary graph
gc_train = dataset_c[:round(ratio*total)]
gc_test = dataset_c[round(ratio*total):]

In [129]:
print(f'g_train {g_train}')
print(f'g_test {g_test}')
print(f'gc_train {len(gc_train)}')
print(f'gc_test {len(gc_test)}')

g_train MUTAG(94)
g_test MUTAG(94)
gc_train 94
gc_test 94


In [199]:
print([x.y.item() for x in g_train])

[1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1]


In [200]:
print([x.y.item() for x in gc_train])

[1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1]


In [206]:
bs = 32
seed = 12345

g_train_loader = DataLoader(g_train, batch_size=bs, shuffle=False)
g_test_loader = DataLoader(g_test, batch_size=bs, shuffle=False)

gc_train_loader = DataLoader(gc_train, batch_size=bs, shuffle=False)
gc_test_loader = DataLoader(gc_test, batch_size=bs, shuffle=False)

In [213]:
for g in g_train_loader:
    print(g.y)
#     break

tensor([1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        1, 1, 1, 1, 1, 0, 1, 1])
tensor([0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
        0, 1, 1, 1, 1, 1])


In [214]:
for g in g_test_loader:
    print(g.y)
    break

tensor([1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 1, 1])


In [215]:
for g in gc_train_loader:
    print(g.y)
#     break

tensor([1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        1, 1, 1, 1, 1, 0, 1, 1])
tensor([0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0,
        0, 1, 1, 1, 1, 1])


In [220]:
for g in gc_test_loader:
    print(g.y)
    break

tensor([1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 1, 1])


## Building model

In [221]:
from torch_geometric.nn import GCNConv
from torch.nn import Linear
from torch.nn import Linear
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import global_add_pool
import torch.nn.functional as F

In [253]:
class ComplementarySupCon(torch.nn.Module):
    def __init__(self, dataset, hidden_channels):
        super(ComplementarySupCon, self).__init__()
        
        # weight seed
        torch.manual_seed(42)
        self.conv1_o = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2_o = GCNConv(hidden_channels, hidden_channels)
        
        self.conv1_c = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2_c = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x_o, x_c, edge_index_o, edge_index_c, batch_o, batch_c):
        x_o = self.conv1_o(x_o, edge_index_o)
        x_o = x_o.relu()
        x_o = self.conv2_o(x_o, edge_index_o)
        
        x_c = self.conv1_o(x_c, edge_index_c)
        x_c = x_c.relu()
        x_c = self.conv2_o(x_c, edge_index_c)

        h = x_o - x_c
        
        h = global_add_pool(h, batch_o)
        
        return h, x_o, x_c

In [254]:
model = ComplementarySupCon(dataset, 64)

In [255]:
model

ComplementarySupCon(
  (conv1_o): GCNConv(7, 64)
  (conv2_o): GCNConv(64, 64)
  (conv1_c): GCNConv(7, 64)
  (conv2_c): GCNConv(64, 64)
)

In [264]:
h, x_o, x_c = None, None, None
for index, (g_o, g_c) in enumerate(zip(g_train_loader, gc_train_loader)):
    h, x_o, x_c = (model(g_o.x, g_c.x, g_o.edge_index, g_c.edge_index, g_o.batch, g_c.batch))
    break

In [273]:
h.size()

torch.Size([585, 64])

In [278]:
g_o.x.size()

torch.Size([585, 7])

In [281]:
g_o.batch

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
         4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
        11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 