In [1]:
import torch
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
import numpy as np
from tqdm.notebook import tqdm  # Use notebook-friendly tqdm
import matplotlib.pyplot as plt
from argparse import Namespace
from torch_geometric.explain import Explainer, GNNExplainer
import sys
sys.path.append('..')

# from src.explainability.gnn_explainer import GNNExplainer
from main_transductive import pretrain
from src.utils import set_random_seed, create_optimizer, WBLogger
from src.datasets.data_util import load_dataset, load_processed_graph
from src.models import build_model, PreModel
from src.evaluation import node_classification_evaluation
from src.utils import build_args, load_best_configs  # if needed

In [3]:
# ----------------------
# Config settings
# ----------------------
# Choose device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Set seed for reproducibility
seed = 0
set_random_seed(seed)

In [4]:
# Define training hyperparameters
dataset_name    = "CPDB_cdgps" 
max_epoch       = 100           # Total training epochs
max_epoch_f     = 200           
num_hidden      = 64
num_layers      = 3
encoder_type    = "rgcn"     
decoder_type    = "rgcn"    
replace_rate    = 0.05
num_edge_types  = 6
in_drop         = 0.2
attn_drop       = 0.1
mask_rate       = 0.5
drop_edge_rate  = 0.0
alpha_l         = 3
num_heads       = 4
activation      = "prelu"          
optimizer       = "adam"            
loss_fn         = "sce"      
lr              = 0.01
weight_decay    = 1e-3
lr_f            = 0.005              
weight_decay_f  = 1e-4
linear_prob     = False
load_model      = False              # Set True to load a checkpoint
save_model      = True              # Set True to save trained model
logs            = True              # Set True to use WBLogger
use_scheduler   = True              # Set True to use a learning rate scheduler
weight_decomposition = {'type': 'basis', 'num_bases': 2}
vertical_stacking = True

In [5]:
# ----------------------
# Create a Namespace for Args
# ----------------------

args = Namespace(
    device         = device,
    seeds          = [seed],
    dataset        = dataset_name,
    max_epoch      = max_epoch,
    max_epoch_f    = max_epoch_f,
    num_hidden     = num_hidden,
    num_layers     = num_layers,
    encoder        = encoder_type,
    decoder        = decoder_type,
    activation     = activation,
    in_drop        = in_drop,
    attn_drop      = attn_drop,
    mask_rate      = mask_rate,
    drop_edge_rate = drop_edge_rate,
    alpha_l        = alpha_l,
    num_heads      = num_heads,
    weight_decomposition = weight_decomposition,
    vertical_stacking = vertical_stacking,
    replace_rate   = replace_rate,
    num_edge_types = num_edge_types,
    optimizer      = optimizer,
    loss_fn        = loss_fn,
    lr             = lr,
    weight_decay   = weight_decay,
    lr_f           = lr_f,
    weight_decay_f = weight_decay_f,
    linear_prob    = linear_prob,
    load_model     = load_model,
    save_model     = save_model,
    logging        = logs,
    scheduler      = use_scheduler,
    num_features   = 6, 
    num_out_heads  = 1,
    residual = False,
    norm = None,
    negative_slope = 0.2,
    concat_hidden = False,
    #return_hidden = False,
)

In [6]:
# ----------------------
# Load Dataset and Build Model
# ----------------------
#graph, (num_features, num_classes) = load_dataset(dataset_name)
graph = load_processed_graph(f'../data/real/multidim_graph/6d/{dataset_name}_multiomics.pt')
num_features = graph.x.shape[1]
num_classes = graph.y.max().item() + 1

args.num_features = num_features  

model = build_model(args)
model.to(device)

PreModel(
  (encoder): RGCN(
    (rgcn_layers): ModuleList(
      (0-2): 3 x RGCNConv(64, 64)
    )
    (activation): PReLU(num_parameters=1)
    (head): Identity()
  )
  (decoder): RGCN(
    (rgcn_layers): ModuleList(
      (0): RGCNConv(64, 64)
    )
    (activation): PReLU(num_parameters=1)
    (head): Identity()
  )
  (encoder_to_decoder): Linear(in_features=64, out_features=64, bias=False)
)

In [7]:
model.load_state_dict(torch.load('../checkpoints/emb_extraction_model.pt'))

<All keys matched successfully>

In [14]:

model.eval()  # set model to evaluation mode

# Select the node(s) for which you want to explain the prediction.
# For example, to explain prediction for a single node:
target_node = 20  # replace with your desired node index

# Alternatively, if you want to loop over a list of nodes:
# target_nodes = [42, 105]
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=100),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

node_index = 10
explanation = explainer(x=graph.x, edge_index=graph.edge_index, index=node_index)

AttributeError: 'Tensor' object has no attribute 'num_nodes'