In [1]:
# Cell 0
import os, sys, random
import numpy as np
import torch

REPO_ROOT = os.getcwd()   # or simply "."
sys.path.insert(0, REPO_ROOT)

print("Repo root:", REPO_ROOT)
print("CUDA:", torch.cuda.is_available())

Repo root: /Users/kazishahrukhomar/Documents/MISC/COMBINEX
CUDA: False


In [2]:
os.environ["WANDB_MODE"] = "disabled"

In [3]:
# Cell 1
from omegaconf import OmegaConf
from hydra import initialize, compose

from torch.nn import functional as F

from src.utils.dataset import get_dataset
from src.datasets.dataset import DataInfo
from src.utils.models import get_model

# If you want to reuse your training pipeline:
from src.oracles.train.train import Trainer

In [4]:
# Cell 2
CFG_PATH = os.path.join(REPO_ROOT, "config", "config.yaml")  # adjust if different
cfg = OmegaConf.load(CFG_PATH)

with initialize(config_path="config", version_base="1.3"):
    cfg = compose(
        config_name="config",
        overrides=[
            "task=node",
            "dataset=citeseer",   # or cora / pubmed / etc
            "device=cpu",        # or cuda
        ],
    )


print(cfg.task.name)      # "Node"
print(cfg.dataset.name)   # "CiteSeer"

Node
citeseer


In [5]:
# Cell 3

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(int(cfg.general.seed))

device = "cuda" if torch.cuda.is_available() and cfg.device == "cuda" else "cpu"
print("Device:", device)

data = get_dataset(cfg.dataset.name, test_size=cfg.test_size)
data = data.to(device)

datainfo = DataInfo(cfg, data)  # note: your DataInfo deletes self.data internally
print("num_features:", datainfo.num_features)
print("num_classes:", datainfo.num_classes)

Device: cpu
num_features: 3703
num_classes: 6


In [6]:
OracleClass = get_model(name=cfg.model.name, task=cfg.task.name)
oracle = OracleClass(
    num_features=datainfo.num_features,
    num_classes=datainfo.num_classes,
    cfg=cfg,
).to(device)

trainer = Trainer(cfg=cfg, dataset=data, model=oracle, loss=F.cross_entropy)
trainer.start_training()

oracle = trainer.model
oracle.eval()

print("Oracle ready.")

name='CHEB', task='Node'
Epoch:    0 Train Loss: 1.7733 Train Acc: 0.2135 Test Loss: 1.7721 Test Acc: 0.2132
Epoch:    1 Train Loss: 1.7711 Train Acc: 0.2041 Test Loss: 1.7660 Test Acc: 0.2012
Epoch:    2 Train Loss: 1.7640 Train Acc: 0.2218 Test Loss: 1.7647 Test Acc: 0.2132
Epoch:    3 Train Loss: 1.7569 Train Acc: 0.2440 Test Loss: 1.7589 Test Acc: 0.2177
Epoch:    4 Train Loss: 1.7528 Train Acc: 0.2492 Test Loss: 1.7614 Test Acc: 0.2207
Epoch:    5 Train Loss: 1.7488 Train Acc: 0.2350 Test Loss: 1.7498 Test Acc: 0.2342
Epoch:    6 Train Loss: 1.7381 Train Acc: 0.2586 Test Loss: 1.7434 Test Acc: 0.2523
Epoch:    7 Train Loss: 1.7350 Train Acc: 0.2564 Test Loss: 1.7426 Test Acc: 0.2402
Epoch:    8 Train Loss: 1.7279 Train Acc: 0.2590 Test Loss: 1.7256 Test Acc: 0.2718
Epoch:    9 Train Loss: 1.7186 Train Acc: 0.2714 Test Loss: 1.7303 Test Acc: 0.2568
Epoch:   10 Train Loss: 1.7107 Train Acc: 0.2647 Test Loss: 1.7237 Test Acc: 0.2462
Epoch:   11 Train Loss: 1.7039 Train Acc: 0.2744 Te

In [7]:
@torch.no_grad()
def predict_node(oracle, data, node_idx: int, edge_weights=None):
    logits = oracle(
        data.x,
        data.edge_index,
        edge_weights=edge_weights
    )
    probs = torch.softmax(logits[node_idx], dim=-1)
    yhat = int(probs.argmax().item())
    return yhat, probs.detach().cpu()

In [8]:
node_idx = 0
y_orig, p_orig = predict_node(oracle, data, node_idx)

print("node:", node_idx)
print("pred:", y_orig)
print("probs:", p_orig.numpy())

node: 0
pred: 3
probs: [5.5762658e-11 3.1528758e-16 4.7524772e-08 1.0000000e+00 1.4729031e-08
 2.9926211e-11]


In [9]:
# Cell 6
target = (y_orig + 1) % datainfo.num_classes
print("Target class:", target)

Target class: 4


In [36]:
node_idx = 30  # choose
data.test_mask = [int(node_idx)]    # override to a list with 1 element

In [37]:
from src.node_level_explainer.utils.utils import build_factual_graph, check_graphs
from src.utils.explainer import get_node_explainer
import torch

device = "cuda" if torch.cuda.is_available() and cfg.device == "cuda" else "cpu"
oracle = oracle.to(device).eval()
data = data.to(device)

with torch.no_grad():
    out = oracle(data.x, data.edge_index)
    predicted_labels = torch.argmax(out, dim=1)
    target_labels = (1 + predicted_labels) % datainfo.num_classes

