### Setup

In [2]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import pickle 
import torch as th
import torch.nn.functional as F
import torch
import networkx as nx
import pandas as pd
import numpy as np
import dgl 

import torch_geometric
from torch_geometric.explain import Explainer, CaptumExplainer, DummyExplainer, GNNExplainer
from torch_geometric.explain.metric import *
from torch_geometric.nn.models.basic_gnn import GraphSAGE
from torch_geometric.utils import from_dgl
from tqdm import tqdm
from torch_geometric.explain import ModelConfig
import scienceplots

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('../../interm/label_encoders.pkl', 'rb') as f:
    encoders = pickle.load(f)
    
def view_metrics(metrics_list, legend=None, s=3):
    for i, metrics in enumerate(metrics_list):
        if legend:
            print(f'\n{legend[i]}')
        
        fp, fn = metrics['softmask fidelity']
        print(f'fid+ : {fp:.4f}\tfid- : {fn:.4f}\n')
        print(f"{'Class':<15}{'fid+':>10}{'fid-':>10}{'char':>10}")
        for attack in encoders['Attack'].classes_:
            fp, fn, c = metrics[f'softmask fidelity {attack}']
            print(f"{attack:<15}{fp:>10.3f}{fn:>10.3f}{c:>10.3f}")
    
    # plt.style.use(['science','no-latex'])
    with plt.style.context('science'): 
           
        for metrics in metrics_list:
            plt.plot(metrics['s'], metrics['fid-'])
            plt.scatter(metrics['s'], metrics['fid-'], s=s)
        
        plt.title('Sparsity Vs Fidelity-')
        if legend: plt.legend(legend)
        plt.show()
        
        for metrics in metrics_list:
            plt.plot(metrics['s'], metrics['fid+'])
            plt.scatter(metrics['s'], metrics['fid+'], s=s)
        
        plt.title('Sparsity Vs Fidelity+')
        if legend: plt.legend(legend)
        plt.show()
        
        for metrics in metrics_list:
            plt.plot(metrics['s'], metrics['c'])
            plt.scatter(metrics['s'], metrics['c'], s=s)
        
        plt.title('Sparsity Vs Characterisation Score')
        if legend: plt.legend(legend)
        plt.show()
        
def masked_prediction(mask, model, G, hardmask=True):
    if not hardmask:
        inv_mask = 1-mask
    else:
        inv_mask = ~mask
        
    y_pred = model(G.x[:, :49], G.edge_index).argmax(axis=1)
    ym_pred = model(G.x[:, :49]*mask, G.edge_index).argmax(axis=1)
    ymi_pred = model(G.x[:, :49]*inv_mask, G.edge_index).argmax(axis=1)
    return y_pred, ym_pred, ymi_pred


def fidelities(y_pred, y_mask, y_imask, y):
    fn = ((y_pred == y).float() - (y_mask == y).float()).abs().mean()
    fp = ((y_pred == y).float() - (y_imask == y).float()).abs().mean()
    return fp, fn
        

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
test = pd.read_csv('../../interm/BoT_test.csv')
attrs = [c for c in test.columns if c not in ("src", "dst", "Attack", "x", "IPV4_SRC_ADDR_metadata", "L4_SRC_PORT_metadata", 
                                              "IPV4_DST_ADDR_metadata", "L4_DST_PORT_metadata") 
        #  and not c.endswith('_metadata')
         ]
test['x'] = test[attrs].values.tolist()
test[:1]


Unnamed: 0,FLOW_START_MILLISECONDS,FLOW_END_MILLISECONDS,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,TCP_FLAGS,CLIENT_TCP_FLAGS,...,Attack,src,dst,x,FLOW_START_MILLISECONDS_metadata,FLOW_END_MILLISECONDS_metadata,IPV4_SRC_ADDR_metadata,L4_SRC_PORT_metadata,IPV4_DST_ADDR_metadata,L4_DST_PORT_metadata
0,-0.236904,-0.236926,-0.503789,106,-0.17034,-0.2804,-0.071477,-0.149842,1.257089,1.728738,...,0,192.168.100.3:-1.0586554,192.168.100.149:2.6106632,"[-0.23690394, -0.23692596, -0.5037887, 106.0, ...",1526968000000.0,1526968000000.0,192.168.100.3,80.0,192.168.100.149,34502.0


