### Setup

In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import pickle 
import torch as th
import torch.nn.functional as F
import torch
import networkx as nx
import pandas as pd
import numpy as np
import dgl 

import torch_geometric
from torch_geometric.explain import Explainer, CaptumExplainer, DummyExplainer, GNNExplainer
from torch_geometric.explain.metric import *
from torch_geometric.nn.models.basic_gnn import GraphSAGE
from torch_geometric.utils import from_dgl
from tqdm import tqdm
from torch_geometric.explain import ModelConfig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ModuleNotFoundError: No module named 'matplotlib'

In [27]:
test = pd.read_csv('../../interm/BoT_test.csv')
attrs = [c for c in test.columns if c not in ("src", "dst", "Attack", "x")]
test['x'] = test[attrs].values.tolist()

def to_graph(data, linegraph=True):
    G = nx.from_pandas_edgelist(data, source='src', 
                                target='dst', 
                                edge_attr=['x', 'Attack'], 
                                create_using=nx.MultiGraph())
    
    G = G.to_directed()
    g = dgl.from_networkx(G, edge_attrs=[ 'x', 'Attack'])
    if linegraph:
        return g.line_graph(shared=True)
    else:
        return g

model = GraphSAGE(
    49,
    hidden_channels=256,
    out_channels=5,
    num_layers=3,
).to(device)
model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))
model.eval()

  model.load_state_dict(th.load('../../interm/GraphSAGE_BoTIoT.pth'))


GraphSAGE(49, 5, num_layers=3)

In [None]:
G = to_graph(test, linegraph=False)
G = from_dgl(G)

# scaning-stars
scanning_stars = []
for i in tqdm(range(len(G.x))):
    mask = G.edge_index[0] == i
    scanners = G.edge_index[1][mask]
    if len(scanners) > 10:
        scanning_stars.append(i)

len(scanning_stars)

100%|██████████| 29760/29760 [00:02<00:00, 13586.45it/s]


85

In [35]:
import copy

X = copy.deepcopy(G.x)

motif_mask = torch.zeros(X.shape[0])
motif_mask[scanning_stars] = 1

X = torch.cat((X, [motif_mask]))
X.shape

TypeError: expected Tensor as element 1 in argument 0, but got list

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GNNExplainer

class GNNExplainerPP(GNNExplainer):
    def __init__(self, model, epochs=100,
                 tv_coef=1e-2, motif_coef=1e-2, leak_coef=1e-2,
                 leak_feat_idx=None, motif_groups=None, node_times=None, **kwargs):
        super().__init__(model, epochs=epochs, **kwargs)
        self.tv_coef = tv_coef
        self.motif_coef = motif_coef
        self.leak_coef = leak_coef
        self.leak_feat_idx = leak_feat_idx      # LongTensor of node feature indices
        self.motif_groups = motif_groups or []  # list of LongTensors of *node indices*
        self.node_times = node_times            # 1D tensor aligned with nodes

    def additional_loss_terms(self, node_mask, feat_mask):
        reg = 0

        # (1) Temporal smoothness on node mask
        if self.node_times is not None:
            order = torch.argsort(self.node_times)
            diffs = torch.abs(node_mask[order][1:] - node_mask[order][:-1])
            reg = reg + self.tv_coef * diffs.sum()

        # (2) Motif group-lasso: encourage selecting / dropping entire node groups
        for g in self.motif_groups:
            reg = reg + self.motif_coef * torch.norm(node_mask[g], p=2)

        # (3) Leakage penalty on identifier-type features
        if self.leak_feat_idx is not None and feat_mask is not None:
            reg = reg + self.leak_coef * feat_mask[self.leak_feat_idx].abs().sum()

        return reg

    def loss(self, pred, target, node_mask=None, feat_mask=None):
        pred_loss = F.nll_loss(pred, target)
        reg_loss = self.additional_loss_terms(node_mask, feat_mask)
        return pred_loss + reg_loss


In [None]:
# explainer = GNNExplainerPP(
#     model, epochs=200,
#     tv_coef=1e-2, motif_coef=5e-3, leak_coef=1e-2,
#     leak_feat_idx=torch.tensor([0,1,2]),   # e.g. identifier features
#     motif_groups=[torch.tensor([0,3,7])], # toy motif groups
#     edge_times=torch.rand(data.edge_index.size(1)) # toy edge times
# )

# node_idx = 0
# node_feat_mask, edge_mask = explainer.explain_node(node_idx, data.x, data.edge_index)