n_hops = len(cfg.model.hidden_layers) + 1
node_idx = int(node_idx)

factual_graph = build_factual_graph(
    mask_index=node_idx,
    data=data,
    n_hops=n_hops,
    oracle=oracle,
    predicted_labels=predicted_labels,
    target_labels=target_labels,
    device="cpu"
)

assert not check_graphs(factual_graph.edge_index), "Invalid factual graph."

cfg.device = "cpu"

oracle = oracle.to("cpu").eval()
factual_graph = factual_graph.to("cpu")

ExplainerCls = get_node_explainer("combined")
explainer = ExplainerCls(cfg, datainfo)

counterfactual = explainer.explain(graph=factual_graph, oracle=oracle)
counterfactual

Data(x=[449, 3703], edge_index=[2, 2242], y=[449], sub_index=5, x_projection=[64])

In [38]:
import torch
from collections import deque

def hop_distances(edge_index: torch.Tensor, center: int, num_nodes: int):
    # build adjacency list (treat as undirected)
    adj = [[] for _ in range(num_nodes)]
    ei = edge_index.detach().cpu()
    for u, v in ei.t().tolist():
        adj[u].append(v)
        adj[v].append(u)

    dist = [-1] * num_nodes
    dist[center] = 0
    q = deque([center])
    while q:
        u = q.popleft()
        for v in adj[u]:
            if dist[v] == -1:
                dist[v] = dist[u] + 1
                q.append(v)
    return dist  # list length=num_nodes

In [39]:
def print_node_feature_changes(
    factual_graph,
    counterfactual,
    nodes,                      # list of local node ids
    topk=8,
    show_all_if_dim_le=10,
    float_fmt="{:.4f}",
):
    x0 = factual_graph.x.detach().cpu()
    x1 = counterfactual.x.detach().cpu()
    Fdim = x0.size(1)

    for n in nodes:
        d = (x1[n] - x0[n])
        absd = torch.abs(d)

        # print(f"\nNode(local)={n}  |  L1_change={absd.sum().item():.4f}  L2_change={torch.norm(d).item():.4f}")

        if Fdim <= show_all_if_dim_le:
            # print all features
            for j in range(Fdim):
                b = float(x0[n, j].item())
                a = float(x1[n, j].item())
                dj = float(d[j].item())
                if abs(dj) > 0:
                    print(f"  f{j:03d}: {float_fmt.format(b)} -> {float_fmt.format(a)}  (Δ={float_fmt.format(dj)})")
        else:
            # print top-k changes only
            k = min(topk, Fdim)
            idx = torch.topk(absd, k=k).indices.tolist()
            for j in idx:
                b = float(x0[n, j].item())
                a = float(x1[n, j].item())
                dj = float(d[j].item())
                if abs(dj) > 0:
                    print(f"  f{j:03d}: {float_fmt.format(b)} -> {float_fmt.format(a)}  (Δ={float_fmt.format(dj)})")


In [40]:
center_local = int(factual_graph.new_idx)
num_nodes = factual_graph.num_nodes

dist = hop_distances(factual_graph.edge_index, center_local, num_nodes)

hop0 = [i for i, d in enumerate(dist) if d == 0]
hop1 = [i for i, d in enumerate(dist) if d == 1]
hop2 = [i for i, d in enumerate(dist) if d == 2]

print("Center(local):", center_local)
print("Hop-0 nodes:", hop0)
print("Hop-1 nodes:", hop1)
print("Hop-2 nodes:", hop2)

print("\n==================== HOP 0 (center) ====================")
print_node_feature_changes(factual_graph, counterfactual, hop0, topk=10)

print("\n==================== HOP 1 neighbors ====================")
print_node_feature_changes(factual_graph, counterfactual, hop1, topk=10)

print("\n==================== HOP 2 neighbors ====================")
print_node_feature_changes(factual_graph, counterfactual, hop2, topk=10)

Center(local): 5
Hop-0 nodes: [5]
Hop-1 nodes: [81, 189]
Hop-2 nodes: [1, 3, 4, 9, 11, 16, 24, 30, 37, 42, 46, 48, 67, 72, 74, 78, 79, 83, 86, 87, 92, 96, 111, 113, 114, 115, 120, 128, 129, 130, 131, 132, 136, 140, 141, 144, 148, 154, 160, 167, 171, 172, 174, 180, 188, 194, 205, 208, 212, 216, 220, 224, 225, 230, 241, 252, 262, 263, 266, 267, 269, 274, 275, 280, 286, 298, 299, 301, 304, 305, 320, 322, 331, 333, 336, 342, 357, 363, 370, 371, 372, 374, 375, 379, 382, 390, 391, 392, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405]

  f002: 0.0000 -> 1.0000  (Δ=1.0000)
  f009: 0.0000 -> 1.0000  (Δ=1.0000)
  f013: 0.0000 -> 1.0000  (Δ=1.0000)
  f012: 0.0000 -> 1.0000  (Δ=1.0000)
  f008: 0.0000 -> 1.0000  (Δ=1.0000)
  f004: 0.0000 -> 1.0000  (Δ=1.0000)
  f014: 0.0000 -> 1.0000  (Δ=1.0000)
  f011: 0.0000 -> 1.0000  (Δ=1.0000)
  f000: 0.0000 -> 1.0000  (Δ=1.0000)
  f001: 0.0000 -> 1.0000  (Δ=1.0000)


