### Setup

In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import pickle 
import torch as th
import torch.nn.functional as F
import torch
import networkx as nx
import pandas as pd
import numpy as np
import dgl 

import torch_geometric
from torch_geometric.explain import Explainer, CaptumExplainer, DummyExplainer, GNNExplainer
from torch_geometric.explain.metric import *
from torch_geometric.nn.models.basic_gnn import GraphSAGE
from torch_geometric.utils import from_dgl
from tqdm import tqdm
from torch_geometric.explain import ModelConfig
import scienceplots
from explanations import *
        

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test = pd.read_csv('../../interm/BoT_test.csv')
attrs = [c for c in test.columns if c not in ("src", "dst", "Attack", "x", "IPV4_SRC_ADDR_metadata", "L4_SRC_PORT_metadata", 
                                              "IPV4_DST_ADDR_metadata", "L4_DST_PORT_metadata") 
        #  and not c.endswith('_metadata')
         ]
test['x'] = test[attrs].values.tolist()
test[:1]


Unnamed: 0,FLOW_START_MILLISECONDS,FLOW_END_MILLISECONDS,PROTOCOL,L7_PROTO,IN_BYTES,IN_PKTS,OUT_BYTES,OUT_PKTS,TCP_FLAGS,CLIENT_TCP_FLAGS,...,Attack,src,dst,x,FLOW_START_MILLISECONDS_metadata,FLOW_END_MILLISECONDS_metadata,IPV4_SRC_ADDR_metadata,L4_SRC_PORT_metadata,IPV4_DST_ADDR_metadata,L4_DST_PORT_metadata
0,-0.236904,-0.236926,-0.503789,106,-0.17034,-0.2804,-0.071477,-0.149842,1.257089,1.728738,...,0,192.168.100.3:-1.0586554,192.168.100.149:2.6106632,"[-0.23690394, -0.23692596, -0.5037887, 106.0, ...",1526968000000.0,1526968000000.0,192.168.100.3,80.0,192.168.100.149,34502.0


In [3]:
def to_graph(data, linegraph=True):
    G = nx.from_pandas_edgelist(data, source='src', 
                                target='dst', 
                                edge_attr=['x', 'Attack'], 
                                create_using=nx.MultiGraph()) 
    
    G = G.to_directed()
    g = dgl.from_networkx(G, edge_attrs=[ 'x', 'Attack'])
    if linegraph:
        return g.line_graph(shared=True)
    else:
        return g

model = GraphSAGE(49,
                  hidden_channels=256,
                  out_channels=5,
                  num_layers=3).to(device)

model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))
model.eval()

  model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))


GraphSAGE(49, 5, num_layers=3)

In [4]:
G = to_graph(test)
G.ndata['x'][0][-1], G.ndata['x'][0][-2] # unscaled start and stop times


(tensor(1.5270e+12), tensor(1.5270e+12))

### Motifs

In [19]:
import torch, networkx as nx, dgl
from torch_geometric.transforms import LineGraph
from torch_geometric.utils import from_dgl

# 1) Build NX, then RELABEL to 0..N-1 to avoid gaps/off-by-one
nx_g = nx.from_pandas_edgelist(
    test, source='src', target='dst',
    edge_attr=['x', 'Attack'],
    create_using=nx.DiGraph()
)
# nx_g = nx.convert_node_labels_to_integers(nx_g, ordering='sorted') # ! ?

# 2) DGL graph + edge motifs (on *edges*)
dgl_g = dgl.from_networkx(nx_g, edge_attrs=['x', 'Attack'])
src, dst = dgl_g.edges()
out_deg = dgl_g.out_degrees()
in_deg  = dgl_g.in_degrees()

scanning_star_nodes = (out_deg > 10).nonzero(as_tuple=True)[0]
fan_nodes           = (in_deg  > 10).nonzero(as_tuple=True)[0]

