In [1]:
from ml.data.datasets import StarWars

dataset = StarWars()
data = dataset[0]
data

[2022-02-09 17:26:43,215][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-09 17:26:43,216][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-09 17:26:43,217][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-09 17:26:43,218][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-09 17:26:43,219][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-09 17:26:43,222][simple_parsing.helpers

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)

In [2]:
from ml.data import EdgeLoaderDataModule

data_module = EdgeLoaderDataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character')

In [3]:
from typing import Optional

import torch
import torch.nn.functional as F
from torch_geometric.typing import Metadata
from torch_geometric.nn import HGTConv, Linear
import torchmetrics
import pytorch_lightning as pl

In [4]:
class HGTModule(torch.nn.Module):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            hidden_channels=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict[self.node_type]


class ClusterModule(torch.nn.Module):
    def __init__(
            self,
            n_clusters: int = 5,
            embedding_dim: int = 64,
            alpha: float = 1.0,
            cluster_centers: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()
        self.n_clusters = n_clusters
        self.embedding_dim = embedding_dim
        self.alpha = alpha

        if cluster_centers is None:
            initial_cluster_centers = torch.zeros(
                self.n_clusters, self.embedding_dim, dtype=torch.float
            )
            torch.nn.init.xavier_uniform_(initial_cluster_centers)
        else:
            assert cluster_centers.shape == (self.n_clusters, self.embedding_dim)
            initial_cluster_centers = cluster_centers
        self.cluster_centers = torch.nn.Parameter(initial_cluster_centers)

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        # Compute soft assignments (returns t distribution for each node over clusters)
        norm_squared = torch.sum(torch.square(batch.unsqueeze(1) - self.cluster_centers), dim=2)
        numerator = 1.0 / (1.0 + (norm_squared / self.alpha))
        power = float(self.alpha + 1) / 2
        numerator = torch.pow(numerator, power)
        return numerator / torch.sum(numerator, dim=1, keepdim=True)

In [5]:
class Net(pl.LightningModule):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            n_clusters: int = 5,
            embedding_dim=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.node_embedding = HGTModule(node_type, metadata, embedding_dim, num_heads, num_layers)
        self.cluster_embedding = ClusterModule(n_clusters, embedding_dim)

        self.dist = torch.nn.PairwiseDistance(p=2)
        self.lin = torch.nn.Linear(1, 2)
        self.ce_loss = torch.nn.CrossEntropyLoss()

        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()


    def _forward_step_small(self, batch, mode='train'):
        self.cluster_embedding.requires_grad_(False)

        # Small Step: Link prediction
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.node_type].batch_size

        emb_l = self.node_embedding(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.node_embedding(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]

        # Compute homophily based loss. TODO: do we need to cap the individual losses?
        dist = self.dist(emb_l, emb_r)
        # hp_loss = torch.mean(dist * (label * 2.0 - 1.0))
        out = self.lin(torch.unsqueeze(dist, 1))
        hp_loss = self.ce_loss(out, label)

        loss = hp_loss


        pred = out.argmax(dim=-1)
        if mode == 'train':
            self.train_acc.update(pred, label)
        elif mode == 'val':
            self.val_acc.update(pred, label)

        return loss

    def _forward_step_large(self, batch):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch):
        loss = self._forward_step_small(batch, mode='train')
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._forward_step_small(batch, mode='val')
        return loss

    def training_epoch_end(self, outputs) -> None:
        acc = self.train_acc.compute()
        self.log('mean_train_acc', acc, prog_bar=True)
        self.train_acc.reset()

    def validation_epoch_end(self, outputs) -> None:
        acc = self.val_acc.compute()
        self.log('mean_val_acc', acc, prog_bar=True)
        self.val_acc.reset()

    def forward(self, batch):
        batch_size = batch[self.node_type].batch_size
        emb = self.node_embedding(batch.x_dict, batch.edge_index_dict)[:batch_size]
        return emb


model = Net(node_type='Character', metadata=data.metadata(), n_clusters=5, embedding_dim=64, num_heads=2, num_layers=1)
trainer = pl.Trainer(
    gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="mean_val_acc", min_delta=0.00, patience=8, verbose=True, mode="max")
    ],
    max_epochs=50,
)
trainer.fit(model, data_module)




Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [6]:
predictions = trainer.predict(model, data_module)
embeddings = torch.cat(predictions, dim=0).cpu().numpy()

Predicting: 182it [00:00, ?it/s]

# Extract and cluster the embeddings

In [7]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-hgt-homophily-clustering')
save_path.mkdir(parents=True, exist_ok=True)

In [8]:
import faiss
import pandas as pd

[2022-02-09 17:34:07,731][faiss.loader][INFO] Loading faiss with AVX2 support.
[2022-02-09 17:34:07,732][faiss.loader][INFO] Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
[2022-02-09 17:34:07,732][faiss.loader][INFO] Loading faiss.
[2022-02-09 17:34:07,763][faiss.loader][INFO] Successfully loaded faiss.


In [9]:
k = 5
kmeans = faiss.Kmeans(embeddings.shape[1], k, niter=20, verbose=True, nredo=5)
kmeans.train(embeddings)

Clustering 113 points in 64D to 5 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.01 s, search 0.00 s): objective=4165.91 imbalance=1.423 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=4135.97 imbalance=1.565 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=3151.19 imbalance=1.359 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=3178.59 imbalance=1.177 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=3178.59 imbalance=1.177 nsplit=0       



3151.1923828125

In [10]:
D, I = kmeans.index.search(embeddings, 1)

In [11]:
from shared.graph import CommunityAssignment

labeling = pd.Series(I.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [12]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [None]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-09 17:34:07,826][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-09 17:34:07,827][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-09 17:34:07,827][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-09 17:34:07,827][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-09 17:34:07,828][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-09 17:34:07,828][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-09 17:34:07,828][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

# Calculate Evaluation Metrics

In [13]:
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list

[2022-02-09 17:34:09,344][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.TunedParameterValue'>
[2022-02-09 17:34:09,345][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.Serializable'>, <class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-09 17:34:09,346][simple_parsing.helpers.serialization.serializable][DEBUG] Parent class <class 'simple_parsing.helpers.serialization.serializable.Serializable'> has decode_into_subclasses = False
[2022-02-09 17:34:09,347][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.ParameterConfig'>
[2022-02-09 17:34:09,349][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'dict'>, <class 'simple_parsing.helpers.serialization.serializable.Serializable'>, <class 'simple_parsing

In [14]:
from shared.schema import GraphSchema, DatasetSchema

DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)

ERROR! Session/line number was not unique in database. History logging moved to new session 141
[2022-02-09 17:34:09,374][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-09 17:34:09,375][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-09 17:34:09,375][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-09 17:34:09,375][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-09 17:34:09,376][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-09 17:34:09,376][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-09 17:34:09,377][sim

In [15]:
metrics = get_metric_list(ground_truth=False, overlapping=False)

results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.400199
2,expansion,3.407707
3,internal_edge_density,0.362462
4,avg_odf,6.95326
5,modularity_overlap,0.037033
6,link_modularity,0.048038
7,z_modularity,0.584812
8,modularity,-0.048813
