In [2]:
from ml.data.datasets import StarWars

dataset = StarWars()
data = dataset[0]
data

[2022-02-12 15:51:06,083][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-12 15:51:06,084][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-12 15:51:06,085][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-12 15:51:06,086][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-12 15:51:06,087][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-12 15:51:06,091][simple_parsing.helpers

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)

In [3]:
from ml.data import EdgeLoaderDataModule

data_module = EdgeLoaderDataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character', neg_sample_ratio=1)

In [4]:
from typing import Optional, Dict, Tuple, Any

import torch
import torch.nn.functional as F
from torch_geometric.typing import Metadata
from torch_geometric.nn import HGTConv, Linear
import torchmetrics
import pytorch_lightning as pl
import math

import ml

In [5]:
class HGTModule(torch.nn.Module):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            hidden_channels=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict[self.node_type]

class ClusteringModule(torch.nn.Module):
    def __init__(self, rep_dim, n_clusters):
        super().__init__()
        self.lin = torch.nn.Linear(rep_dim, n_clusters)
        self.activation = torch.nn.Softmax(dim=1)

    def forward(self, batch: torch.Tensor):
        return self.activation(self.lin(batch))

    def forward_assign(self, batch: torch.Tensor):
        q = self(batch)
        return q.argmax(dim=1)

class LinkPredictionLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, pred: torch.Tensor, label: torch.Tensor):
        return self.ce_loss(pred, label)

class ClusteringLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ce_loss = torch.nn.CrossEntropyLoss()

    def forward(self, q_l: torch.Tensor, q_r: torch.Tensor, label: torch.Tensor):
        p = torch.sum(q_l, dim=0) + torch.sum(q_r, dim=0)
        p /= p.sum()
        ne = math.log(p.size(0)) + (p * torch.log(p)).sum()

        sim = torch.sigmoid(torch.bmm(q_l.unsqueeze(1), q_r.unsqueeze(2)).squeeze())
        pred = torch.stack([1 - sim, sim], dim=1)
        loss = self.ce_loss(pred, label)

        return loss, ne

In [6]:
class Net(ml.BaseModule):
    def __init__(
            self,
            embedding_module: HGTModule,
            clustering_module: ClusteringModule,
    ):
        super().__init__()
        self.embedding_module = embedding_module
        self.clustering_module = clustering_module
        self.cos_sim = torch.nn.CosineSimilarity(dim=1)
        self.lin = torch.nn.Linear(1, 2)

        self.link_prediction_loss = LinkPredictionLoss()
        self.clustering_loss = ClusteringLoss()

    def configure_metrics(self) -> Dict[str, Tuple[torchmetrics.Metric, bool]]:
        return {
            'loss': (torchmetrics.MeanMetric(), True),
            'hp_loss': (torchmetrics.MeanMetric(), True),
            'cc_loss': (torchmetrics.MeanMetric(), True),
            'accuracy': (torchmetrics.Accuracy(), True),
            'ne': (torchmetrics.MeanMetric(), True),
        }

    def _step(self, batch: torch.Tensor):
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.embedding_module.node_type].batch_size

        emb_l = self.embedding_module(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.embedding_module(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]
        sim = self.cos_sim(emb_l, emb_r)
        out = self.lin(torch.unsqueeze(sim, 1))
        hp_loss = self.link_prediction_loss(out, label)

        q_l = self.clustering_module(emb_l)
        q_r = self.clustering_module(emb_r)
        cc_loss, ne = self.clustering_loss(q_l, q_r, label)

        loss = hp_loss + 2 * cc_loss + ne

        pred = out.argmax(dim=-1)
        return {
            'loss': loss,
            'hp_loss': hp_loss,
            'cc_loss': cc_loss,
            'accuracy': (pred, label),
            'ne': ne,
        }

    def forward(self, batch):
        batch_size = batch[self.embedding_module.node_type].batch_size
        emb = self.embedding_module(batch.x_dict, batch.edge_index_dict)[:batch_size]
        q = self.clustering_module(emb)
        return emb, q

    def training_step(self, batch):
        return self._step(batch)

    def validation_step(self, batch, batch_idx):
        return self._step(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [7]:
embedding_module = HGTModule(node_type='Character', metadata=data.metadata(), hidden_channels=32, num_heads=2, num_layers=2)
clustering_module = ClusteringModule(rep_dim=32, n_clusters=5)

In [8]:
model = Net(embedding_module, clustering_module)

In [9]:
trainer = pl.Trainer(
    # gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
    ],
    max_epochs=50,
    enable_model_summary=True,
    # logger=wandb_logger
)
trainer.fit(model, data_module)

  rank_zero_warn(




Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_deprecation(



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [112]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-hgt-comopt-contrastive')
save_path.mkdir(parents=True, exist_ok=True)

In [113]:
predictions = trainer.predict(model, data_module)
embeddings, assignments = map(lambda x: torch.cat(x, dim=0).detach().cpu(), zip(*predictions))
assignments = torch.argmax(assignments, dim=1)

Predicting: 182it [00:00, ?it/s]

In [114]:
from shared.graph import CommunityAssignment
import pandas as pd

labeling = pd.Series(assignments.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [115]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [None]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-12 15:33:53,978][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-12 15:33:53,979][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-12 15:33:53,979][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-12 15:33:53,980][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-12 15:33:53,980][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-12 15:33:53,980][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-12 15:33:53,981][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 416, in _send_bytes
    self._send(header + buf)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connection.py", line 373, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/data/pella/projects/University/Thesis/Thesis/code/env/lib/python3.9/multiprocessing/connec

[2022-02-12 15:33:54,176][export_to_visualization.py][INFO] Adding prediction /data/pella/projects/University/Thesis/Thesis/code/storage/results/analysis/pyg-hgt-comopt-contrastive/schema.comlist as prediction_0
[2022-02-12 15:33:54,180][export_to_visualization.py][INFO] Writing /data/pella/projects/University/Thesis/Thesis/code/storage/datasets/export/star-wars/versions/base/schema.graphml


# Calculate Evaluation Metrics

In [116]:
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list

In [117]:
from shared.schema import GraphSchema, DatasetSchema

DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)

ERROR! Session/line number was not unique in database. History logging moved to new session 166
[2022-02-12 15:33:54,566][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-12 15:33:54,567][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-12 15:33:54,567][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-12 15:33:54,568][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-12 15:33:54,569][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-12 15:33:54,569][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-12 15:33:54,570][sim

In [118]:
metrics = get_metric_list(ground_truth=False, overlapping=False)

results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.734932
2,expansion,9.306583
3,internal_edge_density,0.215144
4,avg_odf,18.763866
5,modularity_overlap,-0.067691
6,link_modularity,0.016956
7,z_modularity,-0.109585
8,modularity,-0.166699
