In [1]:
from ml.data.datasets import StarWars

dataset = StarWars()
data = dataset[0]
data

[2022-02-08 23:36:07,386][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.Serializable'>
[2022-02-08 23:36:07,387][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-08 23:36:07,387][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.FrozenSerializable'>
[2022-02-08 23:36:07,388][simple_parsing.helpers.serialization.serializable][DEBUG] parents: [<class 'simple_parsing.helpers.serialization.serializable.SerializableMixin'>]
[2022-02-08 23:36:07,389][simple_parsing.helpers.serialization.serializable][DEBUG] Registering a new Serializable subclass: <class 'simple_parsing.helpers.serialization.serializable.SimpleSerializable'>
[2022-02-08 23:36:07,392][simple_parsing.helpers

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)

In [2]:
from ml.data import EdgeLoaderDataModule

data_module = EdgeLoaderDataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character')

In [3]:
from typing import Optional, Dict, Tuple, Any

import torch
import torch.nn.functional as F
from torch_geometric.typing import Metadata
from torch_geometric.nn import HGTConv, Linear
import torchmetrics
import pytorch_lightning as pl

from ml.data import BaseModule

In [4]:
class HGTModule(torch.nn.Module):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            hidden_channels=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, metadata, num_heads, group='sum')
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict[self.node_type]


class ClusterModule(torch.nn.Module):
    def __init__(
            self,
            n_clusters: int = 5,
            embedding_dim: int = 64,
            alpha: float = 1.0,
            cluster_centers: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()
        self.n_clusters = n_clusters
        self.embedding_dim = embedding_dim
        self.alpha = alpha

        if cluster_centers is None:
            initial_cluster_centers = torch.zeros(
                self.n_clusters, self.embedding_dim, dtype=torch.float
            )
            torch.nn.init.xavier_uniform_(initial_cluster_centers)
        else:
            assert cluster_centers.shape == (self.n_clusters, self.embedding_dim)
            initial_cluster_centers = cluster_centers
        self.cluster_centers = torch.nn.Parameter(initial_cluster_centers)

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        # Compute soft assignments (returns t distribution for each node over clusters)
        norm_squared = torch.sum(torch.square(batch.unsqueeze(1) - self.cluster_centers), dim=2)
        numerator = 1.0 / (1.0 + (norm_squared / self.alpha))
        power = float(self.alpha + 1) / 2
        numerator = torch.pow(numerator, power)
        return numerator / torch.sum(numerator, dim=1, keepdim=True)

    def assignments(self, batch: torch.Tensor) -> torch.Tensor:
        norm_squared = torch.sum(torch.square(batch.unsqueeze(1) - self.cluster_centers), dim=2)
        return torch.argmin(norm_squared, dim=1)

    def cluster_cohesion_loss(self, batch: torch.Tensor) -> torch.Tensor:
        assignments = self.assignments(batch)
        v = batch
        vc = self.cluster_centers[assignments]
        dist = torch.pairwise_distance(v, vc, p=2)
        return dist

    def davies_bouldin_loss(self, batch: torch.Tensor) -> torch.Tensor:
        pass  # TODO: later


In [19]:
from pytorch_lightning.loggers import WandbLogger
import ipyparams
import datetime as dt

class Net(BaseModule):
    def __init__(
            self,
            node_type,
            metadata: Metadata,
            n_clusters: int = 5,
            embedding_dim=64,
            num_heads=2,
            num_layers=1
    ):
        super().__init__()
        self.node_type = node_type
        self.node_embedding = HGTModule(node_type, metadata, embedding_dim, num_heads, num_layers)
        self.cluster_embedding = ClusterModule(n_clusters, embedding_dim)

        self.cos_sim = torch.nn.CosineSimilarity(dim=1)

    def configure_metrics(self) -> Dict[str, Tuple[torchmetrics.Metric, bool]]:
        return {
            'loss': (torchmetrics.MeanMetric(), True),
            'hp_loss': (torchmetrics.MeanMetric(), True),
            'cc_loss': (torchmetrics.MeanMetric(), True),
        }

    def _forward_step_small(self, batch, mode='train'):
        self.cluster_embedding.requires_grad_(False)

        # Small Step: Link prediction
        batch_l, batch_r, label = batch
        batch_size = batch_l[self.node_type].batch_size

        emb_l = self.node_embedding(batch_l.x_dict, batch_l.edge_index_dict)[:batch_size]
        emb_r = self.node_embedding(batch_r.x_dict, batch_r.edge_index_dict)[:batch_size]

        # Compute homophily based loss. TODO: do we need to cap the individual losses?
        sim = self.cos_sim(emb_l, emb_r) # Aka cosine dissimilarity
        hp_loss = torch.mean(torch.square(sim - label))

        # Compute cluster cohesion loss
        # cc_loss = torch.mean(torch.cat([
        #     self.cluster_embedding.cluster_cohesion_loss(emb_l),
        #     self.cluster_embedding.cluster_cohesion_loss(emb_r)
        # ], dim=0))

        loss = hp_loss #+ cc_loss

        return {
            'loss': loss,
            'hp_loss': hp_loss.detach(),
          #  'cc_loss': cc_loss.detach(),
        }

    def _forward_step_large(self, batch):
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch):
        return self._forward_step_small(batch, mode='train')

    def validation_step(self, batch, batch_idx):
        return self._forward_step_small(batch, mode='val')

    def forward(self, batch):
        batch_size = batch[self.node_type].batch_size
        emb = self.node_embedding(batch.x_dict, batch.edge_index_dict)[:batch_size]
        return emb


model = Net(node_type='Character', metadata=data.metadata(), n_clusters=5, embedding_dim=64, num_heads=2, num_layers=1)

wandb_logger = WandbLogger(
    project='Thesis-Experiments',
    name=f'{ipyparams.notebook_name}-{dt.datetime.now().strftime("%Y%m%d-%H%M%S")}'
)
trainer = pl.Trainer(
    gpus=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val/loss", min_delta=0.00, patience=5, verbose=True, mode="min")
    ],
    max_epochs=50,
    logger=wandb_logger
)
trainer.fit(model, data_module)

