In [2]:
from ml.data.datasets import StarWars

dataset = StarWars()
data = dataset[0]
data

[2022-02-10 11:54:46,956][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-10 11:54:46,957][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-10 11:54:46,958][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-10 11:54:46,958][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-10 11:54:46,960][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-10 11:54:46,964][simple_parsing.helpers

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)

In [3]:
from ml.data import EdgeLoaderDataModule

data_module = EdgeLoaderDataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character')

In [4]:
from typing import Optional, Dict, Tuple, Any

import torch
import torch.nn.functional as F
from torch_geometric.typing import Metadata
from torch_geometric.nn import HGTConv, Linear
import torchmetrics
import pytorch_lightning as pl

from ml.data import BaseModule

In [7]:
from ml.metrics import LabelEntropyMetric
from pytorch_lightning.loggers import WandbLogger
import ipyparams
import datetime as dt

class HGTModule(torch.nn.Module):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            hidden_channels=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict[self.node_type]


class ClusterModule(torch.nn.Module):
    def __init__(
            self,
            n_clusters: int = 5,
            embedding_dim: int = 64,
            cluster_centers: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()
        self.n_clusters = n_clusters
        self.embedding_dim = embedding_dim

        # TODO: figure a way to initialize with value
        # self.cluster_centers = torch.nn.Parameter(initial_cluster_centers)
        self.cluster_centers = torch.nn.Embedding(self.n_clusters, self.embedding_dim)

        self.cos_sim = torch.nn.CosineSimilarity(dim=1)

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        assignments, probability = self.assignments(batch)
        return self.cluster_centers(assignments), probability

    def assignments(self, batch: torch.Tensor) -> torch.Tensor:
        sac = self.soft_assignment(batch)
        result = torch.argmax(sac, dim=1)
        probability = torch.gather(sac, 1, result.unsqueeze(1)).squeeze(1)
        return result, probability

    def soft_assignment(self, batch: torch.Tensor) -> torch.Tensor:
        sac = torch.nn.functional.cosine_similarity(
            batch.unsqueeze(1), self.cluster_centers.weight.unsqueeze(0), dim=2
        )
        return torch.softmax(sac, dim=1)

    def cluster_cohesion_loss(self, batch: torch.Tensor) -> torch.Tensor:
        # Want to have all nodes as close to their assigned cluster center as possible
        assignments, probability = self.assignments(batch)
        v = batch
        vc = self.cluster_centers(assignments)
        sim = self.cos_sim(v, vc)
        return torch.square(1 - sim)

    def cluster_quality_loss(self) -> torch.Tensor:
        # Want to have as large distance between cluster centers as possible
        sim_mtx = torch.nn.functional.cosine_similarity(
            self.cluster_centers.weight.unsqueeze(1), self.cluster_centers.weight.unsqueeze(0), dim=2
        ) - 2 * torch.eye(self.n_clusters, dtype=torch.float).to(self.cluster_centers.weight.device)
        return torch.square(0 - torch.max(sim_mtx, dim=1).values)

    def davies_bouldin_loss(self, batch: torch.Tensor) -> torch.Tensor:
        pass  # TODO: later


class Net(BaseModule):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            n_clusters: int = 5,
            embedding_dim=64,
            num_heads=2,
            num_layers=1
    ):
        self.n_clusters = n_clusters
        super().__init__()
        self.node_type = node_type
        self.node_embedding = HGTModule(node_type, metadata, embedding_dim, num_heads, num_layers)
        self.cluster_embedding = ClusterModule(n_clusters, embedding_dim)

        self.cos_sim = torch.nn.CosineSimilarity(dim=1)

    def configure_metrics(self) -> Dict[str, Tuple[torchmetrics.Metric, bool]]:
        return {
            'loss': (torchmetrics.MeanMetric(), True),
            'hp_loss': (torchmetrics.MeanMetric(), True),
            'chp_loss': (torchmetrics.MeanMetric(), True),
            'cc_loss': (torchmetrics.MeanMetric(), True),
            'cq_loss': (torchmetrics.MeanMetric(), True),
            'entropy': (LabelEntropyMetric(self.n_clusters), True),
        }

    def _forward_step_small(self, batch, mode='train'):
        self.cluster_embedding.requires_grad_(False)

        # Small Step: Link prediction
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.node_type].batch_size

        emb_l = self.node_embedding(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.node_embedding(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]
        label_rescaled = (label + 1.0) / 2.0

        # Compute homophily based loss. TODO: do we need to cap the individual losses?
        sim = self.cos_sim(emb_l, emb_r)
        hp_loss = torch.mean(torch.square(sim - label_rescaled))

        # Community based homophily loss
        c_emb_l, p_l = self.cluster_embedding(emb_l)
        c_emb_r, p_r = self.cluster_embedding(emb_r)
        weight = torch.multiply(p_l, p_r)
        c_sim = self.cos_sim(c_emb_l, c_emb_r)
        chp_loss = torch.mean(torch.square(c_sim - label_rescaled) * weight) * 7

        # Compute cluster cohesion loss
        cc_loss = torch.mean(torch.cat([
            self.cluster_embedding.cluster_cohesion_loss(emb_l),
            self.cluster_embedding.cluster_cohesion_loss(emb_r)
        ], dim=0))

        # Compute cluster quality loss
        cq_loss = torch.mean(self.cluster_embedding.cluster_quality_loss())

        loss = hp_loss + chp_loss + cc_loss + cq_loss

        assignments = torch.cat([
            self.cluster_embedding.assignments(emb_l)[0],
            self.cluster_embedding.assignments(emb_r)[0],
        ], dim=0)

        return {
            'loss': loss,
            'hp_loss': hp_loss.detach(),
            'cc_loss': cc_loss.detach(),
            'cq_loss': cq_loss.detach(),
            'chp_loss': chp_loss.detach(),
            'entropy': assignments.detach(),
        }

    def _forward_step_large(self, batch):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch):
        return self._forward_step_small(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self._forward_step_small(batch, mode='val')

    def forward(self, batch):
        batch_size = batch[self.node_type].batch_size
        emb = self.node_embedding(batch.x_dict, batch.edge_index_dict)[:batch_size]
        return emb


model = Net(node_type='Character', metadata=data.metadata(), n_clusters=5, embedding_dim=64, num_heads=2, num_layers=1)

# wandb_logger = WandbLogger(
#     project='Thesis-Experiments',
#     name=f'{ipyparams.notebook_name}-{dt.datetime.now().strftime("%Y%m%d-%H%M%S")}'
# )
trainer = pl.Trainer(
    # gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
    ],
    max_epochs=50,
    enable_model_summary=True,
    # logger=wandb_logger
)
trainer.fit(model, data_module)

  rank_zero_deprecation(



Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [23]:
predictions = trainer.predict(model, data_module)
emb = torch.cat(predictions, dim=0).cpu()
embeddings = torch.nn.functional.normalize(emb, p=2, dim=1).numpy()

Predicting: 182it [00:00, ?it/s]

# Extract and cluster the embeddings

In [24]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-hgt-comopt')
save_path.mkdir(parents=True, exist_ok=True)

In [25]:
import faiss
import pandas as pd

In [26]:
k = 5
kmeans = faiss.Kmeans(embeddings.shape[1], k, niter=20, verbose=True, nredo=5)
kmeans.train(embeddings)
D, I = kmeans.index.search(embeddings, 1)


Clustering 113 points in 64D to 5 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.00 s, search 0.00 s): objective=19.7195 imbalance=1.209 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.00 s, search 0.00 s): objective=18.6964 imbalance=1.062 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=20.2445 imbalance=1.222 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=19.2894 imbalance=1.138 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=19.3788 imbalance=1.204 nsplit=0       