is_star = torch.isin(src, scanning_star_nodes).to(torch.uint8)
is_fan  = torch.isin(dst, fan_nodes).to(torch.uint8)
new = torch.vstack([is_star, is_fan])

print(dgl_g.edata['x'].shape)
dgl_g.edata['x'] = torch.hstack([dgl_g.edata['x'], new.T])
print(dgl_g.edata['x'].shape)

dgl_lg = dgl_g.line_graph(shared=True)
pyg_lg = from_dgl(dgl_g)
pyg_lg.num_nodes = int(pyg_lg.edge_index.max()) + 1

torch.Size([12631, 51])
torch.Size([12631, 53])


In [37]:
# star_motifs = []
# for hub in scanning_star_nodes.tolist():
#     # find edges with this hub as source
#     lg_nodes = (src == hub).nonzero(as_tuple=True)[0].tolist()
#     if lg_nodes:  # only if non-empty
#         star_motifs.append(lg_nodes)

# fan_motifs = []
# for sink in fan_nodes.tolist():
#     # find edges with this sink as target
#     lg_nodes = (dst == sink).nonzero(as_tuple=True)[0].tolist()
#     if lg_nodes:
#         fan_motifs.append(lg_nodes)

In [39]:
pyg_lg.x[:, 51].sum(), pyg_lg.x[:, 52].sum()

(tensor(1975.), tensor(4727.))

### NIDS-GNNExplainer
- motif coherence reward $= - \lambda_{mc} \cdot \sum_{g \in \text{motifs}} || m_g ||_2$
- temporal smoothness penalty =  $\lambda_{ts} \cdot \text{exp}(\frac{-(t_i - t_j)^2}{2\sigma^2}) ||m_i - m_j||^2$
- threshhold sparsity penalty = $1/N \sum \mathcal{I}(n > k) $

In [40]:

import torch
import torch.nn.functional as F
from torch_geometric.explain import GNNExplainer

