In [2]:
import torch_scatter
import torch_sparse
import torch_cluster
import torch_spline_conv
import torch_geometric
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [3]:
data = pd.read_csv('./Dataset/NF-CSE-CIC-IDS2018.csv')
print(data['Attack'].value_counts())

Attack
Benign                      7190742
DDOS attack-HOIC             467052
DoS attacks-Hulk             187046
DDoS attacks-LOIC-HTTP       133048
Bot                           62116
Infilteration                 50089
SSH-Bruteforce                41152
DoS attacks-GoldenEye         12002
FTP-BruteForce                11307
DoS attacks-SlowHTTPTest       6063
DoS attacks-Slowloris          4086
Brute Force -Web                956
DDOS attack-LOIC-UDP            883
Brute Force -XSS                407
SQL Injection                   198
Name: count, dtype: int64


In [4]:
data.drop(columns=['PROTOCOL', 'L7_PROTO', 'TCP_FLAGS', 'CLIENT_TCP_FLAGS', 'SERVER_TCP_FLAGS', 'ICMP_TYPE', 'ICMP_IPV4_TYPE', \
                   'DNS_QUERY_ID', 'DNS_QUERY_TYPE', 'DNS_TTL_ANSWER', 'FTP_COMMAND_RET_CODE'],inplace=True)


In [5]:
print(data.Label.value_counts())

Label
0.0    7190742
1.0     976405
Name: count, dtype: int64


In [6]:
data['IPV4_SRC_ADDR'] = data.IPV4_SRC_ADDR.apply(str)
data['L4_SRC_PORT'] = data.L4_SRC_PORT.apply(int)
data['L4_SRC_PORT'] = data.L4_SRC_PORT.apply(str)
data['IPV4_DST_ADDR'] = data.IPV4_DST_ADDR.apply(str)
data['L4_DST_PORT'] = data.L4_DST_PORT.apply(int)
data['L4_DST_PORT'] = data.L4_DST_PORT.apply(str)
data['IPV4_SRC_ADDR'] = data['IPV4_SRC_ADDR'] + ':' + data['L4_SRC_PORT']
data['IPV4_DST_ADDR'] = data['IPV4_DST_ADDR'] + ':' + data['L4_DST_PORT']
data.rename(columns={"IPV4_SRC_ADDR": "saddr"},inplace = True)
data.rename(columns={"IPV4_DST_ADDR": "daddr"},inplace = True)
data.drop(columns=['L4_SRC_PORT', 'L4_DST_PORT'],inplace=True)

In [7]:
print(data.head)

<bound method NDFrame.head of                          saddr               daddr  IN_BYTES  IN_PKTS  \
0            13.58.98.64:40894     172.31.69.25:22    3164.0     23.0   
1        213.202.230.143:29622  172.31.66.103:3389    1919.0     14.0   
2            172.31.66.5:65456       172.31.0.2:53     116.0      2.0   
3           172.31.64.92:57918       172.31.0.2:53      70.0      1.0   
4           18.219.32.43:63269     172.31.69.25:80     232.0      5.0   
...                        ...                 ...       ...      ...   
8167142     172.31.66.16:59566       172.31.0.2:53      75.0      1.0   
8167143     172.31.68.26:61394       172.31.0.2:53      63.0      1.0   
8167144    52.14.136.135:61501     172.31.69.25:80     232.0      5.0   
8167145     172.31.69.24:60678  172.31.69.14:15002      44.0      1.0   
8167146     172.31.65.86:53240    23.36.33.118:443     971.0      9.0   

         OUT_BYTES  OUT_PKTS  FLOW_DURATION_MILLISECONDS  DURATION_IN  \
0           3765.0  

In [8]:
label_ground_truth = data[["saddr", "daddr", "Label"]]
class_ground_truth = data[["saddr", "daddr", "Attack"]]
# data = pd.get_dummies(data, columns = ['flgs_number','state_number', 'proto_number']) # One Hot Encoding for categorical data

