In [142]:
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
import pandas as pd
from torch_geometric.transforms import ToUndirected

from experiments import ClusteringModule, ClusterCohesionLoss, NegativeEntropyRegularizer, cosine_cdist
from shared.constants import TMP_PATH, BENCHMARKS_RESULTS
from shared.graph import CommunityAssignment
import ml
import experiments

In [143]:
node_type = 'Character'
repr_dim = 32
save_path = TMP_PATH.joinpath('pyg-node2vec-comopt')
callbacks = [
    pl.callbacks.ModelSummary(),
    pl.callbacks.LearningRateMonitor(),
    pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
]

In [144]:
dataset = ml.StarWarsHomogenous()
transform = ToUndirected()
data = dataset[0]

[2022-02-14 12:00:17,966][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-14 12:00:17,967][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-14 12:00:17,968][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-14 12:00:17,969][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-14 12:00:17,970][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-14 12:00:17,971][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-14 12:00:17,972][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

In [145]:
from torch_geometric.nn import Node2Vec

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data[data.edge_types[0]].edge_index, embedding_dim=repr_dim, walk_length=8,
                     context_size=8, walks_per_node=10,
                     num_negative_samples=3, p=1, q=1, sparse=True).to(device)

In [146]:
loader = model.loader(batch_size=16, shuffle=True, num_workers=4)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

In [147]:
import numpy as np

def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def test():
    model.eval()
    z = model()
    acc = model.test(z, data.y,
                     z, data.y,
                     max_iter=150)
    return acc

for epoch in range(1, 300):
    loss = train()
    # acc = test()
    acc = np.nan
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}') #

Epoch: 01, Loss: 4.2531, Acc: nan
Epoch: 02, Loss: 3.7992, Acc: nan
Epoch: 03, Loss: 3.7021, Acc: nan
Epoch: 04, Loss: 3.3091, Acc: nan
Epoch: 05, Loss: 3.2170, Acc: nan
Epoch: 06, Loss: 2.9813, Acc: nan
Epoch: 07, Loss: 2.8890, Acc: nan
Epoch: 08, Loss: 2.6932, Acc: nan
Epoch: 09, Loss: 2.5513, Acc: nan
Epoch: 10, Loss: 2.5215, Acc: nan
Epoch: 11, Loss: 2.3989, Acc: nan
Epoch: 12, Loss: 2.3405, Acc: nan
Epoch: 13, Loss: 2.2661, Acc: nan
Epoch: 14, Loss: 2.2588, Acc: nan
Epoch: 15, Loss: 2.0568, Acc: nan
Epoch: 16, Loss: 2.0453, Acc: nan
Epoch: 17, Loss: 2.0082, Acc: nan
Epoch: 18, Loss: 1.9872, Acc: nan
Epoch: 19, Loss: 1.9859, Acc: nan
Epoch: 20, Loss: 1.7767, Acc: nan
Epoch: 21, Loss: 1.7581, Acc: nan
Epoch: 22, Loss: 1.7861, Acc: nan
Epoch: 23, Loss: 1.7286, Acc: nan
Epoch: 24, Loss: 1.6200, Acc: nan
Epoch: 25, Loss: 1.5861, Acc: nan
Epoch: 26, Loss: 1.6284, Acc: nan
Epoch: 27, Loss: 1.5679, Acc: nan
Epoch: 28, Loss: 1.5433, Acc: nan
Epoch: 29, Loss: 1.5065, Acc: nan
Epoch: 30, Los

In [148]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-node2vec')
save_path.mkdir(parents=True, exist_ok=True)

In [149]:
import faiss
import pandas as pd

In [150]:
embeddings = model.embedding.weight.detach().cpu().numpy()

In [151]:
k = 5
kmeans = faiss.Kmeans(embeddings.shape[1], k, niter=20, verbose=True, nredo=5)
kmeans.train(embeddings)
D, I = kmeans.index.search(embeddings, 1)

Clustering 113 points in 32D to 5 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.00 s, search 0.00 s): objective=384.696 imbalance=2.107 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.01 s, search 0.00 s): objective=430.144 imbalance=1.436 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=422.185 imbalance=1.683 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=422.095 imbalance=1.896 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=424.694 imbalance=1.970 nsplit=0       



In [152]:
from shared.graph import CommunityAssignment

labeling = pd.Series(I.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [153]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [154]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-14 12:01:40,169][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-14 12:01:40,170][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-14 12:01:40,170][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-14 12:01:40,171][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-14 12:01:40,171][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-14 12:01:40,172][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-14 12:01:40,173][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

In [155]:
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list

In [156]:
from shared.schema import GraphSchema, DatasetSchema

DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)
G.to_undirected()

[2022-02-14 12:01:40,414][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-14 12:01:40,415][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-14 12:01:40,415][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-14 12:01:40,416][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-14 12:01:40,417][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-14 12:01:40,418][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-14 12:01:40,419][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

In [157]:
metrics = get_metric_list(ground_truth=False, overlapping=False)

results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.623867
2,expansion,13.824482
3,internal_edge_density,0.320072
4,avg_odf,13.824482
5,modularity_overlap,0.028717
6,link_modularity,0.075683
7,z_modularity,0.3136
8,modularity,0.151597
