In [2]:
import os
import torch
import torch_geometric
from torch_geometric.explain import GNNExplainer, Explainer
from main import Proteo, AttrDict
from proteo.datasets.ftd import ROOT_DIR, FTDDataset
from config_utils import CONFIG_FILE, read_config_from_file
from models.gat_v4 import GATv4
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np



#Load model checkpoint
module = Proteo.load_from_checkpoint("/home/lcornelis/code/proteo/proteo/checkpoints/ckpt14-05-2024epoch=41.ckpt")

root = os.path.join(ROOT_DIR, "data", "ftd")
config = read_config_from_file(CONFIG_FILE)
model_parameters = getattr(config, config.model)
model_parameters = AttrDict(model_parameters)
test_dataset = FTDDataset(root, "test", config)


# test_loader = DataLoader(  # makes into one big graph
#         test_dataset,
#         batch_size=config.batch_size,
#         shuffle=True,
#         num_workers=config.num_workers,
#         pin_memory=config.pin_memory,
#     )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Explainer
explainer = Explainer(
model=module.model.to(device),
algorithm=GNNExplainer(epochs=5),
explanation_type='model',
model_config=dict(
    mode='regression',
    task_level='graph',  # Explain why the model predicts a certain property or label for the entire graph (nodes + edges)
    return_type='raw'
),
node_mask_type='object', # Generate masks that indicate the importance of individual node features
edge_mask_type=None,
threshold_config=dict(
    threshold_type='topk',
    value=200,
),
)

# Function to visualize node importance
def visualize_node_importance(data, node_importance):
    # Convert to networkx graph
    G = torch_geometric.utils.to_networkx(data, to_undirected=True)

    # Get node positions
    pos = nx.spring_layout(G)

    # Plot nodes with color based on importance
    plt.figure(figsize=(10, 8))
    nx.draw(
        G, pos,
        node_color=node_importance,
        node_size=300,
        cmap=plt.cm.Reds,
        with_labels=True
    )
    plt.title('Node Importance Visualization')
    plt.colorbar(plt.cm.ScalarMappable(cmap=plt.cm.Reds), label='Importance')
    plt.show()

i = 0 
for data in test_dataset:
    if i > 2:
        break
    data_attributes = data.keys
    # Ensure data.x and data.edge_index are tensors
    if not isinstance(data.x, torch.Tensor) or not isinstance(data.edge_index, torch.Tensor):
        raise TypeError("data.x and data.edge_index must be torch.Tensor")
    
    print(f'Batch actual attributes: {data_attributes}')
    explanation = explainer(
        data.x,
        data.edge_index,
        data=data,
        target=None,
        index=None
    )
    print(f'Generated explanations in {explanation.available_explanations}')
    node_importance = explanation.node_mask.cpu().detach().numpy()
    nonzeroind = np.nonzero(node_importance)[0]
    print(nonzeroind)
    i += 1


conv1 is:GATConv(1, 8, heads=4)
conv2 is:GATConv(32, 16, heads=3)


Batch actual attributes: <bound method BaseData.keys of Data(x=[7289, 1], edge_index=[2, 1297081], y=[1])>
Generated explanations in ['node_mask']
[  83   88  175  247  320  321  344  349  439  472  474  484  513  520
  521  528  572  595  642  644  707  774  810  846  931  969 1004 1054
 1177 1183 1234 1255 1319 1324 1344 1434 1483 1533 1656 1665 1689 1698
 1820 1845 1852 1867 1909 1910 1918 2001 2083 2086 2091 2114 2184 2199
 2205 2212 2243 2244 2312 2370 2458 2481 2492 2523 2537 2559 2566 2575
 2577 2595 2672 2676 2678 2724 2727 2757 2763 2774 2833 2852 2863 2867
 2895 2911 2913 3002 3014 3040 3070 3140 3175 3185 3222 3275 3323 3333
 3414 3539 3646 3683 3707 3709 3720 3743 3818 3827 3889 3930 3951 4012
 4076 4079 4104 4209 4246 4261 4312 4324 4375 4392 4434 4448 4479 4500
 4527 4609 4610 4750 4847 4857 4858 4863 4869 4914 4916 4940 4973 4979
 4993 5037 5054 5103 5118 5128 5171 5229 5233 5246 5267 5305 5320 5332
 5413 5418 5542 5629 5631 5724 5733 5756 5798 5818 5849 5853 5953 5970
 