In [1]:
from ml.data.datasets import StarWars

dataset = StarWars()
data = dataset[0]
data

[2022-02-11 17:30:44,533][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-11 17:30:44,534][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-11 17:30:44,535][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-11 17:30:44,535][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-11 17:30:44,536][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-11 17:30:44,538][simple_parsing.helpers

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)

In [2]:
from ml.data import EdgeLoaderDataModule

data_module = EdgeLoaderDataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character')

In [3]:
from typing import Optional, Dict, Tuple, Any

import torch
import torch.nn.functional as F
from torch_geometric.typing import Metadata
from torch_geometric.nn import HGTConv, Linear
import torchmetrics
import pytorch_lightning as pl

import ml

In [4]:
class HGTModule(torch.nn.Module):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            hidden_channels=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict[self.node_type]

class ClusterAssignment(torch.nn.Module):
    def __init__(self, cluster_number, embedding_dimension, alpha, cluster_centers=None):
        super(ClusterAssignment, self).__init__()
        self.embedding_dimension = embedding_dimension
        self.cluster_number = cluster_number
        self.alpha = alpha
        if cluster_centers is None:
            initial_cluster_centers = torch.zeros(self.cluster_number, self.embedding_dimension, dtype=torch.float)
            torch.nn.init.xavier_uniform_(initial_cluster_centers)
        else:
            initial_cluster_centers = cluster_centers
        self.cluster_centers = torch.nn.Parameter(initial_cluster_centers)

    def forward(self, inputs):
        norm_squared = torch.sum((inputs.unsqueeze(1) - self.cluster_centers) ** 2, 2)
        numerator = 1.0 / (1.0 + (norm_squared / self.alpha))
        power = float(self.alpha + 1) / 2
        numerator = numerator ** power
        return numerator / torch.sum(numerator, dim=1, keepdim=True)

In [5]:
class PretrainNet(ml.BaseModule):
    def __init__(
            self,
            embedding_module: HGTModule,
    ):
        super().__init__()
        self.embedding_module = embedding_module

        self.cos_sim = torch.nn.CosineSimilarity(dim=1)
        self.lin = torch.nn.Linear(1, 2)
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def configure_metrics(self) -> Dict[str, Tuple[torchmetrics.Metric, bool]]:
        return {
            'loss': (torchmetrics.MeanMetric(), True),
            'hp_loss': (torchmetrics.MeanMetric(), True),
            'accuracy': (torchmetrics.Accuracy(), True),
            'entropy': (ml.LabelEntropyMetric(2), True),
        }

    def _step(self, batch: torch.Tensor):
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.embedding_module.node_type].batch_size

        emb_l = self.embedding_module(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.embedding_module(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]

        # Compute homophily based loss.
        sim = self.cos_sim(emb_l, emb_r)
        out = self.lin(torch.unsqueeze(sim, 1))
        hp_loss = self.ce_loss(out, label)

        loss = hp_loss

        pred = out.argmax(dim=-1)
        return {
            'loss': loss,
            'hp_loss': loss,
            'accuracy': (pred, label),
            'entropy': pred
        }

    def forward(self, batch):
        batch_size = batch[self.embedding_module.node_type].batch_size
        emb = self.embedding_module(batch.x_dict, batch.edge_index_dict)[:batch_size]
        return emb

    def training_step(self, batch):
        return self._step(batch)

    def validation_step(self, batch, batch_idx):
        return self._step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [6]:
embedding_module = HGTModule(node_type='Character', metadata=data.metadata(), hidden_channels=64, num_heads=2, num_layers=1)
pretrain_model = PretrainNet(embedding_module)

In [8]:
trainer = pl.Trainer(
    # gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
    ],
    max_epochs=20,
    enable_model_summary=True,
    # logger=wandb_logger
)
trainer.fit(pretrain_model, data_module)

  rank_zero_deprecation(



Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f574134c0>
Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f574134c0>
Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  

Validating: 0it [00:00, ?it/s]

In [18]:
predictions = trainer.predict(pretrain_model, data_module)
embeddings = torch.cat(predictions, dim=0).cpu()

Predicting: 135it [00:00, ?it/s]

  rank_zero_deprecation(



Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:offset + size])
Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 416, in _send_bytes
    self._send(header + buf)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 373, in

In [26]:
import faiss

kmeans = faiss.Kmeans(embeddings.shape[1], k=5, niter=20, verbose=True, nredo=5)
kmeans.train(embeddings.numpy())
centroids = kmeans.centroids


Clustering 113 points in 64D to 5 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.01 s, search 0.00 s): objective=109.558 imbalance=1.171 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=109.854 imbalance=1.115 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=109.588 imbalance=1.140 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=116.466 imbalance=1.116 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.02 s, search 0.01 s): objective=109.558 imbalance=1.171 nsplit=0       



