In [1]:
import numpy as np
import pytorch_lightning as pl
import torch
import random
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import Node2Vec
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import NeighborLoader
from torch_geometric.loader.base import BaseDataLoader
from scipy.sparse import csr_matrix

import experiments
import ml
from experiments import cosine_cdist, euclidean_cdist
from shared.constants import TMP_PATH

[2022-02-15 12:43:47,280][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-15 12:43:47,282][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-15 12:43:47,283][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-15 12:43:47,283][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-15 12:43:47,285][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-15 12:43:47,288][simple_parsing.helpers

In [2]:
device = 'cpu'
experiment_name = 'pyg-n2v-experiment'
node_type = 'Character'
initialization = 'louvain'  # 'k-means' or 'none
repr_dim = 32
EPS = 1e-15
recluster = False
save_path = TMP_PATH.joinpath(experiment_name)
callbacks = [
    pl.callbacks.ModelSummary(),
    pl.callbacks.LearningRateMonitor(),
    pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
]

In [3]:
dataset = ml.StarWarsHomogenous()
transform = ToUndirected()
data = dataset[0]
G = dataset.G
G.to_undirected()

[2022-02-15 12:44:00,432][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-15 12:44:00,433][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-15 12:44:00,433][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-15 12:44:00,434][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-15 12:44:00,435][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-15 12:44:00,435][simple_parsing.helpers.serialization.decoding][DEBUG] Decoding a Dict field: typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-15 12:44:00,436][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVe

In [4]:
edge_index = data.edge_stores[0].edge_index
edge_index = torch.cat([
    edge_index,
    torch.stack([edge_index[1, :], edge_index[0, :]])
], dim=1)
row, col = edge_index

csr = csr_matrix(
    (torch.ones(edge_index.shape[1], dtype=torch.int32).numpy(), (row.numpy(), col.numpy())),
    shape=(data.num_nodes, data.num_nodes),
)
neg_neighbors = [
    list(set(range(data.num_nodes)).difference(set(csr[i, :].indices)))
    for i in range(data.num_nodes)
]

In [5]:
def neg_sample(batch: torch.Tensor, num_neg_samples: int = 1) -> torch.Tensor:
    result = torch.tensor([
        random.choices(neg_neighbors[i], k=num_neg_samples)
        for i in batch[:, 0]
    ], dtype=torch.long)
    return result


repeat_count = 2
num_neg_samples = 3
pos_idx = edge_index.t().repeat(repeat_count, 1)
neg_idx = neg_sample(pos_idx, num_neg_samples=num_neg_samples)
data_idx = torch.cat([pos_idx, neg_idx], dim=1)

node_loader = NeighborLoader(
    data=data, num_neighbors=[4, 4], input_nodes='Character', directed=False, replace=False
)

In [6]:
class SamplesLoader(BaseDataLoader):
    def __init__(
            self,
            data,
            **kwargs
    ):
        self.node_loader = node_loader

        super().__init__(
            data.tolist(),
            collate_fn=self.sample,
            **kwargs,
        )

    def sample(self, indices):
        indices = torch.tensor(indices, dtype=torch.long).view(-1).tolist()
        return self.node_loader.neighbor_sampler(indices)

    def transform_fn(self, out):
        return self.node_loader.transform_fn(out)


data_loader = SamplesLoader(data_idx, batch_size=8, shuffle=True)
next(iter(data_loader))

HeteroData(
  [1mCharacter[0m={
    x=[67, 32],
    batch_size=40
  },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[486, 0],
    edge_index=[2, 486],
    timestamp=[486]
  }
)

In [7]:
embedding_module = experiments.GraphSAGEModule(node_type, data.metadata(), repr_dim, n_layers=2)
model = embedding_module
optimizer = torch.optim.Adam(list(model.parameters()), lr=0.01)

cos_sim = torch.nn.CosineSimilarity(dim=2)

def train():
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()

        emb = embedding_module(batch)
        neg_pos_emb = emb.view(-1, num_neg_samples + 2, repr_dim)
        ctr_emb = neg_pos_emb[:, 0, :].unsqueeze(1)
        pos_emb = neg_pos_emb[:, 1, :].unsqueeze(1)
        neg_emb = neg_pos_emb[:, 2:, :]

        out = cos_sim(ctr_emb, pos_emb).view(-1)
        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()

        out = cos_sim(ctr_emb, neg_emb).view(-1)
        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()

        loss = pos_loss + neg_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data_loader)

def get_embeddings():
    model.eval()
    embs = []
    for batch in node_loader:
        emb = embedding_module(batch)
        embs.append(emb)

    return torch.cat(embs, dim=0)


for epoch in range(1, 20):
    loss = train()
    # acc = test()
    acc = np.nan
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')  #

Epoch: 01, Loss: 1.1415, Acc: nan
Epoch: 02, Loss: 1.0982, Acc: nan
Epoch: 03, Loss: 1.0905, Acc: nan
Epoch: 04, Loss: 1.0851, Acc: nan
Epoch: 05, Loss: 1.0814, Acc: nan
Epoch: 06, Loss: 1.0789, Acc: nan
Epoch: 07, Loss: 1.0766, Acc: nan
Epoch: 08, Loss: 1.0780, Acc: nan
Epoch: 09, Loss: 1.0752, Acc: nan
Epoch: 10, Loss: 1.0729, Acc: nan
Epoch: 11, Loss: 1.0745, Acc: nan
Epoch: 12, Loss: 1.0726, Acc: nan
Epoch: 13, Loss: 1.0729, Acc: nan
Epoch: 14, Loss: 1.0711, Acc: nan
Epoch: 15, Loss: 1.0717, Acc: nan
Epoch: 16, Loss: 1.0702, Acc: nan
Epoch: 17, Loss: 1.0705, Acc: nan
Epoch: 18, Loss: 1.0704, Acc: nan
Epoch: 19, Loss: 1.0698, Acc: nan


In [8]:
from shared.constants import BENCHMARKS_RESULTS
import faiss
import pandas as pd

save_path = BENCHMARKS_RESULTS.joinpath('analysis', experiment_name)
save_path.mkdir(parents=True, exist_ok=True)

In [9]:
embeddings = get_embeddings().detach()

# Normalize for cosine similarity
embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)