In [9]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                          saddr               daddr  IN_BYTES  IN_PKTS  \
0            13.58.98.64:40894     172.31.69.25:22    3164.0     23.0   
1        213.202.230.143:29622  172.31.66.103:3389    1919.0     14.0   
2            172.31.66.5:65456       172.31.0.2:53     116.0      2.0   
3           172.31.64.92:57918       172.31.0.2:53      70.0      1.0   
4           18.219.32.43:63269     172.31.69.25:80     232.0      5.0   
...                        ...                 ...       ...      ...   
8167142     172.31.66.16:59566       172.31.0.2:53      75.0      1.0   
8167143     172.31.68.26:61394       172.31.0.2:53      63.0      1.0   
8167144    52.14.136.135:61501     172.31.69.25:80     232.0      5.0   
8167145     172.31.69.24:60678  172.31.69.14:15002      44.0      1.0   
8167146     172.31.65.86:53240    23.36.33.118:443     971.0      9.0   

         OUT_BYTES  OUT_PKTS  FLOW_DURATION_MILLISECONDS  DURATION_IN  \
0           3765.0  

In [10]:
scaler = StandardScaler()
cols_to_norm = list(set(list(data.iloc[:, 2:].columns ))  - set(list(['Label']))  - set(list(['Attack'])) )
print(data[cols_to_norm].describe()) # Check if there's any too large value

  sqr = _ensure_numeric((avg - values) ** 2)


       MAX_IP_PKT_LEN       IN_PKTS  SRC_TO_DST_AVG_THROUGHPUT      IN_BYTES  \
count    8.167147e+06  8.167147e+06               8.167147e+06  8.167147e+06   
mean     5.642520e+02  2.292705e+01               4.750216e+06  1.785934e+03   
std      5.388275e+02  1.073211e+03               1.055848e+07  6.932865e+04   
min      2.800000e+01  1.000000e+00               0.000000e+00  2.800000e+01   
25%      1.000000e+02  1.000000e+00               5.200000e+05  6.900000e+01   
50%      1.910000e+02  4.000000e+00               9.920000e+05  1.800000e+02   
75%      1.189000e+03  9.000000e+00               9.096000e+06  1.460000e+03   
max      6.521200e+04  2.430310e+05               3.520027e+09  5.760577e+07   

       NUM_PKTS_1024_TO_1514_BYTES  SHORTEST_FLOW_PKT     OUT_BYTES  \
count                 8.167147e+06       8.167147e+06  8.167147e+06   
mean                  4.202104e+00       5.359188e+01  6.907992e+03   
std                   3.239484e+02       3.617786e+01  4.863405e+0

In [11]:
data[cols_to_norm] = data[cols_to_norm].clip(lower=-1e9, upper=1e9)
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [12]:
X_train, X_test, y_train, y_test = train_test_split(
     data, label_ground_truth, test_size=0.4, random_state=42, stratify=label_ground_truth.Label)
print(len(X_train))
print(len(X_test))

2450144
1143400


In [13]:
print(X_test.Label.value_counts())

Label
0.0    1006703
1.0     136697
Name: count, dtype: int64


In [None]:
X_train['h'] = X_train[ cols_to_norm ].values.tolist()

In [None]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
attack_labels = le.fit_transform(X_train['Attack'])
class_map = le.classes_
print(class_map)
print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))

In [None]:
# Convert NetworkX graph to PyG graph
G_nx = nx.from_pandas_edgelist(X_train, "saddr", "daddr", ['h', 'Label'], create_using=nx.MultiDiGraph())
G_pyg = from_networkx(G_nx)

num_nodes = G_pyg.num_nodes
num_edges = G_pyg.num_edges

G_pyg.x = th.ones(num_nodes, len(X_train['h'].iloc[0])) 

edge_attr_list = []
edge_label_list = []

for u, v, key, data in G_nx.edges(keys=True, data=True):
    edge_attr_list.append(data['h']) 
    edge_label_list.append(data['Label']) 

G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)
G_pyg.edge_class = th.tensor(attack_labels, dtype=th.long)

print("Number of edges in G_pyg:", G_pyg.num_edges)
print("Number of node in G_pyg:", G_pyg.num_nodes)
print("Shape of node in G_pyg:", G_pyg.x.shape)
print("Shape of edge attr in G_pyg:", G_pyg.edge_attr.shape)
print("Shape of edge label in G_pyg:", G_pyg.edge_label.shape)
print("Shape of edge class in G_pyg:", G_pyg.edge_class.shape)

In [14]:

