# Intro

The goal of this notebook is to demonstrate how to set up each model used in this study, make predictions and visualize model structure.

# Kedro load

In [1]:
%load_ext kedro.ipython

In [2]:
%reload_kedro

# Load libs

In [49]:
import torch as th
import torch_geometric as tg
import torch_geometric.nn as tgnn
import pytorch_lightning as pl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch_geometric.data as pyg_data

import hexgin.lib.models.common as cmn
import hexgin.pipelines.fincen_experiment.fincen_model_utils as fmu
import hexgin.pipelines.fincen_experiment.fincen_models as fm
import hexgin.lib.models.experimental.hexgin_model as hg
import hexgin.pipelines.fincen_experiment.nodes as experiment_nodes
import hexgin.lib.models.common as cmn
import hexgin.lib.training_commons as train_cmn

from torch_geometric.explain import Explainer, CaptumExplainer
from importlib import reload
from pyvis.network import Network

from hexgin.lib.models import lit_wrappers as litw
from hexgin import consts as cc, utils as util

## Reload libs if needed

In [4]:
cmn = reload(cmn)
fmu = reload(fmu)
fm = reload(fm)
experiment_nodes = reload(experiment_nodes)

# Load data

In [5]:
train_g = catalog.load("train_graph")
val_g = catalog.load("val_graph")
test_g = catalog.load("test_graph")

In [6]:
entity_encoder = catalog.load("entity_encoder")

In [7]:
train_g


