In [1]:
import pandas as pd

from shared.schema import DatasetSchema, GraphSchema
from shared.graph.loading import pd_from_entity_schema

In [2]:
DATASET = DatasetSchema.load_schema('star-wars')
schema = GraphSchema.from_dataset(DATASET)

In [3]:
explicit_label = False
explicit_timestamp = True
unix_timestamp = True
prefix_id = None
include_properties = lambda cs: [c for c in cs if c.startswith('feat_') or c == 'name']

nodes_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    ).set_index('id').drop(columns=['type']).sort_index()
    for label, entity_schema in schema.nodes.items()
}
node_mappings_dfs = {
    label: pd.Series(range(len(df)), index=df.index, name='nid')
    for label, df in nodes_dfs.items()
}

edges_dfs = {
    label: pd_from_entity_schema(
        entity_schema,
        explicit_label=explicit_label,
        explicit_timestamp=explicit_timestamp,
        include_properties=include_properties,
        unix_timestamp=unix_timestamp,
        prefix_id=prefix_id,
    )
        .reset_index()
        .drop(columns=['type'])
        .drop_duplicates(subset=['src', 'dst', 'timestamp'])
        .join(node_mappings_dfs[entity_schema.source_type], on='src')
        .drop(columns=['src'])
        .rename(columns={'nid': 'src'})
        .join(node_mappings_dfs[entity_schema.target_type], on='dst')
        .drop(columns=['dst'])
        .rename(columns={'nid': 'dst'})
    for label, entity_schema in schema.edges.items()
}

cursor = 0
for df in edges_dfs.values():
    df.index += cursor
    cursor += len(df)

In [4]:
import torch
import numpy as np
from torch_geometric.data import HeteroData, Data
from torch_geometric.utils import negative_sampling

[2022-02-07 15:19:37,300][git.cmd][DEBUG] Popen(['git', 'version'], cwd=/data/pella/projects/University/Thesis/Thesis/code/experiments/notebooks, universal_newlines=False, shell=None, istream=None)
[2022-02-07 15:19:37,318][git.cmd][DEBUG] Popen(['git', 'version'], cwd=/data/pella/projects/University/Thesis/Thesis/code/experiments/notebooks, universal_newlines=False, shell=None, istream=None)


In [5]:
data = HeteroData()
for ntype, ndf in nodes_dfs.items():
    columns = [c for c in ndf.columns if c.startswith('feat_')]
    data[ntype].x = torch.tensor(ndf[columns].values.astype(np.float32))
    if 'timestamp' in ndf.columns:
        data[ntype].timestamp = torch.tensor(ndf['timestamp'].values.astype(np.int32))

for etype, edf in edges_dfs.items():
    columns = [c for c in edf.columns if c.startswith('feat_')]
    edge_schema = schema.edges[etype]
    edge_type = (edge_schema.source_type, edge_schema.get_type(), edge_schema.target_type)
    data[edge_type].edge_attr = torch.tensor(edf[columns].values.astype(np.float32))
    data[edge_type].edge_index = torch.tensor(edf[['src', 'dst']].T.values.astype(np.int64))
    if 'timestamp' in edf.columns:
        data[edge_type].timestamp = torch.tensor(edf['timestamp'].values.astype(np.int32))

In [6]:
print(data)