<IPython.core.display.Javascript object>

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33megordm[0m (use `wandb login --relogin` to force relogin)





Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_deprecation(



In [20]:
predictions = trainer.predict(model, data_module)
embeddings = torch.nn.functional.normalize(torch.cat(predictions, dim=0), p=2, dim=1).cpu().numpy()

  rank_zero_deprecation(



Predicting: 182it [00:00, ?it/s]

# Extract and cluster the embeddings

In [27]:
from shared.constants import BENCHMARKS_RESULTS

save_path = BENCHMARKS_RESULTS.joinpath('analysis', 'pyg-hgt-comopt')
save_path.mkdir(parents=True, exist_ok=True)

In [28]:
import faiss
import pandas as pd

In [29]:
k = 3
kmeans = faiss.Kmeans(embeddings.shape[1], k, niter=20, verbose=True, nredo=5)
kmeans.train(embeddings)

Clustering 113 points in 64D to 3 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=58.143 imbalance=1.047 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.01 s, search 0.01 s): objective=58.1223 imbalance=1.063 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.02 s, search 0.01 s): objective=57.1643 imbalance=1.087 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.02 s, search 0.01 s): objective=58.0796 imbalance=1.100 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.02 s, search 0.01 s): objective=59.5923 imbalance=1.142 nsplit=0       



57.16433334350586

In [30]:
D, I = kmeans.index.search(embeddings, 1)




In [31]:
from shared.graph import CommunityAssignment

labeling = pd.Series(I.squeeze(), index=dataset.node_mapping(), name="cid")
labeling.index.name = "nid"
comlist = CommunityAssignment(labeling)

In [32]:
comlist.save_comlist(save_path.joinpath('schema.comlist'))

In [None]:
from datasets.scripts import export_to_visualization

export_to_visualization.run(
    export_to_visualization.Args(
        dataset='star-wars',
        version='base',
        run_paths=[str(save_path)]
    )
)

[2022-02-09 00:40:06,597][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetSchema'>, drop extra fields: True
[2022-02-09 00:40:06,597][simple_parsing.helpers.serialization.decoding][DEBUG] name = name, field_type = <class 'str'>
[2022-02-09 00:40:06,598][simple_parsing.helpers.serialization.decoding][DEBUG] name = database, field_type = <class 'str'>
[2022-02-09 00:40:06,599][simple_parsing.helpers.serialization.decoding][DEBUG] name = description, field_type = <class 'str'>
[2022-02-09 00:40:06,599][simple_parsing.helpers.serialization.decoding][DEBUG] name = versions, field_type = typing.Dict[str, shared.schema.dataset.DatasetVersion]
[2022-02-09 00:40:06,600][simple_parsing.helpers.serialization.serializable][DEBUG] from_dict for <class 'shared.schema.dataset.DatasetVersion'>, drop extra fields: True
[2022-02-09 00:40:06,601][simple_parsing.helpers.serialization.decoding][DEBUG] name = type, field_type = <enum 'DatasetVersi