[1;35mHeteroData[0m[1m([0m
  [33mentity[0m=[1m{[0m [33mx[0m=[1m[[0m[1;36m2823[0m, [1;36m2[0m[1m][0m [1m}[0m,
  [33mfiling[0m=[1m{[0m [33mx[0m=[1m[[0m[1;36m4507[0m[1m][0m [1m}[0m,
  [1m([0mentity, sends, filing[1m)[0m=[1m{[0m [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m4507[0m[1m][0m [1m}[0m,
  [1m([0mfiling, benefits, entity[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m1894[0m[1m][0m,
    [33medge_label[0m=[1m[[0m[1;36m811[0m[1m][0m,
    [33medge_label_index[0m=[1m[[0m[1;36m2[0m, [1;36m811[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, concerns, entity[1m)[0m=[1m{[0m [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m24491[0m[1m][0m [1m}[0m
[1m)[0m

In [8]:
filing_encoder = catalog.load("filing_encoder")
country_encoder = catalog.load("country_encoder")
entity_encoder = catalog.load("entity_encoder")

# Building models

Section below will demonstrate how to build each model from configs. Each model will be:

1. Instantiated from config.
2. Visualized.
3. Tested on a small dataset, to check if works properly.

## SAGE

### Build model from config

In [9]:
sage_conf = {
    "sage_dims": [64, 32],
    "hidden_act_f": 'relu',
    "out_act_f": 'relu',
    "linpred_act_f": 'relu',
    "embedding_params": {
        "ent_embed_dim": 64,
        "filing_embed_dim": 64,
        "country_emebd_dim": 12,
    }
}

In [10]:
tg.seed_everything(123)
sage = fmu.build_graph_sage_model(
    train_g,
    dim_ents=entity_encoder.classes_.shape[0],
    dim_countries=country_encoder.classes_.shape[0],
    dim_filings=filing_encoder.classes_.shape[0],
    sage_params=sage_conf
)

### Inspect model

In [11]:
list(sage.preproc_model.parameters())


[1m[[0m
    Parameter containing:
[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m0.3374[0m, [1;36m-0.1778[0m, [1;36m-0.3035[0m,  [33m...[0m, [1;36m-0.7979[0m,  [1;36m0.1838[0m,  [1;36m0.2293[0m[1m][0m,
        [1m[[0m [1;36m0.5146[0m,  [1;36m0.9938[0m, [1;36m-0.2587[0m,  [33m...[0m,  [1;36m1.2774[0m, [1;36m-1.4596[0m, [1;36m-2.1595[0m[1m][0m,
        [1m[[0m[1;36m-0.2582[0m, [1;36m-2.0407[0m, [1;36m-0.8016[0m,  [33m...[0m,  [1;36m0.1132[0m,  [1;36m0.8365[0m,  [1;36m0.0285[0m[1m][0m,
        [33m...[0m,
        [1m[[0m [1;36m0.1044[0m,  [1;36m0.6398[0m, [1;36m-0.7032[0m,  [33m...[0m, [1;36m-0.3725[0m,  [1;36m1.2763[0m,  [1;36m0.6706[0m[1m][0m,
        [1m[[0m [1;36m0.6796[0m, [1;36m-1.1518[0m,  [1;36m0.0343[0m,  [33m...[0m, [1;36m-0.8727[0m, [1;36m-0.5842[0m,  [1;36m0.5277[0m[1m][0m,
        [1m[[0m[1;36m-1.4733[0m, [1;36m-0.8799[0m,  [1;36m1.3983[0m,  [33m...[0m, [1;36m-1.2414[0

In [12]:
with th.no_grad():
    x_preproc = sage.preproc_model(train_g.x_dict)
    print(tgnn.summary(sage.gnn_conv, x_preproc, train_g.edge_index_dict))

+-------------------------------------------+---------------+----------------+----------+
| Layer                                     | Input Shape   | Output Shape   | #Param   |
|-------------------------------------------+---------------+----------------+----------|
| GraphModule                               |               |                | -1       |
| ├─(conv1)ModuleDict                       | --            | --             | -1       |
| │    └─(entity__sends__filing)SAGEConv    | [2, 4507]     | [4507, 64]     | -1       |
| │    └─(filing__benefits__entity)SAGEConv | [2, 1894]     | [2823, 64]     | -1       |
| │    └─(filing__concerns__entity)SAGEConv | [2, 24491]    | [2823, 64]     | -1       |
| ├─(hidden_act)ModuleDict                  | --            | --             | --       |
| │    └─(entity)ReLU                       | [2823, 64]    | [2823, 64]     | --       |
| │    └─(filing)ReLU                       | [4507, 64]    | [4507, 64]     | --       |
| ├─(batch

### Test on a simple output

In [13]:
x_dict = train_g.x_dict.copy()
with th.no_grad():
    out = sage.forward(train_g, train_g.edge_index_dict)
out.shape

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m811[0m, [1;36m1[0m[1m][0m[1m)[0m

## MLP

### Build model from config

In [14]:
mlp_conf = {
    'hdims': [256, 128, 64],
    "embedding_params": {
        "ent_embed_dim": 64,
        "filing_embed_dim": 64,
        "country_emebd_dim": 12,
    }
}

tg.seed_everything(123)

mlp = fmu.build_mlp(
    dim_ents=entity_encoder.classes_.shape[0],
    dim_countries=country_encoder.classes_.shape[0],
    dim_filings=filing_encoder.classes_.shape[0],
    mlp_params=mlp_conf
)

list(mlp.h_layers.parameters())


[1m[[0m
    Parameter containing:
[1;35mtensor[0m[1m([0m[1m[[0m[1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m.,
        [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m.,
        [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m.,
        [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;36m1[0m., [1;

### Inspect model

In [15]:
print(tgnn.summary(mlp, train_g))

+-------------------------------------+---------------+----------------+----------+
| Layer                               | Input Shape   | Output Shape   | #Param   |
|-------------------------------------+---------------+----------------+----------|
| MlpModel                            | [7330, 7330]  | [811, 1]       | 544,137  |
| ├─(embed_model)FinCENPreprocModel   |               |                | 465,648  |
| │    └─(embedding_layers)ModuleDict | --            | --             | 465,648  |
| │    │    └─(entity)Embedding       | [2823]        | [2823, 64]     | 175,552  |
| │    │    └─(filing)Embedding       | [4507]        | [4507, 64]     | 288,512  |
| │    │    └─(country)Embedding      | [2823]        | [2823, 12]     | 1,584    |
| ├─(h_layers)Sequential              | [811, 140]    | [811, 1]       | 78,489   |
| │    └─(0)BatchNorm1d               | [811, 140]    | [811, 140]     | 280      |
| │    └─(1)Linear                    | [811, 140]    | [811, 256]     | 36,

### Test model on simple output

In [16]:
with th.no_grad():
    out2 = mlp(train_g)
out2


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m4.9739e-01[0m[1m][0m,
        [1m[[0m[1;36m-6.6340e-02[0m[1m][0m,
        [1m[[0m[1;36m-5.6424e-01[0m[1m][0m,
        [1m[[0m[1;36m-1.4783e-01[0m[1m][0m,
        [1m[[0m[1;36m-7.1502e-01[0m[1m][0m,
        [1m[[0m [1;36m5.4690e-01[0m[1m][0m,
        [1m[[0m [1;36m7.0940e-01[0m[1m][0m,
        [1m[[0m [1;36m8.0337e-01[0m[1m][0m,
        [1m[[0m [1;36m3.4420e-01[0m[1m][0m,
        [1m[[0m[1;36m-2.2426e-01[0m[1m][0m,
        [1m[[0m[1;36m-2.6177e-01[0m[1m][0m,
        [1m[[0m[1;36m-5.8978e-01[0m[1m][0m,
        [1m[[0m [1;36m3.4636e-01[0m[1m][0m,
        [1m[[0m [1;36m5.8572e-01[0m[1m][0m,
        [1m[[0m[1;36m-1.1593e+00[0m[1m][0m,
        [1m[[0m [1;36m1.0988e+00[0m[1m][0m,
        [1m[[0m [1;36m5.9922e-01[0m[1m][0m,
        [1m[[0m[1;36m-5.2914e-01[0m[1m][0m,
        [1m[[0m [1;36m3.8228e-02[0m[1m][0m,
        [1m[[0m[1;3

## HexGIN building

### Build model from config

In [17]:
hexgin_config = catalog.load("params:model_params.HexGIN_params")
hexgin_config


[1m{[0m
    [32m'embedding_params'[0m: [1m{[0m[32m'ent_embed_dim'[0m: [1;36m64[0m, [32m'filing_embed_dim'[0m: [1;36m64[0m, [32m'country_emebd_dim'[0m: [1;36m12[0m[1m}[0m,
    [32m'linkpred_act_f'[0m: [32m'leakyReLu'[0m,
    [32m'linkpred_dims'[0m: [1m[[0m[1;36m64[0m, [1;36m32[0m, [1;36m1[0m[1m][0m,
    [32m'conv_params'[0m: [1m[[0m
        [1m{[0m
            [32m'activation'[0m: [32m'leakyReLu'[0m,
            [32m'aggregation'[0m: [32m'multi'[0m,
            [32m'batch_norm'[0m: [3;92mTrue[0m,
            [32m'aggregation_params'[0m: [1m{[0m[32m'aggrs'[0m: [1m[[0m[32m'add'[0m, [32m'mean'[0m[1m][0m, [32m'mode'[0m: [32m'cat'[0m[1m}[0m,
            [32m'relations'[0m: [1m{[0m
                [32m'entity__sends__filing'[0m: [1m[[0m[1m[[0m[1;36m216[0m, [1;36m128[0m[1m][0m, [1m[[0m[1;36m128[0m, [1;36m64[0m[1m][0m[1m][0m,
                [32m'filing__benefits__entity'[0m: [1m[[0m[1m[[0

In [18]:
hexgin = fmu.build_hexgin_net(
    dim_ents=entity_encoder.classes_.shape[0],
    dim_countries=country_encoder.classes_.shape[0],
    dim_filings=filing_encoder.classes_.shape[0],
    hexgin_params=hexgin_config

)

### Inspect model

In [19]:
print(tgnn.summary(hexgin, train_g))

+-------------------------------------+---------------+----------------+----------+
| Layer                               | Input Shape   | Output Shape   | #Param   |
|-------------------------------------+---------------+----------------+----------|
| HeteroGNNLinkPredModel              | [7330, 7330]  | [811, 1]       | 631,347  |
| ├─(preproc_model)FinCENPreprocModel |               |                | 465,648  |
| │    └─(embedding_layers)ModuleDict | --            | --             | 465,648  |
| │    │    └─(entity)Embedding       | [2823]        | [2823, 64]     | 175,552  |
| │    │    └─(filing)Embedding       | [4507]        | [4507, 64]     | 288,512  |
| │    │    └─(country)Embedding      | [2823]        | [2823, 12]     | 1,584    |
| ├─(gnn_conv)HexGINModel             |               |                | 163,586  |
| │    └─(hexgin_layers)ModuleList    | --            | --             | 163,586  |
| │    │    └─(0)HexGINLayer          |               |                | 117

### Test model on a simple output

In [20]:
with th.no_grad():
    hexgin_out = hexgin(val_g)
hexgin_out[:10]


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m-0.0359[0m[1m][0m,
        [1m[[0m[1;36m-0.1236[0m[1m][0m,
        [1m[[0m [1;36m0.0209[0m[1m][0m,
        [1m[[0m [1;36m0.0290[0m[1m][0m,
        [1m[[0m [1;36m0.0120[0m[1m][0m,
        [1m[[0m [1;36m0.0290[0m[1m][0m,
        [1m[[0m [1;36m0.0341[0m[1m][0m,
        [1m[[0m [1;36m0.0984[0m[1m][0m,
        [1m[[0m [1;36m0.0418[0m[1m][0m,
        [1m[[0m [1;36m0.0167[0m[1m][0m[1m][0m[1m)[0m

# Check trianing functions for deterministic outputs

In [21]:
model_params = catalog.load("params:model_params")
training_params = catalog.load("params:training_params")
crossval_params = catalog.load("params:crossval_params")

In [22]:
hexgin_test_report, hexgin_metric = experiment_nodes.train_hexgin(
    train_g,
    val_g,
    test_g,
    entity_encoder,
    filing_encoder,
    country_encoder,
    model_params,
    training_params
)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

                                                                           

Epoch 11: 100%|██████████| 13/13 [00:00<00:00, 23.38it/s, v_num=65_1]


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 27.29it/s]


In [24]:
print(hexgin_test_report)

         class  precision    recall  f1-score  support
0            0   0.707488  0.897869  0.791389   1126.0
1            1   0.860267  0.628774  0.726526   1126.0
2  avg / total   0.783878  0.763321  0.758958   2252.0


In [25]:
print(hexgin_metric)

         f1  precision    recall    rocauc   model
0  0.726526   0.860267  0.628774  0.815136  HexGIN


# Hexgin model predictions visualizataion

Secion below presents how to visualize predictions of HexGIN model using CAPTUM and SHAP tools.

Note, that, as this is a dedicated fuctionality for heterogeneous graph NNs, part of this work is being now moved to the separate project: **HEXTRACTOR** that
will be made available as a separate publication.

The project can be now accessed at: [HEXTRACTOR repo](https://github.com/maddataanalyst/hextractor).

### Prepare model for explanation visualization

In [27]:
hexgin = fmu.build_hexgin_net(
    dim_ents=entity_encoder.classes_.shape[0],
    dim_countries=country_encoder.classes_.shape[0],
    dim_filings=filing_encoder.classes_.shape[0],
    hexgin_params=hexgin_config

)
hexgin_lit = litw.LinkPredLitWrapper(hexgin, "HexGIN", cc.TARGET_RELATION, learning_rate=0.001,
                                        params=hexgin_config)
early_stop = pl.callbacks.early_stopping.EarlyStopping(monitor='val_f1', patience=2, mode='max')

trainer = pl.Trainer(
        check_val_every_n_epoch=3,
        max_epochs=20,
        enable_checkpointing=True,
        callbacks=[early_stop],
        accelerator='cuda',
        log_every_n_steps=5,
        deterministic=True)
pl.seed_everything(123)

[1;36m123[0m

In [30]:
train_loader = tg.loader.LinkNeighborLoader(
        data=train_g,
        num_neighbors={r: [-1, -1]for r in train_g.edge_types},
        edge_label_index=(cc.TARGET_RELATION, train_g[cc.TARGET_RELATION].edge_label_index),
        edge_label=train_g[cc.TARGET_RELATION].edge_label,
        directed=True,
        neg_sampling_ratio=1.0,
        batch_size=64,
        shuffle=True
    )
next(iter(train_loader))


[1;35mHeteroData[0m[1m([0m
  [33mentity[0m=[1m{[0m
    [33mx[0m=[1m[[0m[1;36m676[0m, [1;36m2[0m[1m][0m,
    [33mn_id[0m=[1m[[0m[1;36m676[0m[1m][0m,
  [1m}[0m,
  [33mfiling[0m=[1m{[0m
    [33mx[0m=[1m[[0m[1;36m2663[0m[1m][0m,
    [33mn_id[0m=[1m[[0m[1;36m2663[0m[1m][0m,
  [1m}[0m,
  [1m([0mentity, sends, filing[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m2663[0m[1m][0m,
    [33me_id[0m=[1m[[0m[1;36m2663[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, benefits, entity[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m609[0m[1m][0m,
    [33medge_label[0m=[1m[[0m[1;36m128[0m[1m][0m,
    [33medge_label_index[0m=[1m[[0m[1;36m2[0m, [1;36m128[0m[1m][0m,
    [33me_id[0m=[1m[[0m[1;36m609[0m[1m][0m,
    [33minput_id[0m=[1m[[0m[1;36m64[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, concerns, entity[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36

In [31]:
trainer.fit(hexgin_lit, train_loader)

Epoch 19: 100%|██████████| 13/13 [00:01<00:00,  9.35it/s, v_num=2]

Epoch 19: 100%|██████████| 13/13 [00:01<00:00,  9.26it/s, v_num=2]


### Build wrapper explainer model

In [33]:
class ExplanationsWrapper(th.nn.Module):

    def __init__(self, hexgin):
        super().__init__()
        self.hexgin = hexgin

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        print(kwargs.keys())
        data = pyg_data.HeteroData()
        for k, x in x_dict.items():
            data[k].x = x.to(th.long)
        for rel, edge_index in edge_index_dict.items():
            data[rel].edge_index = edge_index.to(th.long)
        if 'edge_label_index_dict' in kwargs:
            for rel, label_index in kwargs['edge_label_index_dict'].items():
                data[rel].edge_label_index = label_index.to(th.long)
        return self.hexgin(data, use_labeled_edges=False)

In [34]:
explanation_example_loader = tg.loader.LinkNeighborLoader(
        data=test_g,
        num_neighbors={r: [10, 10]for r in test_g.edge_types},
        edge_label_index=(cc.TARGET_RELATION, test_g[cc.TARGET_RELATION].edge_label_index),
        edge_label=test_g[cc.TARGET_RELATION].edge_label,
        directed=True,
        neg_sampling_ratio=0.0,
        batch_size=16,
        shuffle=True
    )
next(iter(explanation_example_loader))


[1;35mHeteroData[0m[1m([0m
  [33mentity[0m=[1m{[0m
    [33mx[0m=[1m[[0m[1;36m73[0m, [1;36m2[0m[1m][0m,
    [33mn_id[0m=[1m[[0m[1;36m73[0m[1m][0m,
  [1m}[0m,
  [33mfiling[0m=[1m{[0m
    [33mx[0m=[1m[[0m[1;36m120[0m[1m][0m,
    [33mn_id[0m=[1m[[0m[1;36m120[0m[1m][0m,
  [1m}[0m,
  [1m([0mentity, sends, filing[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m110[0m[1m][0m,
    [33me_id[0m=[1m[[0m[1;36m110[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, benefits, entity[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m46[0m[1m][0m,
    [33medge_label[0m=[1m[[0m[1;36m16[0m[1m][0m,
    [33medge_label_index[0m=[1m[[0m[1;36m2[0m, [1;36m16[0m[1m][0m,
    [33me_id[0m=[1m[[0m[1;36m46[0m[1m][0m,
    [33minput_id[0m=[1m[[0m[1;36m16[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, concerns, entity[1m)[0m=[1m{[0m
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [

In [35]:
it = iter(explanation_example_loader)
next(it)
batch = next(it)


fake_model = ExplanationsWrapper(hexgin_lit.model)
with th.no_grad():
    fake_model(batch.x_dict, batch.edge_index_dict, edge_label_index_dict=batch.edge_label_index_dict)

dict_keys(['edge_label_index_dict'])


### Build explanations

In [36]:
explainer = Explainer(
    fake_model,  # It is assumed that model outputs a single tensor.
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type=None,
    edge_mask_type='object',
    model_config = dict(
        mode='binary_classification',
        task_level='edge',
        return_type='probs',  # Model returns probabilities.
    )
)

In [37]:
hetero_explanation = explainer(
    batch.x_dict,
    batch.edge_index_dict,
    edge_label_index_dict=batch.edge_label_index_dict,
)

dict_keys(['edge_label_index_dict'])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])
dict_keys([])


In [41]:
hetero_explanation


[1;35mHeteroExplanation[0m[1m([0m
  [33mprediction[0m=[1m[[0m[1;36m49[0m, [1;36m1[0m[1m][0m,
  [33mtarget[0m=[1m[[0m[1;36m49[0m[1m][0m,
  [33mentity[0m=[1m{[0m [33mx[0m=[1m[[0m[1;36m67[0m[1m][0m [1m}[0m,
  [33mfiling[0m=[1m{[0m [33mx[0m=[1m[[0m[1;36m102[0m[1m][0m [1m}[0m,
  [1m([0mentity, sends, filing[1m)[0m=[1m{[0m
    [33medge_mask[0m=[1m[[0m[1;36m92[0m[1m][0m,
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m92[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, benefits, entity[1m)[0m=[1m{[0m
    [33medge_mask[0m=[1m[[0m[1;36m49[0m[1m][0m,
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m49[0m[1m][0m,
    [33medge_label_index[0m=[1m[[0m[1;36m2[0m, [1;36m16[0m[1m][0m,
  [1m}[0m,
  [1m([0mfiling, concerns, entity[1m)[0m=[1m{[0m
    [33medge_mask[0m=[1m[[0m[1;36m39[0m[1m][0m,
    [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m39[0m[1m][0m,
  [1m}[0m
[1m)[0m

### Build an interactive visualization chart

This part builds an interactive visualization chart that can be used to explore the model predictions and explanations.

Note, that this part is being moved to the [**Hextractor project**](https://github.com/maddataanalyst/hextractor).
Code presented below is a preview, work-in-progress version, that will be updated in the future.

In [52]:
hetero_node_type_counts = {
    hetero_explanation['entity']['x'].shape[0]: 'entity',
    hetero_explanation['filing']['x'].shape[0]: 'filing'
}

hetero_edge_type_counts = {
}
for rel in hetero_explanation.edge_types:
    hetero_edge_type_counts[hetero_explanation[rel]['edge_index'].shape[1]] = rel

expl_homogeneous = hetero_explanation.to_homogeneous()
homogeneous_node_type_counts = pd.value_counts(expl_homogeneous.node_type.numpy()).to_dict()

node_type_mapping = {}
node_type_names = []
for i, cnt in homogeneous_node_type_counts.items():
    node_type_mapping[i] = hetero_node_type_counts[cnt]

for node_type in expl_homogeneous.node_type.numpy():
    node_type_names.append(node_type_mapping[node_type])


edge_type_mapping = {}
edge_type_names =[]
for i, cnt in pd.value_counts(expl_homogeneous.edge_type.numpy()).items():
    edge_type_mapping[i] = hetero_edge_type_counts[cnt][1]

for edge_type in expl_homogeneous.edge_type.numpy():
    edge_type_names.append(edge_type_mapping[edge_type])

node_type_names = np.array(node_type_names)
edge_type_names = np.array(edge_type_names)

def remap_name(type_name, idx):
    if type_name == 'entity':
        return idx2entity[idx]
    else:
        return type_name

nt = Network(notebook=True)

source_2_color = {
    'entity': 'lightgreen',
    'filing': 'lightblue'
}

for idx, (s, t) in enumerate(expl_homogeneous.edge_index.T.numpy()):
    s = int(s)
    t = int(t)
    rel_type = edge_type_names[idx]
    source_name = node_type_names[s]
    source_color = source_2_color[source_name]
    source_name = remap_name(source_name, s)

    target_name = node_type_names[t]
    target_color = source_2_color[target_name]
    target_name = remap_name(target_name, t)
    weight = th.sigmoid(expl_homogeneous.edge_mask[idx]).item()
    nt.add_node(s, label=source_name, color=source_color)
    nt.add_node(t, label=target_name, color=target_color)
    nt.add_edge(s, t, label=rel_type, width=weight*5, color='black')

for idx, (s, t) in enumerate(expl_homogeneous.edge_label_index.T.numpy()):
    s, t = int(s), int(t)
    pred = th.sigmoid(expl_homogeneous.prediction[idx]).item() >= 0.5
    target = expl_homogeneous.target[idx].item()
    link_color = 'red' if pred != target else 'green'
    source_name = node_type_names[s]
    source_color = source_2_color[source_name]
    target_name = node_type_names[t]
    target_color = source_2_color[target_name]

    source_name = remap_name(source_name, s)
    target_name = remap_name(target_name, t)
    weight = 1.
    nt.add_node(int(s), label=source_name, color=source_color)
    nt.add_node(int(t), label=target_name, color=target_color)
    nt.add_edge(s, t, label=rel_type, width=5, color=link_color)
nt.show('predicions_visualization.html', notebook=True)

predicions_visualization.html