In [27]:
class TrainNet(ml.BaseModule):
    def __init__(
            self,
            embedding_module: HGTModule,
            clustering_module: ClusterAssignment,
            gamma: float = 0.001
    ):
        super().__init__()
        self.gamma = gamma

        self.embedding_module = embedding_module
        self.cos_sim = torch.nn.CosineSimilarity(dim=1)
        self.lin = torch.nn.Linear(1, 2)
        self.ce_loss = torch.nn.CrossEntropyLoss()

        self.clustering_module = clustering_module
        self.kl_loss = torch.nn.KLDivLoss(size_average=False)

    def configure_metrics(self) -> Dict[str, Tuple[torchmetrics.Metric, bool]]:
        return {
            'loss': (torchmetrics.MeanMetric(), True),
            'hp_loss': (torchmetrics.MeanMetric(), True),
            'ca_loss': (torchmetrics.MeanMetric(), True),
            'accuracy': (torchmetrics.Accuracy(), True),
            'entropy': (ml.LabelEntropyMetric(5), True),
        }

    def _step(self, batch: torch.Tensor):
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.embedding_module.node_type].batch_size

        emb_l = self.embedding_module(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.embedding_module(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]

        # Compute homophily based loss.
        sim = self.cos_sim(emb_l, emb_r)
        out = self.lin(torch.unsqueeze(sim, 1))
        hp_loss = self.ce_loss(out, label)

        # Compute clustering loss
        emb = torch.cat([emb_l, emb_r], dim=0)
        q = self.clustering_module(emb)
        p = torch.nn.functional.one_hot(
            torch.argmax(q.detach(), dim=1), self.clustering_module.cluster_number
        ).to(dtype=torch.float)
        ca_loss = self.kl_loss(torch.log(q), p)

        loss = hp_loss + self.gamma * ca_loss

        pred = out.argmax(dim=-1).detach()
        assignments = q.argmax(dim=-1).detach()
        return {
            'loss': loss,
            'hp_loss': hp_loss.detach(),
            'ca_loss': ca_loss.detach(),
            'accuracy': (pred, label),
            'entropy': assignments
        }

    def forward(self, batch):
        batch_size = batch[self.embedding_module.node_type].batch_size
        emb = self.embedding_module(batch.x_dict, batch.edge_index_dict)[:batch_size]
        return emb

    def training_step(self, batch):
        return self._step(batch)

    def validation_step(self, batch, batch_idx):
        return self._step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [28]:
clustering_module = ClusterAssignment(5, 64, alpha=1, cluster_centers=torch.tensor(centroids))
main_model = TrainNet(embedding_module, clustering_module)




In [29]:
trainer = pl.Trainer(
    # gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
    ],
    max_epochs=10,
    enable_model_summary=True,
    # logger=wandb_logger
)
trainer.fit(main_model, data_module)

  rank_zero_warn(

  rank_zero_deprecation(



Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_deprecation(



In [34]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-hgt-comoptkl')
save_path.mkdir(parents=True, exist_ok=True)

In [30]:
predictions = trainer.predict(main_model, data_module)
embeddings = torch.cat(predictions, dim=0).cpu()

Predicting: 182it [00:00, ?it/s]

In [31]:
assignments = clustering_module(embeddings).detach().cpu().argmax(dim=-1).to(dtype=torch.long)

In [32]:
from shared.graph import CommunityAssignment
import pandas as pd

labeling = pd.Series(assignments.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [35]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [None]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-11 17:42:36,103][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-11 17:42:36,104][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-11 17:42:36,104][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-11 17:42:36,105][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-11 17:42:36,105][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-11 17:42:36,106][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-11 17:42:36,107][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

# Calculate Evaluation Metrics

In [36]:
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list

ERROR! Session/line number was not unique in database. History logging moved to new session 151
[2022-02-11 17:42:56,828][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.TunedParameterValue'>
[2022-02-11 17:42:56,829][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.Serializable'>, <class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-11 17:42:56,830][simple_parsing.helpers.serialization.serializable][DEBUG] Parent class <class 'simple_parsing.helpers.serialization.serializable.Serializable'> has decode_into_subclasses = False
[2022-02-11 17:42:56,832][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'benchmarks.config.ParameterConfig'>
[2022-02-11 17:42:56,833][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'dict'>, 

In [37]:
from shared.schema import GraphSchema, DatasetSchema

DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)

[2022-02-11 17:43:01,663][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-11 17:43:01,664][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-11 17:43:01,665][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-11 17:43:01,665][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-11 17:43:01,666][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-11 17:43:01,667][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-11 17:43:01,668][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

In [38]:
metrics = get_metric_list(ground_truth=False, overlapping=False)

results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.425956
2,expansion,3.232474
3,internal_edge_density,0.361209
4,avg_odf,6.562917
5,modularity_overlap,0.021889
6,link_modularity,0.050799
7,z_modularity,0.506997
8,modularity,-0.099927
