In [1]:
import numpy as np
from sklearn.metrics import normalized_mutual_info_score

import torch
import torch.nn.functional as F
import torch_geometric as tg
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool, dense_mincut_pool
from torch_geometric.utils import to_dense_adj

from dataset import load_pyg_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
pooling = 'mincut' #(options: diffpool, mincut)
dataset_name = 'Twitch-RU'

#Pooling selector
pooling_selector = {
    'diffpool': dense_diff_pool,
    'mincut': dense_mincut_pool
}

In [3]:
data = load_pyg_dataset(
    data_name=dataset_name,
    device=device
)
num_clusters = data.y.max().tolist()+1
data.adj = to_dense_adj(data.edge_index)

The obtained data Twitch-RU has 4385 nodes, 78993 edges, 128 features, 2 labels, 


In [4]:
#GNN to compute transformed node features for pooling (for assignation matrix)
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GNN, self).__init__()
        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
        self.bn3 = torch.nn.BatchNorm1d(out_channels)
        if lin is True:
            self.lin = torch.nn.Linear(
                2 * hidden_channels + out_channels,
                out_channels
            )
        else:
            self.lin = None
            
    def bn(self, i, x):
        batch_size, num_nodes, num_channels = x.size()
        x = x.view(-1, num_channels)
        x = getattr(self, 'bn{}'.format(i))(x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x
    
    def forward(self, x, adj, mask=None):
        x0 = x
        if pooling == 'diffpool':
            x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask)))
            x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask)))
            x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask)))
        elif pooling == 'mincut':
            x1 = F.relu(self.conv1(x0, adj, mask))
            x2 = F.relu(self.conv2(x1, adj, mask))
            x3 = F.relu(self.conv3(x2, adj, mask))
        x = torch.cat([x1, x2, x3], dim=-1)
        if self.lin is not None:
            x = F.relu(self.lin(x))
        return x

#Net to compute pooling in an unsupervised manner
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gnn1_pool = GNN(data.num_node_features, 64, num_clusters)
        self.gnn1_embed = GNN(data.num_node_features, 64, 64, lin=False)
        self.pooling = pooling_selector[pooling]
        
    def forward(self, x, adj, mask=None):
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)
        x, adj, l1, e1 = self.pooling(x, adj, s, mask)
        return torch.softmax(s, dim=-1), l1, e1 # returns assignation matrix, and auxiliary losses

In [5]:
###########
#Training
###########
#Optimizer, model
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#Trains using auxiliary losses only
def train(epoch):
    model.train()
    optimizer.zero_grad()
    s, l1, e1 = model(data.x, data.adj)
    loss = l1 + e1
    loss.backward()
    optimizer.step()
    return loss

#Measures mutual information of clustering computed by pooling and ground truth
@torch.no_grad()
def test():
    model.eval()
    pred_node_label = model(data.x, data.adj)[0].max(dim=-1)[1].detach().cpu().numpy() 
    truth_node_labels = data.y.cpu().numpy()
    nmi = normalized_mutual_info_score(truth_node_labels.flatten(), pred_node_label.flatten())
    return nmi

In [6]:
best_val_acc = best_test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test()
    test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.8f}, '
          f'Val NMI: {val_acc:.4f}, Test NMI: {test_acc:.4f}')
print(f'Best Val NMI: {best_val_acc:.4f}, Test NMI: {best_test_acc:.4f}')

Epoch: 001, Train Loss: -0.23608363, Val NMI: 0.0001, Test NMI: 0.0001
Epoch: 002, Train Loss: -0.23774678, Val NMI: 0.0002, Test NMI: 0.0002
Epoch: 003, Train Loss: -0.23969305, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 004, Train Loss: -0.24214876, Val NMI: 0.0002, Test NMI: 0.0002
Epoch: 005, Train Loss: -0.24522078, Val NMI: 0.0003, Test NMI: 0.0003
Epoch: 006, Train Loss: -0.24906558, Val NMI: 0.0003, Test NMI: 0.0003
Epoch: 007, Train Loss: -0.25378245, Val NMI: 0.0004, Test NMI: 0.0004
Epoch: 008, Train Loss: -0.25953442, Val NMI: 0.0004, Test NMI: 0.0004
Epoch: 009, Train Loss: -0.26649177, Val NMI: 0.0002, Test NMI: 0.0002
Epoch: 010, Train Loss: -0.27478069, Val NMI: 0.0001, Test NMI: 0.0001
Epoch: 011, Train Loss: -0.28451955, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 012, Train Loss: -0.29575616, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 013, Train Loss: -0.30857784, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 014, Train Loss: -0.32299548, Val NMI: 0.0000, Test NMI: 0.0000
Epoch:

Epoch: 118, Train Loss: -0.75166291, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 119, Train Loss: -0.75181323, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 120, Train Loss: -0.75168252, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 121, Train Loss: -0.75166380, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 122, Train Loss: -0.75178993, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 123, Train Loss: -0.75214320, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 124, Train Loss: -0.75228494, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 125, Train Loss: -0.75221425, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 126, Train Loss: -0.75221086, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 127, Train Loss: -0.75232643, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 128, Train Loss: -0.75257736, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 129, Train Loss: -0.75269455, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 130, Train Loss: -0.75267452, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: 131, Train Loss: -0.75268054, Val NMI: 0.0000, Test NMI: 0.0000
Epoch: