In [1]:
from branch_utils import *
from neural_network import NeuralNetwork
import torch, time
from sklearn.metrics import classification_report, confusion_matrix, silhouette_score, calinski_harabasz_score, davies_bouldin_score
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from branches import *
from utils import load_data


dataset_path = './Dataset/NF-UNSW-NB15-v3.csv'

numerical_cols = [
    "NUM_PKTS_128_TO_256_BYTES",
    "RETRANSMITTED_OUT_PKTS",
    "SRC_TO_DST_IAT_STDDEV",
    "SRC_TO_DST_SECOND_BYTES",
    "IN_PKTS",
    "LONGEST_FLOW_PKT",
    "NUM_PKTS_256_TO_512_BYTES",
    "DST_TO_SRC_IAT_AVG",
    "OUT_BYTES",
    "NUM_PKTS_UP_TO_128_BYTES",
    "DURATION_OUT",
    "NUM_PKTS_512_TO_1024_BYTES",
    "SRC_TO_DST_IAT_AVG",
    "DURATION_IN",
    "SHORTEST_FLOW_PKT",
    "RETRANSMITTED_IN_PKTS",
    "FLOW_DURATION_MILLISECONDS",
    "IN_BYTES",
    "MIN_IP_PKT_LEN",
    "TCP_WIN_MAX_OUT",
    "SRC_TO_DST_IAT_MIN",
    "RETRANSMITTED_OUT_BYTES",
    "DST_TO_SRC_IAT_MAX",
    "DST_TO_SRC_SECOND_BYTES",
    "DNS_TTL_ANSWER",
    "NUM_PKTS_1024_TO_1514_BYTES",
    "SRC_TO_DST_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_STDDEV",
    "OUT_PKTS",
    "SRC_TO_DST_IAT_MAX",
    "TCP_WIN_MAX_IN",
    "MAX_IP_PKT_LEN",
    "DST_TO_SRC_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_MIN",
    "RETRANSMITTED_IN_BYTES"

    ]

categorical_cols = [
    "PROTOCOL",
    "L7_PROTO",
    "TCP_FLAGS",
    "CLIENT_TCP_FLAGS",
    "SERVER_TCP_FLAGS",
    "ICMP_TYPE",
    "ICMP_IPV4_TYPE",
    "DNS_QUERY_TYPE",
    "FTP_COMMAND_RET_CODE"
    ]

target_col = 'Attack'
model_name = 'branchy_NB15'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 4096


In [2]:
root_included = ['Benign', 'Worms']
root_groups = [['DoS', 'Analysis'], ['Backdoor', 'Fuzzers', 'Exploits', 'Generic', 'Reconnaissance', 'Shellcode']] # Classifica Benign e Worms
node_1_groups = [] # Classifica DoS e Analytics
node_1_included = ['DoS', 'Analysis']
node_2_groups = [['Backdoor', 'Fuzzers'], ['Generic', 'Reconnaissance', 'Shellcode']] # Classifica Exploits
node_2_included = ['Exploits']
node_21_groups = [] # Classifica Backdoor e Fuzzers
node_21_included = ['Backdoor', 'Fuzzers']
node_22_groups = [] # Classifica Genric, Reconnaissance e Shellcode
node_22_included = ['Generic', 'Reconnaissance', 'Shellcode']


In [3]:
data = load_data(dataset_path)

### Root

In [4]:
root_train_dataloader, root_valid_dataloader, root_test_dataloader, root_cat_cardinalities, root_cw, root_class_names, enc, scaler = prepare_dl_nb15_branching(
    data=data,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=4096,
    groups=root_groups
)

print(root_class_names)

['Benign', 'Worms', 'group_0', 'group_1']


In [5]:
embedder = RawEmbedder(
    numeric_dim=len(numerical_cols),
    cat_cardinalities=root_cat_cardinalities,
    embed_dims=[min(50, (card + 1) // 2) for card in root_cat_cardinalities]
).to(device)

In [6]:
root = Node(in_dim=embedder.out_dim ,hidden=256, n_classes=len(root_class_names), tau={0 : 0.99, 1 : 0.89, 2 : 0, 3 : 0}, embedder=embedder).to(device)
root_train_dataloader = flatten_dl(dataloader=root_train_dataloader, chain_node=[root], embedder=embedder, batch_size=batch_size, device=device)
root_valid_dataloader = flatten_dl(dataloader=root_valid_dataloader, chain_node=[root], embedder=embedder, batch_size=batch_size, device=device)

### Node 1

In [7]:
node_1_train_dataloader, node_1_valid_dataloader, node_1_test_dataloader, node_1_cat_cardinalities, node_1_cw, node_1_class_names, enc, scaler = prepare_dl_nb15_branching(
    data=data,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=4096,
    groups=node_1_groups,
    include_classes=node_1_included,
    remove_classes=root_included + node_2_groups[0] + node_2_groups[1]
)

print(node_1_class_names)

['Analysis', 'DoS']


In [8]:
node_1 = Node(256, 128, n_classes=len(node_1_class_names), tau={0 : 0, 1 : 0}).to(device)
node_1_train_dataloader = flatten_dl(dataloader=node_1_train_dataloader, chain_node=[root, node_1], batch_size=batch_size, device=device)
node_1_valid_dataloader = flatten_dl(dataloader=node_1_valid_dataloader, chain_node=[root, node_1], batch_size=batch_size, device=device)

### Node 2

In [9]:
node_2_train_dataloader, node_2_valid_dataloader, node_2_test_dataloader, node_2_cat_cardinalities, node_2_cw, node_2_class_names, enc, scaler = prepare_dl_nb15_branching(
    data=data,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=4096,
    groups=node_2_groups,
    include_classes=node_2_included,
    remove_classes=root_included + node_1_included
)

print(node_2_class_names)

['Exploits', 'group_0', 'group_1']


In [10]:
node_2 = Node(256, 128, n_classes=len(node_2_class_names), tau={0:0.8, 1:0, 2:0}).to(device)
node_2_train_dataloader = flatten_dl(dataloader=node_2_train_dataloader, chain_node=[root, node_2], batch_size=batch_size, device=device)
node_2_valid_dataloader = flatten_dl(dataloader=node_2_valid_dataloader, chain_node=[root, node_2], batch_size=batch_size, device=device)

#### Node 21

In [11]:
node_21_train_dataloader, node_21_valid_dataloader, node_21_test_dataloader, node_21_cat_cardinalities, node_21_cw, node_21_class_names, enc, scaler = prepare_dl_nb15_branching(
    data=data,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=4096,
    groups=node_21_groups,
    include_classes=node_21_included,
    remove_classes=root_included + node_1_included + node_22_included
)

print(node_21_class_names)

['Backdoor', 'Fuzzers']


In [12]:
node_21 = Node(128, 64, n_classes=len(node_21_class_names), tau={0:0, 1:0}).to(device)
node_21_train_dataloader = flatten_dl(dataloader=node_21_train_dataloader, chain_node=[root, node_2, node_21], batch_size=batch_size, device=device)
node_21_valid_dataloader = flatten_dl(dataloader=node_21_valid_dataloader, chain_node=[root, node_2, node_21], batch_size=batch_size, device=device)

### Node 22

In [13]:
node_22_train_dataloader, node_22_valid_dataloader, node_22_test_dataloader, node_22_cat_cardinalities, node_22_cw, node_22_class_names, enc, scaler = prepare_dl_nb15_branching(
    data=data,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=4096,
    groups=node_22_groups,
    include_classes=node_22_included,
    remove_classes= root_included + node_1_included + node_21_included
)

print(node_22_class_names)

['Generic', 'Reconnaissance', 'Shellcode']


In [14]:
node_22 = Node(128, 64, n_classes=len(node_22_class_names), tau={0:0, 1:0, 2:0}).to(device)
node_22_train_dataloader = flatten_dl(dataloader=node_22_train_dataloader, chain_node=[root, node_2, node_22], batch_size=batch_size, device=device)
node_22_valid_dataloader = flatten_dl(dataloader=node_22_valid_dataloader, chain_node=[root, node_2, node_22], batch_size=batch_size, device=device)

### Creo l'albero

In [15]:
root.add_child(class_idx=2, child=node_1)
root.add_child(class_idx=3, child=node_2)
node_2.add_child(class_idx=1, child=node_21)
node_2.add_child(class_idx=2, child=node_22)

nodes = [root, node_1, node_2, node_21, node_22]

In [16]:
def train_node(node, dl_train, dl_val, cw, epochs=8, lr=1e-3):
    """Allena un singolo nodo (routing OFF)"""
    params = list(node.parameters())
    opt = torch.optim.AdamW(params, lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss(weight=cw.to(device))
    node.fit(dl_train, dl_val, epochs,
             optimizer=opt, loss_fn=loss_fn, device=device)
    node.freeze()

In [18]:
print("==> Training ROOT")
train_node(node=root, dl_train=root_train_dataloader, dl_val=root_valid_dataloader, cw=root_cw, epochs=10)
for p in embedder.parameters(): p.requires_grad = False

print("==> Training Node 1")
train_node(node_1, node_1_train_dataloader, node_1_valid_dataloader, cw=node_1_cw)

print("==> Training Node 2")
train_node(node_2, node_2_train_dataloader, node_2_valid_dataloader, cw=node_2_cw)

print("==> Training Node 21")
train_node(node_21, node_21_train_dataloader, node_21_valid_dataloader, cw=node_21_cw)

print("==> Training Node 22")
train_node(node_22, node_22_train_dataloader, node_22_valid_dataloader, cw=node_22_cw)

==> Training ROOT


KeyboardInterrupt: 