class CustomGNNExplainer(GNNExplainer):
    
    params = {
        'ts_coef': 0,
        'motif_coef': 0,
        'sparsity_coef': 0,
        'sparsity_threshold': 0
    }
    
    epoch_metrics = {
        'temporal smoothness penalty': [],
        'sparsity penalty': [],
        'motif coherance reward': [],
        'base loss': [],
    }
    
    def __init__(
        self, 
        node_times, 
        motif_groups, 
        **kwargs
    ):
        super().__init__(**kwargs)
        self.params.update(kwargs)
        self.node_times = node_times 
        self.motif_groups = motif_groups

    def temporal_smoothness(self, node_mask):
        order = torch.argsort(self.node_times)
        times = self.node_times[order]
        time_diffs = (times[1:] - times[:-1])
        ordered_node_importances = node_mask[order].mean(axis=1)
        w_ij = torch.exp(- (time_diffs**2) / (2 * time_diffs.std()**2))
        spread = (w_ij * (ordered_node_importances[1:] 
                          - ordered_node_importances[:-1]  ** 2)).sum()
        
        return self.params['ts_coef'] * spread
    
    def motif_coherance(self, node_mask):
        return sum([
            self.params['motif_coef'] * torch.norm(node_mask[g], p=2)
            for g in self.motif_groups
        ]) / len(self.motif_groups)

    def sparsity(self, node_mask):
        sparsity = (node_mask > self.params['sparsity_threshhold']).float().mean()
        return self.params['sparsity_coef'] * sparsity

    def additional_loss_terms(self, node_mask):
        reg = 0

        ts = self.temporal_smoothness(node_mask)
        self.epoch_metrics['temporal smoothness penalty'].append(ts)
        print(f'temporal smoothness penalty {ts}')
        reg += ts

        mc = self.motif_coherance(node_mask)
        self.epoch_metrics['motif coherance reward'].append(mc)
        print(f'motif coherance reward: {mc}')
        reg -= mc
        
        sp = self.sparsity(node_mask)
        self.epoch_metrics['sparsity penalty'].append(sp)
        print(f'sparsity penalty {sp}')
        reg += sp
        
        return reg

    def plot_descent(self):
        with plt.style.context('science'): 
            for l in self.epoch_metrics.values():
                plt.plot(l)
            plt.legend(self.epoch_metrics.keys())
            plt.show()

    # def _loss(self, log_logits, pred_label, node_mask, feat_mask):
    def _loss(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        base_loss = super()._loss(y_hat, y)
        reg_loss = self.additional_loss_terms(self.node_mask)
        print(f'base_loss: {base_loss}')
        print(f'total loss: {base_loss + reg_loss}\n')
        self.epoch_metrics['base loss'].append(base_loss)
        return base_loss + reg_loss
    


In [None]:

metrics = {}
for attack, subG in yield_class_graphs(pyg_lg):
    if attack == 'Benign': 
        continue
    
    print(attack)

    is_star, is_fan = subG.x[:, 51], subG.x[:, 52]
    end_times, start_times = subG.x[:, 50], subG.x[:, 49]
    
    scanning_star_nodes = is_star.nonzero(as_tuple=True)[0]
    fan_nodes = is_fan.nonzero(as_tuple=True)[0]
    
    star_motifs = []
    for hub in scanning_star_nodes.tolist():
        lg_nodes = (src == hub).nonzero(as_tuple=True)[0].tolist()
        if lg_nodes: 
            star_motifs.append(lg_nodes)

    fan_motifs = []
    for sink in fan_nodes.tolist():
        lg_nodes = (dst == sink).nonzero(as_tuple=True)[0].tolist()
        if lg_nodes:
            fan_motifs.append(lg_nodes)
    
    explainer = Explainer(
        model=model,
        algorithm=CustomGNNExplainer(
            epochs=100, 
            node_times = start_times,
            motif_groups = (star_motifs + fan_motifs),
            tv_coef = 1.0,
            motif_coef = 0.01,
            sparsity_coef = 1.0,
            sparsity_threshhold = 0.5,
        ),
        explanation_type='phenomenon',
        node_mask_type='attributes',
        edge_mask_type=None,
        model_config=ModelConfig(
            mode='multiclass_classification',
            task_level='node',
            return_type='raw',
        ),
    )  
    
    print(subG.x.shape)
    print(subG.Attack.shape)

    x = subG.x[:, :49]
    explanation = explainer(
        x=x,
        edge_index=subG.edge_index.to(device),
        target=subG.Attack,
    )
    
    metrics[attack] = explanation 
    
    # softmask metrics
    fp, fn, c = evaluate_softmask(model, subG, explanation.node_mask)
    metrics[f'{attack} softmask metrics'] = fp, fn, c
    print(f'\tfp: {fp:.3f}')
    print(f'\tfn: {fn:.3f}')
    print(f'\tc: {c:.3f}')
    
    # sparsity curve
    metrics[f'{attack} sparsity curve'] = evaluate_sparsity_threshholds(
        model, subG, explanation.node_mask)
    
    # regularization curves
    print(attack)
    with plt.style.context('science'): 
        for m, l in explainer.algorithm.epoch_metrics.items():
            y = [ll.detach().numpy() for ll in l]
            plt.plot(y)

        plt.legend(explainer.algorithm.epoch_metrics.keys())
        plt.show()
    

DDoS
torch.Size([10168, 53])
torch.Size([10168])
temporal smoothness penalty -0.0


IndexError: index 10179 is out of bounds for dimension 0 with size 10168

In [None]:
# plot sparsity curves
with plt.style.context('science'): 
    for metric in ('fid-', 'fid+', 'c'):
        print(metric)
        for attack in le['Attack'].classes_:
            if attack == 'Benign': 
                continue

            m = metrics[f'{attack} sparsity curve']
            plt.plot(m['s'], m[metric])
        
        plt.legend(le['Attack'].classes_[1:])
        plt.show()   

In [23]:
with open('../../interm/nidse_metrics', 'wb') as f:
    pickle.dump(metrics, f)