In [7]:
import os
from IPython.display import display
from typing import Union

import torch as T
import torch.nn as nn
import torch.optim as optim
import torchinfo
from sklearn.metrics import normalized_mutual_info_score, f1_score

from admon.model import DMoN
from admon.utils import load_npz, modularity, conductance

# Alias
_PathLike = Union[str, 'os.PathLike[str]']
CORA_DIR: _PathLike = './data/cora'

In [2]:
adj, features, labels, label_indices = load_npz(os.path.join(CORA_DIR, 'cora.npz'))
adj_tensor = T.tensor(adj.todense()).unsqueeze(0).float()
features_tensor = T.tensor(features.todense()).unsqueeze(0).float()

In [3]:
# hyperparameter
n_clusters: int = 16
num_epochs: int = 200
hidden: int = 64
depths: int = 1
dropout: float = 0.
inflation: int = 1
activation: str= 'selu'
collapse_regularization = .01
device = T.device('cpu')

lr: float = 1e-3
weight_decay: float = 5e-4
lr_decay_step: int = 5
lr_decay_gamma: float = 0.3

model = DMoN(features_tensor.size(-1), n_clusters,
             hidden, depths, dropout, inflation,
             collapse_regularization=collapse_regularization)
model: nn.Module = model.to(device)
torchinfo.summary(model)

  self.predict.weight.copy_(nn.init.orthogonal(self.predict.weight.data))


Layer (type:depth-idx)                   Param #
DMoN                                     --
├─Single: 1-1                            --
│    └─Sequential: 2-1                   --
│    │    └─GCN: 3-1                     91,712
│    └─Linear: 2-2                       1,040
│    └─Dropout: 2-3                      --
Total params: 92,752
Trainable params: 92,752
Non-trainable params: 0

In [4]:
# Train main
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, lr_decay_step, lr_decay_gamma)

for epoch in range(num_epochs):
  model.train()
  optimizer.zero_grad()

  pooled_features, assignments, m_loss, c_loss = model((features_tensor, adj_tensor))
  loss: T.Tensor = m_loss + c_loss
  loss.backward()
  optimizer.step()
  lr_scheduler.step()

  print(f'Epoch [{epoch+1:d}/{num_epochs:d}] Loss: {loss.item():.2f}')

Epoch [1/200] Loss: 27.06
Epoch [2/200] Loss: 27.06
Epoch [3/200] Loss: 27.06
Epoch [4/200] Loss: 27.06
Epoch [5/200] Loss: 27.06
Epoch [6/200] Loss: 27.06
Epoch [7/200] Loss: 27.06
Epoch [8/200] Loss: 27.06
Epoch [9/200] Loss: 27.06
Epoch [10/200] Loss: 27.05
Epoch [11/200] Loss: 27.05
Epoch [12/200] Loss: 27.05
Epoch [13/200] Loss: 27.05
Epoch [14/200] Loss: 27.05
Epoch [15/200] Loss: 27.05
Epoch [16/200] Loss: 27.05
Epoch [17/200] Loss: 27.05
Epoch [18/200] Loss: 27.05
Epoch [19/200] Loss: 27.05
Epoch [20/200] Loss: 27.05
Epoch [21/200] Loss: 27.05
Epoch [22/200] Loss: 27.05
Epoch [23/200] Loss: 27.05
Epoch [24/200] Loss: 27.05
Epoch [25/200] Loss: 27.05
Epoch [26/200] Loss: 27.05
Epoch [27/200] Loss: 27.05
Epoch [28/200] Loss: 27.05
Epoch [29/200] Loss: 27.05
Epoch [30/200] Loss: 27.05
Epoch [31/200] Loss: 27.05
Epoch [32/200] Loss: 27.05
Epoch [33/200] Loss: 27.05
Epoch [34/200] Loss: 27.05
Epoch [35/200] Loss: 27.05
Epoch [36/200] Loss: 27.05
Epoch [37/200] Loss: 27.05
Epoch [38/

In [9]:
cluster_labels = assignments.detach().cpu().numpy().argmax(axis=-1)
display(modularity(adj, cluster_labels), conductance(adj, cluster_labels))

-0.038974025608526655

1.0

In [13]:
labels.shape, cluster_labels.shape

((1640,), (1, 2708))

In [14]:
display(normalized_mutual_info_score(labels, cluster_labels[0, label_indices]))

0.060144042121365406