In [1]:
import numpy as np
from sklearn.metrics import normalized_mutual_info_score

import torch
import torch.nn.functional as F
import torch_geometric as tg
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool, dense_mincut_pool
from torch_geometric.utils import to_dense_adj

from dataset import load_pyg_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
pooling = 'mincut' #(options: diffpool, mincut)
dataset_name = 'Cora'

#Pooling selector
pooling_selector = {
    'diffpool': dense_diff_pool,
    'mincut': dense_mincut_pool
}

In [7]:
data = load_pyg_dataset(
    data_name=dataset_name,
    device=device
)
# num_clusters = data.y.max().tolist()+1
num_clusters = int(0.05 * data.num_nodes)
data.adj = to_dense_adj(data.edge_index)

The obtained data Cora has 2708 nodes, 10556 edges, 1433 features, 7 labels, 


In [8]:
#GNN to compute transformed node features for pooling (for assignation matrix)
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GNN, self).__init__()
        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
        self.bn3 = torch.nn.BatchNorm1d(out_channels)
        if lin is True:
            self.lin = torch.nn.Linear(
                2 * hidden_channels + out_channels,
                out_channels
            )
        else:
            self.lin = None
            
    def bn(self, i, x):
        batch_size, num_nodes, num_channels = x.size()
        x = x.view(-1, num_channels)
        x = getattr(self, 'bn{}'.format(i))(x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x
    
    def forward(self, x, adj, mask=None):
        x0 = x
        if pooling == 'diffpool':
            x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask)))
            x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask)))
            x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask)))
        elif pooling == 'mincut':
            x1 = F.relu(self.conv1(x0, adj, mask))
            x2 = F.relu(self.conv2(x1, adj, mask))
            x3 = F.relu(self.conv3(x2, adj, mask))
        x = torch.cat([x1, x2, x3], dim=-1)
        if self.lin is not None:
            x = F.relu(self.lin(x))
        return x

#Net to compute pooling in an unsupervised manner
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gnn1_pool = GNN(data.num_node_features, 64, num_clusters)
        self.gnn1_embed = GNN(data.num_node_features, 64, 64, lin=False)
        self.pooling = pooling_selector[pooling]
        
    def forward(self, x, adj, mask=None):
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)
        x, adj, l1, e1 = self.pooling(x, adj, s, mask)
        return torch.softmax(s, dim=-1), l1, e1 # returns assignation matrix, and auxiliary losses

In [9]:
###########
#Training
###########
#Optimizer, model
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#Trains using auxiliary losses only
def train(epoch):
    model.train()
    optimizer.zero_grad()
    s, l1, e1 = model(data.x, data.adj)
    # link loss from DiffPool
    link_loss = data.adj - torch.matmul(s, s.transpose(1, 2))
    link_loss = torch.norm(link_loss, p=2)
    link_loss = link_loss / data.adj.numel()
    
#     loss = l1 + e1
#     loss = e1
    loss = e1 + link_loss
    loss.backward()
    optimizer.step()
    return loss

#Measures mutual information of clustering computed by pooling and ground truth
@torch.no_grad()
def test():
    model.eval()
#     pred_node_label = model(data.x, data.adj)[0].max(dim=-1)[1].detach().cpu().numpy() 
#     truth_node_labels = data.y.cpu().numpy()
#     nmi = normalized_mutual_info_score(truth_node_labels.flatten(), pred_node_label.flatten())
#     return nmi
    return 0

In [10]:
best_val_acc = best_test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test()
    test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
    print(model(data.x, data.adj)[0].max(dim=-1)[1].unique().shape)
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.8f}, '
          f'Val NMI: {val_acc:.4f}, Test NMI: {test_acc:.4f}')
print(f'Best Val NMI: {best_val_acc:.4f}, Test NMI: {best_test_acc:.4f}')

tensor([  4,   5,   8,  13,  21,  31,  32,  33,  43,  51,  53,  94, 106, 110,
        113, 122, 126, 128, 130], device='cuda:0')
Epoch: 001, Train Loss: 1.35199463, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   8,  13,  21,  29,  31,  33,  43,  51,  53,  55,  71,  78,
         94, 106, 110, 111, 113, 122, 126, 128, 130], device='cuda:0')
Epoch: 002, Train Loss: 1.35198998, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  2,   4,   5,   8,  13,  21,  28,  29,  31,  33,  43,  50,  51,  53,
         55,  67,  71,  78,  80,  86,  94, 106, 110, 113, 122, 125, 126, 128,
        130], device='cuda:0')
Epoch: 003, Train Loss: 1.35198283, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  2,   4,   5,   8,   9,  13,  21,  28,  29,  31,  33,  43,  50,  51,
         53,  55,  64,  66,  67,  70,  71,  78,  80,  86,  94,  99, 106, 110,
        111, 113, 117, 122, 125, 126, 128, 130], device='cuda:0')