class EGraphSAGEConv(MessagePassing):
    def __init__(self, node_in_channels, edge_in_channels, out_channels):
        super(EGraphSAGEConv, self).__init__(aggr='mean')  # mean aggregation
        self.lin_node = nn.Linear(node_in_channels, out_channels)
        self.lin_edge = nn.Linear(edge_in_channels, out_channels)
        self.lin_update = nn.Linear(node_in_channels + out_channels, out_channels) # out_channels * 2

    def forward(self, x, edge_index, edge_attr):
        # x: Node features, edge_attr: Edge features, edge_index: Connectivity
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        if edge_attr is not None:
            if edge_attr.size(0) != edge_index.size(1):
                loop_attr = th.zeros((edge_index.size(1) - edge_attr.size(0), edge_attr.size(1))).to(edge_attr.device)
                edge_attr = th.cat([edge_attr, loop_attr], dim=0)
        else:
            print("edge_attr is unexist")
        
        # Propagate and aggregate neighbor information
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j represents the adjacent nodes of x
        # Compute messages by combining node and edge features
        return self.lin_node(x_j) + self.lin_edge(edge_attr)

    def update(self, aggr_out, x):
        # Update node features after message passing
        return self.lin_update(th.cat([x, aggr_out], dim=1))

class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = EGraphSAGEConv(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = EGraphSAGEConv(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)




In [15]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [16]:
th.cuda.empty_cache()

In [None]:
model = EGraphSAGE(node_in_channels=G_pyg.num_node_features, 
                   edge_in_channels=G_pyg.num_edge_features,
                   hidden_channels=128, 
                   out_channels=2).to(device)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

labels = G_pyg.edge_label.cpu().numpy()
class_weights = class_weight.compute_class_weight('balanced',
                                                  classes=np.unique(labels),
                                                  y=labels)

class_weights = th.FloatTensor(class_weights).cuda()
criterion = nn.CrossEntropyLoss(weight = class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=0.001)

In [None]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

G_pyg.edge_label = G_pyg.edge_label.to(device)
G_pyg.edge_attr = G_pyg.edge_attr.to(device)

def generate_edge_based_batches_with_node_expansion(graph, batch_size, min_nodes):
    num_edges = graph.edge_index.size(1)  # Get total number of edges
    edge_indices = th.arange(num_edges)   # Create list of edge indices
    num_edges_processed = 0
    
    while num_edges_processed < num_edges:
        # Select a batch of edges
        batch_edge_indices = edge_indices[num_edges_processed : min(num_edges_processed + batch_size, num_edges)]
        edge_index = graph.edge_index[:, batch_edge_indices]
        
        # Update the number of edges processed
        num_edges_processed += batch_size
        
        # Get the unique nodes associated with these edges
        batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()

        # Check if the batch has enough unique nodes
        while batch_nodes.size(0) < min_nodes:
            # Sample additional neighboring nodes to ensure diversity
            additional_edges = int(batch_size / 8)  # Ensure additional_edges is an integer
            batch_edge_indices = th.cat([batch_edge_indices, edge_indices[num_edges_processed : min(num_edges_processed + additional_edges, num_edges)]])
            edge_index = graph.edge_index[:, batch_edge_indices]
            batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()
            num_edges_processed += additional_edges

            # Avoid potential infinite loops by breaking if no more edges can be added
            if num_edges_processed >= num_edges:
                break

        # Create subgraph from the selected nodes and edges
        edge_index, _, edge_mask = subgraph(batch_nodes, graph.edge_index, relabel_nodes=True, return_edge_mask=True)

        # Use edge_mask to select edge attributes and labels
        edge_attr = graph.edge_attr[edge_mask]
        edge_label = graph.edge_label[edge_mask]

        yield batch_nodes, edge_index, edge_attr, edge_label

batch_size = 64
for epoch in range(5):
    print(f'epoch : {epoch}')
    all_preds = []
    all_labels = []
    
    try:
        for batch_idx, (batch_nodes, edge_index, edge_attr, edge_label) in enumerate(generate_edge_based_batches_with_node_expansion(G_pyg, batch_size, 20)):
            # print(f"Processing epoch {epoch}, batch {batch_idx} with {batch_nodes.size(0)} nodes and {edge_index.size(1)} edges")
            batch = Data(x=G_pyg.x[batch_nodes], edge_index=edge_index, edge_attr=edge_attr, edge_label=edge_label)
            
            if batch.edge_index.size(1) == 0 or batch.edge_label.size(0) == 0:
                print(f"Warning: Empty batch at batch {batch_idx}")
                continue
                
            if batch is None or batch.num_nodes == 0:
                print(f"Warning: Empty batch at Batch {batch_idx}")
                continue 
    
            if th.isnan(batch.x).any() or th.isinf(batch.x).any() or th.isnan(batch.edge_attr).any() or th.isinf(batch.edge_attr).any():
                print(f"Warning: batch x and edge_attr contains NaN or Inf at Batch {batch_idx}")
                continue 
                
            try:
                batch = batch.to(device)
            except Exception as batch_error:
                print(f"Error moving batch to device at Batch {batch_idx}: {batch_error}")
                continue
            
            try:
                out = model(batch)
    
                if th.isnan(out).any() or th.isinf(out).any():
                    print(f"Warning: out contains NaN or Inf at Batch {batch_idx}")
                    continue 
                all_preds.append(out)
                all_labels.append(batch.edge_label)
    
                loss = criterion(out, batch.edge_label)
                if th.isnan(loss):
                    print(f"loss: {loss}")
                    print(f"out: {out}")
                    print(f"edge_labels: {batch.edge_label}")
                    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            except Exception as forward_error:
                print(f"Error during forward/backward pass at Epoch {epoch}, Batch {batch_idx}: {forward_error}")
                continue
        
        all_preds = th.cat(all_preds)
        all_labels = th.cat(all_labels)
        
        epoch_accuracy = compute_accuracy(all_preds, all_labels)
        print(f'Epoch {epoch}, Loss: {loss:.4f}, Accuracy: {epoch_accuracy:.4f}')
        print(all_labels.shape)

    except Exception as e:
        print(f"An error occurred at epoch {epoch}, batch {batch_idx}: {str(e)}")
print("Training is over")

In [None]:
th.save(model.state_dict(), "./Weights/GNN_model_weights_CICIDS2018_subset_2.pth")

In [17]:
from sklearn.preprocessing import LabelEncoder

test_le = LabelEncoder()
X_test['Attack'] = test_le.fit_transform(X_test['Attack'])
test_class_map = test_le.classes_
print(test_class_map)
print("Attack label mapping:", dict(zip(test_class_map, range(len(test_class_map)))))

['Benign' 'Bot' 'Brute Force -Web' 'Brute Force -XSS' 'DDOS attack-HOIC'
 'DDOS attack-LOIC-UDP' 'DDoS attacks-LOIC-HTTP' 'DoS attacks-GoldenEye'
 'DoS attacks-Hulk' 'DoS attacks-SlowHTTPTest' 'DoS attacks-Slowloris'
 'FTP-BruteForce' 'Infilteration' 'SQL Injection' 'SSH-Bruteforce']
Attack label mapping: {'Benign': 0, 'Bot': 1, 'Brute Force -Web': 2, 'Brute Force -XSS': 3, 'DDOS attack-HOIC': 4, 'DDOS attack-LOIC-UDP': 5, 'DDoS attacks-LOIC-HTTP': 6, 'DoS attacks-GoldenEye': 7, 'DoS attacks-Hulk': 8, 'DoS attacks-SlowHTTPTest': 9, 'DoS attacks-Slowloris': 10, 'FTP-BruteForce': 11, 'Infilteration': 12, 'SQL Injection': 13, 'SSH-Bruteforce': 14}


In [20]:
X_test['h'] = X_test[ cols_to_norm ].values.tolist()

G_nx_test = nx.from_pandas_edgelist(X_test, "saddr", "daddr", ['h', 'Label', 'Attack'], create_using=nx.MultiDiGraph())

G_pyg_test = from_networkx(G_nx_test)

test_num_nodes = G_pyg_test.num_nodes
test_num_edges = G_pyg_test.num_edges

G_pyg_test.x = th.ones(test_num_nodes, len(X_test['h'].iloc[0]))

test_edge_attr_list = []
test_edge_label_list = []
test_edge_class_list = []

for u, v, key, data in G_nx_test.edges(keys=True, data=True):
    test_edge_attr_list.append(data['h']) 
    test_edge_label_list.append(data['Label']) 
    test_edge_class_list.append(data['Attack'])

G_pyg_test.edge_attr = th.tensor(test_edge_attr_list, dtype=th.float32)
G_pyg_test.edge_label = th.tensor(test_edge_label_list, dtype=th.long)
G_pyg_test.edge_class = th.tensor(test_edge_class_list, dtype=th.long)

print("Number of edges in G_pyg_test:", G_pyg_test.num_edges)
print("Number of node in G_pyg_test:", G_pyg_test.num_nodes)
print("Shape of node in G_pyg_test:", G_pyg_test.x.shape)
print("Shape of edge attr in G_pyg_test:", G_pyg_test.edge_attr.shape)
print("Shape of edge label in G_pyg_test:", G_pyg_test.edge_label.shape)
print("Shape of edge class in G_pyg_test:", G_pyg_test.edge_class.shape)

Number of edges in G_pyg_test: 1143400
Number of node in G_pyg_test: 1080734
Shape of node in G_pyg_test: torch.Size([1080734, 28])
Shape of edge attr in G_pyg_test: torch.Size([1143400, 28])
Shape of edge label in G_pyg_test: torch.Size([1143400])
Shape of edge class in G_pyg_test: torch.Size([1143400])


In [22]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

new_model_2 = EGraphSAGE(node_in_channels=G_pyg_test.num_node_features, 
                       edge_in_channels=G_pyg_test.num_edge_features,
                       hidden_channels=128, 
                       out_channels=2).to(device)

new_model_2.load_state_dict(th.load("./Weights/GNN_model_weights_CICIDS2018_subset_2.pth", weights_only=True))

def generate_edge_based_batches_with_node_expansion(graph, batch_size, min_nodes):
    num_edges = graph.edge_index.size(1) 
    edge_indices = th.arange(num_edges)  
    num_edges_processed = 0
    
    while num_edges_processed < num_edges:
        # Select a batch of edges
        batch_edge_indices = edge_indices[num_edges_processed : min(num_edges_processed + batch_size, num_edges)]
        edge_index = graph.edge_index[:, batch_edge_indices]
        
        # Update the number of edges processed
        num_edges_processed += batch_size
        
        # Get the unique nodes associated with these edges
        batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()

        # Check if the batch has enough unique nodes
        while batch_nodes.size(0) < min_nodes:
            # Sample additional neighboring nodes to ensure diversity
            additional_edges = int(batch_size / 8)  # Ensure additional_edges is an integer
            batch_edge_indices = th.cat([batch_edge_indices, edge_indices[num_edges_processed : min(num_edges_processed + additional_edges, num_edges)]])
            edge_index = graph.edge_index[:, batch_edge_indices]
            batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()
            num_edges_processed += additional_edges

            # Avoid potential infinite loops by breaking if no more edges can be added
            if num_edges_processed >= num_edges:
                break

        # Create subgraph from the selected nodes and edges
        edge_index, _, edge_mask = subgraph(batch_nodes, graph.edge_index, relabel_nodes=True, return_edge_mask=True)

        # Use edge_mask to select edge attributes and labels
        edge_attr = graph.edge_attr[edge_mask]
        edge_label = graph.edge_label[edge_mask]
        edge_class = graph.edge_class[edge_mask]

        yield batch_nodes, edge_index, edge_attr, edge_label, edge_class


new_model_2.eval()

all_test_preds = []
all_test_labels = []
all_test_classes = []
attack_class_performance = {attack_type: {'correct': 0, 'incorrect': 0} for attack_type in test_class_map}

batch_size = 64

print("inference start")
with th.no_grad():
    for batch_idx, (batch_nodes, edge_index, edge_attr, edge_label, edge_class) in enumerate(generate_edge_based_batches_with_node_expansion(G_pyg_test, batch_size, 20)):
        # print(f"Processing batch {batch_idx} with {batch_nodes.size(0)} nodes and {edge_index.size(1)} edges")
        batch = Data(x=G_pyg_test.x[batch_nodes], edge_index=edge_index, edge_attr=edge_attr, edge_label=edge_label)
        
        if batch.edge_index.size(1) == 0 or batch.edge_label.size(0) == 0:
            print(f"Warning: Empty batch at batch {batch_idx}")
            continue
            
        if batch is None or batch.num_nodes == 0:
            print(f"Warning: Empty batch at Batch {batch_idx}")
            continue

        if th.isnan(batch.x).any() or th.isinf(batch.x).any() or th.isnan(batch.edge_attr).any() or th.isinf(batch.edge_attr).any():
            print(f"Warning: batch x and edge_attr contains NaN or Inf at Batch {batch_idx}")
            continue
            
        try:
            batch = batch.to(device)
        except Exception as batch_error:
            print(f"Error moving batch to device at Batch {batch_idx}: {batch_error}")
            continue
        
        try:
            out = new_model_2(batch)

            if th.isnan(out).any() or th.isinf(out).any():
                print(f"Warning: out contains NaN or Inf at Batch {batch_idx}")
                continue 
            
            all_test_preds.append(out)
            all_test_labels.append(edge_label)
            all_test_classes.append(edge_class)

            pred = out.argmax(dim=1) 

            for i in range(len(pred)):
                true_label = edge_label[i].item()
                predicted_label = pred[i].item()
                attack_type = test_le.inverse_transform([edge_class[i].item()])[0] 

                if true_label == 0 and attack_type != 'Benign':
                    print('this sample is Benign but label is wrong')
                
                if true_label == predicted_label:
                    attack_class_performance[attack_type]['correct'] += 1
                else:
                    attack_class_performance[attack_type]['incorrect'] += 1
        except Exception as forward_error:
            print(f"Error during forward/backward pass at Batch {batch_idx}: {forward_error}")
            continue

print("inference done")
all_test_preds = th.cat(all_test_preds).to(device)
all_test_labels = th.cat(all_test_labels).to(device)
all_test_classes = th.cat(all_test_classes).to(device)

test_accuracy = compute_accuracy(all_test_preds, all_test_labels)
print(f'Test Accuracy: {test_accuracy:.4f}')



inference start
inference done
torch.Size([1146184, 2])
torch.Size([1146184])
torch.Size([1146184])
Test Accuracy: 0.5270


In [23]:
from sklearn.metrics import confusion_matrix

pred_labels = all_test_preds.argmax(dim=1)

pred_labels = pred_labels.cpu()
all_test_labels = all_test_labels.cpu()

cm = confusion_matrix(all_test_labels, pred_labels)

TN, FP, FN, TP = cm.ravel()

print(f'True Positives (TP): {TP}')
print(f'False Positives (FP): {FP}')
print(f'True Negatives (TN): {TN}')
print(f'False Negatives (FN): {FN}')

accuracy = (TP + TN) / (TP + TN + FP + FN)
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1_score:.4f}')


