# Inroduction

This notebook demonstrates a toy scenario of a network with users, systems and resources, and how structural information can be identified using graph explainers in order to make strategic red-team and blue-team decisions.

The network graph is built programmatically and randomly based on various parameters. The graph is homogeneous; metadata stores the node type (user, system, resource).

The graph learning in this notebook is transductive, meaning it's learning about a single given graph. That was chosen for simplicity. However, the code could be extended to be inductive by training on many randomly generated graphs. The GraphSAGE is a framework appropriate for inductive learning on graphs, and is one of the model types used here.

Importantly, in the generated graphs there is a notion of compromised user(s) and high-value resources. From a red-teaming perspective, a compromised user might be a known target, and a high-value resource might be identified as high-value after initial reconnaissance.

For the purposes of this toy example, the compromised user label provides a classification problem for transductive learning. This is useful because it provides a way to learn node vectors; however, node vectors could be learned in an unsupervised context as well; for example, using Deep Graph Infomax (DGI) which is implemented in pytorch_geometric, as is GraphSAGE.

After doing transductive learning on the generated graph using a 2-layer GNN, we then use the GNNExplainer, also from pytorch_geomtric, to identify nodes that: (1) contribute most to the target node’s prediction, (2) are bottlenecks in information flow, (3) determine classification outcomes (e.g., malicious vs. benign).

It works by turning edges and features off, and the ones it can’t remove without changing the model’s answer are the important parts of the graph.

So, overall steps taken are:
1) build a random graph
2) learn to classify user nodes in the graph by risk
3) run the graph explainer to identify salient structure in the graph
4) extract the explainer's edge info about a high-risk user and high-value resource(s)
5) call OpenAI API to translate the explainer details into red and blue team recommendations
6) visualize the graph

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.explain  import Explainer, GNNExplainer
from torch_geometric.explain.config import ModelConfig, ModelMode

## Build the toy "access" graph

In [2]:
from random_access_graph import generate_access_graph

# build a random graph
data, meta = generate_access_graph(
    n_users=8, n_systems=6, n_resources=10,
    p_login=0.3, p_lateral=0.04, p_sys_access=0.4,
    p_cross_user_cluster=0.15, n_user_clusters=3,
    high_value_ratio=0.15,
    seed=99
)

## Create a tiny GCN or GraphSAGE model

In [None]:
from gnn_factory import build_gnn

MODEL_NAME = "sage"          # <── swap "gcn"  /  "sage"
model = build_gnn(MODEL_NAME,
                  in_dim=3,
                  hidden=32,
                  n_classes=2)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(200):
    opt.zero_grad()
    out  = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward(); opt.step()
model.eval()

In [None]:
# Single-graph TRAIN accuracy (1.0 here because the only malicious user is in train_mask).
# The score is not a useful metric; we train purely to obtain embeddings
# that GNNExplainer will analyse.  Unsupervised objectives (e.g. DGI) could serve the
# same purpose if no labels were available.
with torch.no_grad():
    logits = model(data.x, data.edge_index)
    pred   = logits.argmax(dim=1)
    acc    = (pred[data.train_mask] == data.y[data.train_mask]).float().mean()
print(f"Train accuracy: {acc:.3f}")

## Run GNNExplainer

In [4]:
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=500, lr=0.005),        # more epochs → clearer masks
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=ModelConfig(
        mode=ModelMode.multiclass_classification,
        task_level='node',
        return_type='raw',                     # logits
    ),
)

## Calculate risk scores and braodcast scores

In [None]:
from helper_tabular import get_node_types
from rank_user_compromise import rank_users, broadcast_scores

top_users = rank_users(data, model, k=3, node_types=get_node_types(meta))
print("=== highest-risk users ===")
print(top_users)

bc = broadcast_scores(data, explainer, top_users["node_id"])
print("\n=== broadcast score ===")
print(bc)

In [6]:
# for now, only keep the explaination for the user node with highest risk score
node_ids = top_users["node_id"].tolist()
for node_id in top_users["node_id"]:
    explanation = explainer(data.x, data.edge_index, index=node_id)
    break

# Review the explanation and related data 

In [None]:
# look at the importance of the features for the high risk user node
feat_names = meta["data_feature_names"]
row = explanation.node_mask[node_id]           # (3,) tensor
for w, n in sorted(zip(row.tolist(), feat_names), reverse=True):
    print(f"{n:<15} {w:6.3f}")

In [None]:
# Look at influential edges related to the high risk user node (not necessarily connected to that node)

print(f"\nTop-5 influential edges for node {node_id}:")
edge_scores = explanation.edge_mask
edge_idx_T  = data.edge_index.t()
top = torch.topk(edge_scores, 5)
for s, idx in zip(top.values, top.indices):
    u, v = edge_idx_T[idx].tolist()
    print(f"{u:>2} → {v:<2}  score={s:.3f}")

In [9]:
from helper_tabular import TabularData, NodeCategories

# Get tabular data for easier, further inspection
node_categories = NodeCategories(meta)
tabular_data = TabularData(data, meta, explanation, node_id)
node_info = tabular_data.get_per_node_info()
edge_info = tabular_data.get_per_edge_info()

In [None]:
print(node_info)

In [None]:
print(edge_info)

## Define Targets

In [None]:
# We want to be able to exploit the explainer's info about the graph.
# Above we looked at high risk user info. Now let's look at high value resources.
# Then we'll get actionable insights for both.

# 1) decide resource target(s)
targets = meta.get('high_value', []) or meta['resources']
targets

In [None]:
from helper_tabular import explain_resource_paths

# 2) for each high-value resource, get the top K conduit edges
all_hv_reports = {}
for rid in targets:
    hv_df, hv_nodes = explain_resource_paths(data, meta, explainer, rid, top_k=10)
    all_hv_reports[rid] = hv_df        # store for dashboards / bullets
    print(f"\nTop paths to resource {rid}")
    print(hv_df[['src','dst','importance','kind']])

# also, get the top K user edges
topk_edges = tabular_data.get_topk_edges()

## Create an LLM explanation summary from top edges

In [14]:
from helper_llm_explain import explain_edges_with_llm, build_edge_sentence_fn
import json

to_sentence = build_edge_sentence_fn(node_categories.users, node_categories.systems, node_categories.resources)

# make API call to LMM to get resource and user edge actions
resource_bullets = "\n".join("• "+to_sentence(r) for _, r in hv_df.iterrows())
user_bullets = "\n".join("• "+to_sentence(r) for _, r in topk_edges.iterrows())
report_resource = explain_edges_with_llm(resource_bullets)
report_user = explain_edges_with_llm(user_bullets)

In [None]:
print("=== Resource Report ===")
print(resource_bullets)
print("-"*100)
print(json.dumps(report_resource, indent=2))

In [None]:
print("=== User Report ===")
print(user_bullets)
print("-"*100)
print(json.dumps(report_user, indent=2))

## Visualise the graph

In [None]:
from visualize_graph import visualize_graph
visualize_graph(data, meta, top_users, bc)