HeteroData(
  [1mCharacter[0m={ x=[113, 32] },
  [1m(Character, INTERACTIONS, Character)[0m={
    edge_attr=[958, 0],
    edge_index=[2, 958],
    timestamp=[958]
  },
  [1m(Character, MENTIONS, Character)[0m={
    edge_attr=[1120, 0],
    edge_index=[2, 1120],
    timestamp=[1120]
  }
)


In [7]:
import pytorch_lightning as pl

In [8]:
from ml.data import LinkSplitter, EdgeLoader


class DataModule(pl.LightningDataModule):
    def __init__(
            self,
            data: HeteroData,
            node_type: str,
            num_val=0.3,
            num_test=0.0,
            neg_sample_ratio=1.0,
            num_samples=None,
            batch_size=8,
            num_workers=0,
            **kwargs,
    ):
        super().__init__()
        self.data = data
        self.neg_sample_ratio = neg_sample_ratio
        self.batch_size = batch_size
        self.num_samples = num_samples or [4] * 2
        self.num_workers = num_workers
        self.node_type = node_type

        transform = LinkSplitter(
            num_val=num_val,
            num_test=num_test,
            edge_types=data.edge_types,
        )

        self.train_data, self.val_data, self.test_data = transform(data)

    def _get_edges_partition(self, data: HeteroData, partition: int) -> torch.Tensor:
        pos_edge_index = {
            edge_type: data[edge_type].edge_index[:, data[edge_type].edge_partitions == partition]
            for edge_type in data.edge_types
        }

        neg_edge_index = {
            edge_type: negative_sampling(
                data[edge_type].edge_index,
                num_neg_samples=int(pos_edge_index[edge_type].shape[1] * self.neg_sample_ratio)
            )
            for edge_type in data.edge_types
        }

        edge_index = {
            edge_type: torch.cat([
                torch.cat(
                    [pos_edge_index[edge_type], torch.ones(1, pos_edge_index[edge_type].shape[1], dtype=torch.long)],
                    dim=0),
                torch.cat(
                    [neg_edge_index[edge_type], torch.zeros(1, neg_edge_index[edge_type].shape[1], dtype=torch.long)],
                    dim=0),
            ], dim=1)
            for edge_type in data.edge_types
        }

        return edge_index

    def train_dataloader(self):
        edge_index = self._get_edges_partition(self.train_data, partition=0)
        nodes = (self.node_type, torch.tensor(range(data[self.node_type].num_nodes)))

        return EdgeLoader(
            self.train_data,
            num_samples=self.num_samples,
            shuffle=True,
            input_nodes=nodes,
            input_edges=torch.cat([*edge_index.values()], dim=1).t(),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        edge_index = self._get_edges_partition(self.val_data, partition=1)
        nodes = (self.node_type, torch.tensor(range(data[self.node_type].num_nodes)))

        return EdgeLoader(
            self.train_data,
            num_samples=self.num_samples,
            input_nodes=nodes,
            input_edges=torch.cat([*edge_index.values()], dim=1).t(),
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

data_module = DataModule(data, batch_size=16, num_samples=[4] * 2, num_workers=4, node_type='Character')
# next(iter(data_module.train_dataloader()))

(HeteroData(
   [1mCharacter[0m={
     x=[24, 32],
     batch_size=16
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[98, 0],
     edge_index=[2, 98],
     timestamp=[98],
     edge_partitions=[98]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[81, 0],
     edge_index=[2, 81],
     timestamp=[81],
     edge_partitions=[81]
   }
 ),
 HeteroData(
   [1mCharacter[0m={
     x=[24, 32],
     batch_size=16
   },
   [1m(Character, INTERACTIONS, Character)[0m={
     edge_attr=[103, 0],
     edge_index=[2, 103],
     timestamp=[103],
     edge_partitions=[103]
   },
   [1m(Character, MENTIONS, Character)[0m={
     edge_attr=[104, 0],
     edge_index=[2, 104],
     timestamp=[104],
     edge_partitions=[104]
   }
 ),
 tensor([1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [22]:
import torch.nn.functional as F

from torch_geometric.nn import HGTConv, Linear
import torchmetrics

In [23]:
class HGTModel(pl.LightningModule):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()
        self.accuracy = torchmetrics.Accuracy()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels * 2, out_channels)

    def forward_embed(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict['Character']

    def forward(self, larg, rarg):
        lemb = self.forward_embed(*larg)
        remb = self.forward_embed(*rarg)
        emb = torch.cat([lemb, remb], dim=-1)

        return self.lin(emb)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def _step(self, batch, batch_idx):
        batch_l, batch_r, label = batch
        batch_size = batch_l['Character'].batch_size

        out = self(
            (batch_l.x_dict, batch_l.edge_index_dict),
            (batch_r.x_dict, batch_r.edge_index_dict),
        )[:batch_size]

        loss = F.cross_entropy(out, label)

        pred = out.argmax(dim=-1)
        self.accuracy(pred, label)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch, batch_idx)
        self.log('train_acc', self.accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._step(batch, batch_idx)
        self.log('val_acc', self.accuracy, prog_bar=True)
        self.log('val_loss', loss, prog_bar=True)
        return loss

model = HGTModel(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)

In [25]:
trainer = pl.Trainer(gpus=1, callbacks=[
    pl.callbacks.EarlyStopping(monitor="val_acc", min_delta=0.00, patience=5, verbose=True, mode="max")
])
trainer.fit_loop.connect()
trainer.fit(model, data_module)

  rank_zero_deprecation(



Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]