True Positives (TP): 5623
False Positives (FP): 409411
True Negatives (TN): 598395
False Negatives (FN): 132755
Accuracy: 0.5270
Precision: 0.0135
Recall: 0.0406
F1 Score: 0.0203


In [26]:
for attack_type, performance in attack_class_performance.items():
    total_samples = performance['correct'] + performance['incorrect']
    if attack_type != 'Benign':
        sum += total_samples
    accuracy = performance['correct'] / total_samples if total_samples > 0 else 0
    print(f"Attack Type: {attack_type}, Accuracy: {accuracy:.4f}, Total Samples: {total_samples}, Correct Samples: {performance['correct']}, Incorrect Samples: {performance['incorrect']}")

Attack Type: Benign, Accuracy: 0.5938, Total Samples: 1007806, Correct Samples: 598395, Incorrect Samples: 409411
Attack Type: Bot, Accuracy: 0.0002, Total Samples: 8772, Correct Samples: 2, Incorrect Samples: 8770
Attack Type: Brute Force -Web, Accuracy: 0.0000, Total Samples: 149, Correct Samples: 0, Incorrect Samples: 149
Attack Type: Brute Force -XSS, Accuracy: 0.0000, Total Samples: 51, Correct Samples: 0, Incorrect Samples: 51
Attack Type: DDOS attack-HOIC, Accuracy: 0.0000, Total Samples: 65972, Correct Samples: 1, Incorrect Samples: 65971
Attack Type: DDOS attack-LOIC-UDP, Accuracy: 0.0000, Total Samples: 135, Correct Samples: 0, Incorrect Samples: 135
Attack Type: DDoS attacks-LOIC-HTTP, Accuracy: 0.1121, Total Samples: 18991, Correct Samples: 2129, Incorrect Samples: 16862
Attack Type: DoS attacks-GoldenEye, Accuracy: 0.0369, Total Samples: 1706, Correct Samples: 63, Incorrect Samples: 1643
Attack Type: DoS attacks-Hulk, Accuracy: 0.0110, Total Samples: 26815, Correct Samples