In [27]:
from shared.graph import CommunityAssignment

labeling = pd.Series(I.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [28]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [None]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-09 17:27:52,446][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-09 17:27:52,446][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-09 17:27:52,447][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-09 17:27:52,448][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-09 17:27:52,448][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-09 17:27:52,449][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-09 17:27:52,449][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

# Calculate Evaluation Metrics

In [29]:
from shared.graph import DataGraph
from benchmarks.evaluation import get_metric_list

In [30]:
from shared.schema import GraphSchema, DatasetSchema

DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)
G = DataGraph.from_schema(schema)

[2022-02-09 17:27:53,702][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-09 17:27:53,703][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-09 17:27:53,703][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-09 17:27:53,704][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-09 17:27:53,704][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-09 17:27:53,705][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-09 17:27:53,705][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi

In [31]:
metrics = get_metric_list(ground_truth=False, overlapping=False)

results = pd.DataFrame([
    {
        'metric': metric_cls.metric_name(),
        'value': metric_cls.calculate(G, comlist)
    }
    for metric_cls in metrics]
)
results

Unnamed: 0,metric,value
0,community_count,
1,conductance,0.384977
2,expansion,5.200736
3,internal_edge_density,0.430615
4,avg_odf,10.480171
5,modularity_overlap,0.025714
6,link_modularity,0.037039
7,z_modularity,0.57352
8,modularity,-0.004529