k = 5
kmeans = faiss.Kmeans(embeddings.shape[1], k, niter=20, verbose=True, nredo=10)
kmeans.train(embeddings.numpy())
D, I = kmeans.index.search(embeddings.numpy(), 1)

Clustering 113 points in 32D to 5 clusters, redo 10 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 10
  Iteration 19 (0.00 s, search 0.00 s): objective=28.6359 imbalance=1.145 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 10
  Iteration 19 (0.00 s, search 0.00 s): objective=28.6359 imbalance=1.145 nsplit=0       
Outer iteration 2 / 10
  Iteration 19 (0.01 s, search 0.00 s): objective=25.6957 imbalance=1.181 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 10
  Iteration 19 (0.01 s, search 0.01 s): objective=25.3886 imbalance=1.076 nsplit=0       
Objective improved: keep new clusters
Outer iteration 4 / 10
  Iteration 19 (0.01 s, search 0.01 s): objective=25.2893 imbalance=1.090 nsplit=0       
Objective improved: keep new clusters
Outer iteration 5 / 10
  Iteration 19 (0.01 s, search 0.01 s): objective=26.1573 imbalance=1.184 nsplit=0       
Outer iteration 6 / 10
  Iteration 19 (0.01 s, search 0.01 s): ob



In [10]:
from shared.graph import CommunityAssignment

labeling = pd.Series(I.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [11]:
from datasets.scripts import export_to_visualization
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list
from shared.schema import GraphSchema, DatasetSchema

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-15 12:46:13,315][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.TunedParameterValue'>
[2022-02-15 12:46:13,316][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.Serializable'>, <class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-15 12:46:13,317][simple_parsing.helpers.serialization.serializable][DEBUG] Parent class <class 'simple_parsing.helpers.serialization.serializable.Serializable'> has decode_into_subclasses = False
[2022-02-15 12:46:13,318][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.ParameterConfig'>
[2022-02-15 12:46:13,319][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'dict'>, <class 'simple_parsing.helpers.serialization.serializable.Serializable'>, <class 'simple_parsing

In [12]:
DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)
G.to_undirected()

metrics = get_metric_list(ground_truth=False, overlapping=False)
results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

[2022-02-15 12:46:29,881][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-15 12:46:29,882][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-15 12:46:29,883][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-15 12:46:29,884][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-15 12:46:29,885][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-15 12:46:29,885][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-15 12:46:29,886][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.560363
2,expansion,6.364586
3,internal_edge_density,0.309625
4,avg_odf,6.364586
5,modularity_overlap,0.038812
6,link_modularity,0.082726
7,z_modularity,0.533332
8,modularity,0.24947