In [4]:
def to_graph(data, linegraph=True):
    G = nx.from_pandas_edgelist(data, source='src', 
                                target='dst', 
                                edge_attr=['x', 'Attack'], 
                                create_using=nx.MultiGraph()) 
    
    G = G.to_directed()
    g = dgl.from_networkx(G, edge_attrs=[ 'x', 'Attack'])
    if linegraph:
        return g.line_graph(shared=True)
    else:
        return g

model = GraphSAGE(49,
                  hidden_channels=256,
                  out_channels=5,
                  num_layers=3).to(device)

model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))
model.eval()

  model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))


GraphSAGE(49, 5, num_layers=3)

In [5]:
G = to_graph(test)
G.ndata['x'][0][-1], G.ndata['x'][0][-2] # unscaled start and stop times


(tensor(1.5270e+12), tensor(1.5270e+12))

### Motifs

In [19]:
import torch, networkx as nx, dgl
from torch_geometric.transforms import LineGraph
from torch_geometric.utils import from_dgl

# 1) Build NX, then RELABEL to 0..N-1 to avoid gaps/off-by-one
nx_g = nx.from_pandas_edgelist(
    test, source='src', target='dst',
    edge_attr=['x', 'Attack'],
    create_using=nx.DiGraph()
)
nx_g = nx.convert_node_labels_to_integers(nx_g, ordering='sorted')

# 2) DGL graph + edge motifs (on *edges*)
dgl_g = dgl.from_networkx(nx_g, edge_attrs=['x', 'Attack'])
src, dst = dgl_g.edges()
out_deg = dgl_g.out_degrees()
in_deg  = dgl_g.in_degrees()

scanning_star_nodes = (out_deg > 10).nonzero(as_tuple=True)[0]
fan_nodes           = (in_deg  > 10).nonzero(as_tuple=True)[0]

is_star = torch.isin(src, scanning_star_nodes).to(torch.uint8)
is_fan  = torch.isin(dst, fan_nodes).to(torch.uint8)
dgl_g.edata['is_star'] = is_star
dgl_g.edata['is_fan']  = is_fan

# 3) Convert to PyG and ensure num_nodes is consistent
pyg_g = from_dgl(dgl_g)
pyg_g.num_nodes = int(pyg_g.edge_index.max()) + 1  # guard against off-by-one

# 4) Line graph (each original edge -> one LG node)
pyg_lg = LineGraph(force_directed=True)(pyg_g)

# 5) Map edge motifs to LG node features (reuse DGL edata — same edge order)
E = pyg_g.edge_index.size(1)
base_x = pyg_lg.x if pyg_lg.x is not None else torch.zeros((E, 0), dtype=torch.float)
motifs = torch.stack([dgl_g.edata['is_star'], dgl_g.edata['is_fan']], dim=1).float()
pyg_lg.x = torch.cat([motifs, base_x], dim=1)

# Sanity check: first two columns are your motif flags
pyg_lg.x[:, 0], pyg_lg.x[:, 1]


(tensor([1., 1., 1.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 1., 0., 1.]))

In [20]:
star_motifs = []
for hub in scanning_star_nodes.tolist():
    # find edges with this hub as source
    lg_nodes = (src == hub).nonzero(as_tuple=True)[0].tolist()
    if lg_nodes:  # only if non-empty
        star_motifs.append(lg_nodes)

fan_motifs = []
for sink in fan_nodes.tolist():
    # find edges with this sink as target
    lg_nodes = (dst == sink).nonzero(as_tuple=True)[0].tolist()
    if lg_nodes:
        fan_motifs.append(lg_nodes)

In [18]:
len(star_motifs), len(fan_motifs)

(45, 47)

In [21]:
pyg_lg.x[:, 0].sum(), pyg_lg.x[:, 1].sum()

(tensor(1975.), tensor(4727.))

### NIDS-GNNExplainer
motif coherence reward $=  - \lambda_{mc} \sum_{g \in \text{motifs}} || m_g ||_2$
temporal smoothness penalty $$

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.explain import GNNExplainer

