In [1]:
from fph_clustering.models.direct_constrained_parameterization import FPHConstrainedDirectParameterization
from fph_clustering.data.data_modules import FPHDataModule
from pytorch_lightning import Trainer
import torch
import scipy.sparse as sp

from fph_clustering.algorithms.agglomerative_linkage import agglomerative_linkage
from fph_clustering.util.utils import networkx_from_torch_sparse
from fph_clustering.algorithms.hierarchy_compression import compress_hierarchy_dasgupta, compress_hierarchy_tsd
from sknetwork.hierarchy.metrics import tree_sampling_divergence
from fph_clustering.util import utils




## Configuration

In [2]:
optimizer = 'PGD'
learning_rate = 150
loss = 'TSD'
internal_nodes = 512
dataset = 'citeseer'
data_path = '../data'
tree_init = 'avg'
optimizer_params = {
    'optimizer_type': optimizer,
    'opt_params': {
        'lr': learning_rate
    }
}
max_epochs = 400
use_gpu = True
val_every = 10

## Initialize data module

In [3]:
if dataset.startswith('ogb'):
    data_module = FPHDataModule.from_ogb_dataset(dataset)
else:
    data_module = FPHDataModule.from_pickle(f'{data_path}/{dataset}.pkl.gzip')

## Optionally perform average linkage initialization

In [4]:
init_from = None
if tree_init == 'avg':
    adjacency = data_module.dataset.adjacency
    graph = networkx_from_torch_sparse(adjacency)
    den = agglomerative_linkage(graph, affinity='unitary', linkage='average', check=True)
    if loss == 'TSD':
        compressed = compress_hierarchy_tsd(graph, den, internal_nodes)
    elif loss == 'DASGUPTA':
        compressed = compress_hierarchy_dasgupta(graph, den, internal_nodes)
    else:
        raise NotImplementedError()
    init_from = utils.tree_to_A_B(compressed, adjacency.shape[0], internal_nodes)

## Initialize model

In [5]:
model = FPHConstrainedDirectParameterization(internal_nodes, data_module.num_nodes, 
                                              optimizer_params=optimizer_params, loss=loss, 
                                              initialize_from=init_from, )

## Training

In [6]:
trainer = Trainer(gpus=1 if use_gpu and torch.cuda.is_available() else 0,
                  max_epochs=max_epochs,
                  checkpoint_callback=False,
                  check_val_every_n_epoch=int(val_every),
                  progress_bar_refresh_rate=1, 
                  num_sanity_val_steps=1,
                  )

  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [7]:
trainer.fit(model=model, datamodule=data_module)

Missing logger folder: /Users/danielzuegner/Documents/TUM/fph-clustering/notebooks/lightning_logs

  | Name | Type      | Params
-----------------------------------
0 | A_u  | Embedding | 1.1 M 
-----------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.370     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Extract learned hierarchy

In [8]:
model.eval()
A = B = None

if model.best_A is None:
    A, B = model.compute_A_B()
    A = A.detach().cpu().numpy()
    B = B.detach().cpu().numpy()
    T = utils.best_tree(A, B)
    A, B = utils.tree_to_A_B(T, A.shape[0], A.shape[1])
    A = torch.tensor(A).to(model.A_u.device)
    B = torch.tensor(B).to(model.B_u.device)

else:
    A = model.best_A.detach()
    B = model.best_B.detach()
    T = model.best_tree

graph = next(iter(data_module.test_dataloader()))
with torch.no_grad():
    res_tsd = model.compute_TSD(graph, A=A, B=B)
    res_das = model.compute_dasgupta(graph, A=A, B=B)

skn_TSD = None
if model.num_nodes < 1e6:
    # sknetwork's TSD implementation is inefficient and takes very long for large datasets.
    adj = data_module.dataset.adjacency
    A_sp = sp.csr_matrix((adj.values().cpu(), adj.indices().cpu()), shape=adj.shape)
    den, _ = utils.tree_to_dendrogram(T, A.shape[0])
    skn_TSD_raw = tree_sampling_divergence(A_sp, den, normalized=False)
    skn_TSD = (100 * skn_TSD_raw / graph.mutual_information).item()


In [9]:
# In the paper we report the sknetwork TSD results.
# However, due to a bug in sknetwork 0.24.0, the results
# are slightly different than our own TSD metric.
# See: https://github.com/sknetwork-team/scikit-network/issues/504.
print(f'TSD (sknetwork): {skn_TSD:.2f}')
print(f'TSD: {res_tsd.metric.item():.2f}')
print(f'Dasgupta: {res_das.metric.item():.2f}')

TSD (sknetwork): 69.37
TSD: 67.58
Dasgupta: 86.81