Epoch: 004, Train Loss: 1.35197330, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  1,   2,   4,   5,   8, 

tensor([ 18,  21,  33,  43,  64,  81,  91, 106, 111, 113, 126],
       device='cuda:0')
Epoch: 047, Train Loss: 1.26824105, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 21,  33,  43,  64,  81,  91, 106, 111, 113, 126], device='cuda:0')
Epoch: 048, Train Loss: 1.26303339, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 21,  33,  43,  58,  64,  81,  91, 106, 111, 113, 126],
       device='cuda:0')
Epoch: 049, Train Loss: 1.25803852, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 21,  33,  43,  58,  64,  81,  91, 106, 111, 113, 126],
       device='cuda:0')
Epoch: 050, Train Loss: 1.25499177, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 18,  21,  33,  43,  58,  64,  81,  91, 106, 111, 113, 126],
       device='cuda:0')
Epoch: 051, Train Loss: 1.25036657, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 18,  21,  33,  43,  50,  58,  64,  81,  91, 106, 111, 113, 126],
       device='cuda:0')
Epoch: 052, Train Loss: 1.24664247, Val NMI: 0.0000, Test NMI: 0.0000
tensor([ 18,  21,  33,  43,  50,  58,  64,  81,  

tensor([  4,   5,   7,   9,  12,  13,  17,  18,  21,  23,  24,  32,  33,  36,
         37,  43,  46,  50,  52,  53,  58,  61,  64,  71,  72,  79,  81,  84,
         88,  89,  91,  97,  99, 103, 106, 111, 113, 115, 117, 123, 124, 125,
        126, 133], device='cuda:0')
Epoch: 083, Train Loss: 1.05571425, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,   9,  12,  13,  17,  18,  21,  23,  24,  32,  33,  37,
         43,  46,  50,  52,  53,  58,  61,  64,  71,  78,  81,  88,  91,  97,
         99, 103, 106, 111, 113, 115, 117, 123, 124, 125, 126, 133],
       device='cuda:0')
Epoch: 084, Train Loss: 1.05856740, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,   9,  12,  13,  17,  18,  21,  24,  32,  33,  37,  43,
         46,  50,  52,  53,  58,  61,  64,  71,  78,  81,  88,  91,  97,  99,
        102, 103, 106, 111, 113, 115, 117, 123, 124, 125, 126, 133],
       device='cuda:0')
Epoch: 085, Train Loss: 1.05029702, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,

tensor([  4,   5,   6,   7,   9,  10,  12,  17,  18,  21,  24,  33,  35,  43,
         46,  47,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,
         89,  91,  92,  93,  97,  99, 100, 102, 103, 106, 111, 113, 115, 117,
        121, 125, 126], device='cuda:0')
Epoch: 107, Train Loss: 0.98677534, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   6,   7,   9,  10,  12,  17,  18,  21,  24,  33,  35,  43,
         46,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,  89,
         91,  92,  97,  99, 100, 102, 103, 106, 111, 113, 115, 117, 121, 125,
        126, 132], device='cuda:0')
Epoch: 108, Train Loss: 0.97835970, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,   9,  10,  12,  15,  17,  18,  21,  24,  33,  35,  43,
         46,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,  89,
         91,  92,  97,  99, 100, 102, 103, 106, 108, 111, 113, 115, 117, 121,
        125, 126, 132], device='cuda:0')
Epoch: 109, Train Loss: 0.98057729, Val 

tensor([  4,   5,   7,   9,  10,  12,  15,  17,  18,  21,  24,  33,  35,  43,
         46,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,  89,
         91,  92,  97,  99, 100, 102, 103, 106, 108, 111, 113, 115, 117, 120,
        121, 125, 126, 132], device='cuda:0')
Epoch: 131, Train Loss: 0.93321484, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,   9,  10,  12,  15,  17,  18,  21,  24,  33,  35,  43,
         46,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,  89,
         91,  92,  97,  99, 100, 102, 103, 106, 108, 111, 113, 115, 117, 120,
        121, 125, 126, 132], device='cuda:0')
Epoch: 132, Train Loss: 0.93195468, Val NMI: 0.0000, Test NMI: 0.0000
tensor([  4,   5,   7,   9,  10,  12,  15,  17,  18,  21,  24,  33,  35,  43,
         46,  50,  52,  53,  58,  61,  63,  64,  68,  71,  81,  84,  88,  89,
         91,  92,  97,  99, 100, 102, 103, 106, 108, 111, 113, 115, 117, 120,
        121, 125, 126, 132], device='cuda:0')
Epoch: 133, Train Lo

In [17]:
model(data.x, data.adj)[0].max(dim=-1)[1].unique().shape

torch.Size([48])

In [18]:
num_clusters

135