class CustomGNNExplainer(GNNExplainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # custom params
        self.tv_coef = 0
        self.motif_coef = 0
        self.leak_coef = 0

        # external info you must provide
        self.node_times = None        # tensor of time indices per node
        self.motif_groups = []        # list of lists of node indices
        self.leak_feat_idx = None     # indices of features that should not leak

    def additional_loss_terms(self, node_mask, feat_mask):
        reg = 0

        # (1) Temporal smoothness penalty
        if self.node_times is not None:
            order = torch.argsort(self.node_times)
            diffs = torch.abs(node_mask[order][1:] - node_mask[order][:-1])
            reg = reg + self.tv_coef * diffs.sum()

        # (2) Motif group coherence (group-lasso style)
        for g in self.motif_groups:
            # - not + ?
            reg = reg - self.motif_coef * torch.norm(node_mask[g], p=2)

        # (3) Leakage penalty
        if self.leak_feat_idx is not None and feat_mask is not None:
            reg = reg + self.leak_coef * feat_mask[self.leak_feat_idx].abs().sum()

        return reg

    def loss(self, log_logits, pred_label, node_mask, feat_mask):
        """Override base loss by adding extra penalties."""
        base_loss = super().loss(log_logits, pred_label, node_mask, feat_mask)
        reg_loss = self.additional_loss_terms(node_mask, feat_mask)
        return base_loss + reg_loss


In [27]:
G.x[:, :49].shape

torch.Size([29760, 49])

In [None]:
explainer = Explainer(
    model=model,
    algorithm=CustomGNNExplainer(epochs=50),
    # explanation_type='model',
    explanation_type='phenomenon',
    node_mask_type='attributes',
    edge_mask_type=None,
    model_config=ModelConfig(
        mode='multiclass_classification',
        task_level='node',
        return_type='raw',
    ),
)  

G = from_dgl(G)

explainer.node_times = G.x[:, 50] # start times  
explainer.motif_groups = star_motifs + fan_motifs
explainer.leak_feat_idx = torch.tensor([0])           # pretend feature 0 leaks

explainer.tv_coef = 1
explainer.motif_coef = 0.5
explainer.leak_coef = 2

explanation = explainer(
        x=G.x[:, :49].to(device),
        edge_index=G.edge_index.to(device),
        target=G.Attack,
)

explanation

Explanation(node_mask=[29760, 49], target=[29760], x=[29760, 49], edge_index=[2, 5405122])

### Performance

In [None]:
import copy
from torch_geometric.explain.metric import fidelity, characterization_score

metrics = {'fid+': [], 'fid-': [], 's': [], 'c': [], 'k': []}

explanation_cp = copy.deepcopy(explanation)

for s in tqdm(np.arange(0.1, 1, 0.1)):
    flat_mask = explanation.node_mask.flatten()
    k = int(s * flat_mask.numel())
    threshold = torch.topk(flat_mask, k).values[-1]
    
    new_mask = (explanation.node_mask >= threshold).float()
    explanation_cp.node_mask = new_mask
    
    fp, fn = fidelity(explainer, explanation_cp)
    metrics['fid+'].append(fp)
    metrics['fid-'].append(fn)
    
    c = characterization_score(fp, fn) if (fp * fn) != 0 else 0
    metrics['c'].append(c)
    metrics['s'].append(s)
    metrics['k'].append(k)
    
    
metrics['softmask fidelity'] = fidelity(explainer, explanation)

y_pred, ym_pred, ymi_pred = masked_prediction(
    explanation.node_mask, model, G, hardmask=False)

for idx in range(5):
    attack = encoders['Attack'].inverse_transform([idx])[0]
    fp, fn = fidelities(y_pred= y_pred == idx, 
                        y_mask= ym_pred == idx, 
                        y_imask= ymi_pred == idx,
                        y= G.Attack==idx)

    w = (G.Attack==idx).float().mean()
    c = characterization_score(fp, fn, 
                               pos_weight=w, 
                               neg_weight=1-w) if fp*fn > 0 else 0
    
    metrics[f'softmask fidelity {attack}'] = fp, fn, c
     
view_metrics([metrics])

  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [01:13<00:00,  8.15s/it]
