In [4]:
from utils import load_and_prepare_nb15
from neural_network import *
from dense_branchynet import *
import time
from sklearn.metrics import classification_report


numerical_cols = [
    "NUM_PKTS_128_TO_256_BYTES",
    "RETRANSMITTED_OUT_PKTS",
    "SRC_TO_DST_IAT_STDDEV",
    "SRC_TO_DST_SECOND_BYTES",
    "IN_PKTS",
    "LONGEST_FLOW_PKT",
    "NUM_PKTS_256_TO_512_BYTES",
    "DST_TO_SRC_IAT_AVG",
    "OUT_BYTES",
    "NUM_PKTS_UP_TO_128_BYTES",
    "DURATION_OUT",
    "NUM_PKTS_512_TO_1024_BYTES",
    "SRC_TO_DST_IAT_AVG",
    "DURATION_IN",
    "SHORTEST_FLOW_PKT",
    "RETRANSMITTED_IN_PKTS",
    "FLOW_DURATION_MILLISECONDS",
    "IN_BYTES",
    "MIN_IP_PKT_LEN",
    "TCP_WIN_MAX_OUT",
    "SRC_TO_DST_IAT_MIN",
    "RETRANSMITTED_OUT_BYTES",
    "DST_TO_SRC_IAT_MAX",
    "DST_TO_SRC_SECOND_BYTES",
    "DNS_TTL_ANSWER",
    "NUM_PKTS_1024_TO_1514_BYTES",
    "SRC_TO_DST_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_STDDEV",
    "OUT_PKTS",
    "SRC_TO_DST_IAT_MAX",
    "TCP_WIN_MAX_IN",
    "MAX_IP_PKT_LEN",
    "DST_TO_SRC_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_MIN",
    "RETRANSMITTED_IN_BYTES"

    ]

categorical_cols = [
    "PROTOCOL",
    "L7_PROTO",
    "TCP_FLAGS",
    "CLIENT_TCP_FLAGS",
    "SERVER_TCP_FLAGS",
    "ICMP_TYPE",
    "ICMP_IPV4_TYPE",
    "DNS_QUERY_TYPE",
    "FTP_COMMAND_RET_CODE"
    ]

target_col = 'Attack'
num_target_classes = 10
dataset_path = 'datasets/NF-UNSW-NB15-v3.csv'
batch_size = 4096
epochs = 20


In [5]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities, cw, target_names = load_and_prepare_nb15(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=batch_size,
)

embedding_dims = [min(50, (card + 1) // 2) for card in cat_cardinalities]

In [3]:
model = NeuralNetwork(
    hidden_layers_sizes=[256, 256, 256], 
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical_features=len(numerical_cols),
    num_target_classes=num_target_classes,
)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

In [None]:
model.fit(
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    epochs=epochs,
)

--- Epoch: 0  |  Loss: 0.0397  |  F1 Score: 0.4062  |  Accuracy: 0.9869 ---
--- Epoch: 1  |  Loss: 0.0378  |  F1 Score: 0.4324  |  Accuracy: 0.9873 ---
--- Epoch: 2  |  Loss: 0.0369  |  F1 Score: 0.4326  |  Accuracy: 0.9875 ---
--- Epoch: 3  |  Loss: 0.0358  |  F1 Score: 0.4874  |  Accuracy: 0.9881 ---
--- Epoch: 4  |  Loss: 0.0355  |  F1 Score: 0.5167  |  Accuracy: 0.9880 ---
--- Epoch: 5  |  Loss: 0.0353  |  F1 Score: 0.4985  |  Accuracy: 0.9881 ---
--- Epoch: 6  |  Loss: 0.0352  |  F1 Score: 0.4995  |  Accuracy: 0.9883 ---
--- Epoch: 7  |  Loss: 0.0340  |  F1 Score: 0.5500  |  Accuracy: 0.9888 ---
--- Epoch: 8  |  Loss: 0.0338  |  F1 Score: 0.5717  |  Accuracy: 0.9888 ---
--- Epoch: 9  |  Loss: 0.0338  |  F1 Score: 0.5725  |  Accuracy: 0.9889 ---
--- Epoch: 10  |  Loss: 0.0341  |  F1 Score: 0.5607  |  Accuracy: 0.9887 ---
--- Epoch: 11  |  Loss: 0.0339  |  F1 Score: 0.5544  |  Accuracy: 0.9889 ---
--- Epoch: 12  |  Loss: 0.0343  |  F1 Score: 0.5772  |  Accuracy: 0.9887 ---
--- Epoch

In [7]:
def make_optimizer(model, wd=1e-4,
                   lr_trunk=1e-3, lr_heads=2e-3, lr_emb=None):
    if lr_emb is None:
        lr_emb = lr_trunk * 0.5

    trunk_params = list(model.fc1.parameters()) + \
                   list(model.fc2.parameters()) + \
                   list(model.fc3.parameters())
    heads12_params = list(model.head1.parameters()) + list(model.head2.parameters())
    head3_params   = list(model.head3.parameters())

    opt = torch.optim.AdamW([
        {"params": model.embeddings.parameters(), "lr": lr_emb},
        {"params": trunk_params, "lr": lr_trunk},
        {"params": heads12_params, "lr": lr_heads},
        {"params": head3_params, "lr": lr_heads},
    ], weight_decay=wd)
    return opt


In [8]:
branchynet = DenseBranchyNet(
    hidden_layers_sizes=[256, 256, 256],
    taus=[1.4, 1.6],
    alphas=[0.2, 0.8, 0.9],
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical=len(numerical_cols),
    num_target_classes=num_target_classes
    ).to(device)

In [None]:
branchynet.set_stage("stage0_trunk_final")
optimizer = make_optimizer(branchynet, lr_trunk=1e-3, lr_heads=2e-3, lr_emb=5e-4)
hist0 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs/2)

[Ep 001] train_loss=0.2059 (l1=0.1846 l2=0.0975 l3=0.1010)  | val_acc=0.9847  val_f1=0.9818
[Ep 002] train_loss=0.0860 (l1=0.0643 l2=0.0438 l3=0.0424)  | val_acc=0.9866  val_f1=0.9843
[Ep 003] train_loss=0.0789 (l1=0.0555 l2=0.0405 l3=0.0394)  | val_acc=0.9868  val_f1=0.9848
[Ep 004] train_loss=0.0754 (l1=0.0516 l2=0.0387 l3=0.0379)  | val_acc=0.9871  val_f1=0.9850
[Ep 005] train_loss=0.0730 (l1=0.0490 l2=0.0375 l3=0.0369)  | val_acc=0.9872  val_f1=0.9853
[Ep 006] train_loss=0.0712 (l1=0.0473 l2=0.0366 l3=0.0361)  | val_acc=0.9874  val_f1=0.9855
[Ep 007] train_loss=0.0697 (l1=0.0459 l2=0.0359 l3=0.0353)  | val_acc=0.9876  val_f1=0.9858
[Ep 008] train_loss=0.0685 (l1=0.0448 l2=0.0353 l3=0.0348)  | val_acc=0.9875  val_f1=0.9858
[Ep 009] train_loss=0.0675 (l1=0.0440 l2=0.0348 l3=0.0343)  | val_acc=0.9877  val_f1=0.9861
[Ep 010] train_loss=0.0668 (l1=0.0433 l2=0.0344 l3=0.0340)  | val_acc=0.9875  val_f1=0.9855
[Ep 011] train_loss=0.0657 (l1=0.0426 l2=0.0339 l3=0.0334)  | val_acc=0.9880  va

In [None]:
branchynet.set_stage("stage1_heads_only")
optimizer = make_optimizer(branchynet, lr_trunk=0.0, lr_heads=3e-3, lr_emb=0.0) 
hist1 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs/2)

[Ep 001] train_loss=0.0556 (l1=0.0352 l2=0.0300 l3=0.0273)  | val_acc=0.9887  val_f1=0.9877
[Ep 002] train_loss=0.0549 (l1=0.0324 l2=0.0298 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 003] train_loss=0.0548 (l1=0.0318 l2=0.0298 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 004] train_loss=0.0545 (l1=0.0314 l2=0.0296 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 005] train_loss=0.0538 (l1=0.0306 l2=0.0289 l3=0.0272)  | val_acc=0.9888  val_f1=0.9878
[Ep 006] train_loss=0.0538 (l1=0.0305 l2=0.0289 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 007] train_loss=0.0538 (l1=0.0304 l2=0.0290 l3=0.0272)  | val_acc=0.9887  val_f1=0.9878
[Ep 008] train_loss=0.0537 (l1=0.0304 l2=0.0289 l3=0.0272)  | val_acc=0.9887  val_f1=0.9878
[Ep 009] train_loss=0.0534 (l1=0.0300 l2=0.0286 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 010] train_loss=0.0534 (l1=0.0300 l2=0.0286 l3=0.0273)  | val_acc=0.9887  val_f1=0.9878
[Ep 011] train_loss=0.0534 (l1=0.0300 l2=0.0286 l3=0.0272)  | val_acc=0.9887  va

In [None]:
branchynet.set_stage("stage2_finetune_all")
optimizer = make_optimizer(branchynet, lr_trunk=1e-4, lr_heads=3e-4, lr_emb=5e-5)
hist2 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs/2)

[Ep 001] train_loss=0.0538 (l1=0.0299 l2=0.0287 l3=0.0277)  | val_acc=0.9888  val_f1=0.9878
[Ep 002] train_loss=0.0537 (l1=0.0299 l2=0.0286 l3=0.0276)  | val_acc=0.9887  val_f1=0.9878
[Ep 003] train_loss=0.0536 (l1=0.0298 l2=0.0285 l3=0.0276)  | val_acc=0.9887  val_f1=0.9877
[Ep 004] train_loss=0.0535 (l1=0.0298 l2=0.0285 l3=0.0275)  | val_acc=0.9887  val_f1=0.9878
[Ep 005] train_loss=0.0530 (l1=0.0296 l2=0.0282 l3=0.0273)  | val_acc=0.9887  val_f1=0.9877
[Ep 006] train_loss=0.0529 (l1=0.0296 l2=0.0281 l3=0.0272)  | val_acc=0.9887  val_f1=0.9877
[Ep 007] train_loss=0.0528 (l1=0.0295 l2=0.0281 l3=0.0272)  | val_acc=0.9887  val_f1=0.9878
[Ep 008] train_loss=0.0525 (l1=0.0294 l2=0.0279 l3=0.0271)  | val_acc=0.9887  val_f1=0.9878
[Ep 009] train_loss=0.0525 (l1=0.0294 l2=0.0279 l3=0.0270)  | val_acc=0.9887  val_f1=0.9878
[Ep 010] train_loss=0.0525 (l1=0.0294 l2=0.0279 l3=0.0270)  | val_acc=0.9887  val_f1=0.9877
[Ep 011] train_loss=0.0524 (l1=0.0293 l2=0.0278 l3=0.0270)  | val_acc=0.9887  va

In [12]:
@torch.no_grad()
def measure_stage_costs(model, loader, device, n_batches=5):
    model.eval()
    if device.type == "cuda":
        starter = torch.cuda.Event(enable_timing=True); ender = torch.cuda.Event(enable_timing=True)
    t1, t12, t123 = [], [], []
    for i, (x_num, x_cat, _) in enumerate(loader):
        x_num = x_num.to(device, non_blocking=True); x_cat = x_cat.to(device, non_blocking=True).long()
        if device.type == "cuda": torch.cuda.synchronize()
        # c1
        if device.type == "cuda": starter.record()
        x  = model._embed_input(x_num, x_cat)
        h1 = F.relu(model.fc1(x)); l1 = model.head1(h1)
        if device.type == "cuda": ender.record(); torch.cuda.synchronize(); t1.append(starter.elapsed_time(ender)/1000.0)
        else: t1.append(0)  # su CPU puoi usare time.perf_counter, ma su GPU è più preciso così
        # c1+c2
        if device.type == "cuda": starter.record()
        h2 = F.relu(model.fc2(h1)); l2 = model.head2(h2)
        if device.type == "cuda": ender.record(); torch.cuda.synchronize(); t12.append(starter.elapsed_time(ender)/1000.0)
        # c1+c2+c3
        if device.type == "cuda": starter.record()
        h3 = F.relu(model.fc3(h2)); l3 = model.head3(h3)
        if device.type == "cuda": ender.record(); torch.cuda.synchronize(); t123.append(starter.elapsed_time(ender)/1000.0)

        if i+1 >= n_batches: break

    c1 = sum(t1)/len(t1); c12 = sum(t12)/len(t12); c123 = sum(t123)/len(t123)
    return {"c1": c1, "c2": c12 - c1, "c3": c123 - c12}


In [13]:
@torch.no_grad()
def f1_baseline_head3(model, loader, device):
    from sklearn.metrics import f1_score
    model.eval(); y_true, y_pred = [], []
    for x_num, x_cat, y in loader:
        x_num = x_num.to(device); x_cat = x_cat.to(device).long()
        _, _, out3 = model(x_num, x_cat)
        y_true += y.tolist(); y_pred += out3.argmax(1).cpu().tolist()
    return f1_score(y_true, y_pred, average="weighted")

use_margin = True  # stesso setting di evaluate/predict
base_f1 = f1_baseline_head3(branchynet, valid_dataloader, device)
cost = measure_stage_costs(branchynet, valid_dataloader, device, n_batches=5)

best = branchynet.calibrate_taus(
    valid_dataloader, device, use_margin=use_margin,
    n_grid=21, mode="min_cost_at_f1",
    f1_min=base_f1 - 0.001,                 # es. max -0.1% drop F1
    cost=cost
)
branchynet.taus = [best["t1"], best["t2"]]
print("τ:", branchynet.taus, "F1:", best["f1"], "rates:", best["rates"], "cost_norm:", best["cost"])


τ: [81.59566497802734, 167.44451904296875] F1: 0.9878117999140423 rates: {'r1': 0.005002378020435572, 'r2': 0.004975627176463604, 'r3': 0.9900219948031008} cost_norm: 1.0051191961050954


In [13]:
y_true = torch.cat([y for _, _, y in test_dataloader]).numpy()

In [28]:
start_model = time.time()
model_preds = model.predict(test_dataloader,device)
end_model = time.time()

In [None]:
start_branchynet = time.time()
branchy_preds = branchynet.predict(test_dataloader, device, early_exit=True, use_margin=use_margin)
end_branchynet = time.time()

In [16]:
dur_model   = end_model - start_model
dur_branchy = end_branchynet - start_branchynet

print(f"Classic DNN: {dur_model:.6f} s")
print(f"BranchyNet:  {dur_branchy:.6f} s")

Classic DNN: 3.020097 s
BranchyNet:  3.310293 s


In [17]:
overhead = 100 - ((100 / dur_branchy) * dur_model)
print(f'BranchyNet Overhead: {overhead:.2f}%')

BranchyNet Overhead: 8.77%


In [18]:
print("\n=== Classification Report DNN===")
print(classification_report(y_true, model_preds.numpy(), target_names=target_names, digits=4))
print("\n=== Classification Report BRANCHYNET===")
print(classification_report(y_true, branchy_preds.numpy(), target_names=target_names, digits=4))



=== Classification Report DNN===
                precision    recall  f1-score   support

      Analysis     0.2973    0.2391    0.2651       184
      Backdoor     0.3923    0.0986    0.1577       517
        Benign     1.0000    1.0000    1.0000    322654
           DoS     0.5357    0.1781    0.2673       758
      Exploits     0.7816    0.7944    0.7879      5823
       Fuzzers     0.6922    0.9403    0.7974      3838
       Generic     0.7633    0.6232    0.6862       714
Reconnaissance     0.7284    0.5968    0.6561      1694
     Shellcode     0.7761    0.4370    0.5591       238
         Worms     0.6250    0.7500    0.6818        20

      accuracy                         0.9888    336440
     macro avg     0.6592    0.5658    0.5859    336440
  weighted avg     0.9883    0.9888    0.9879    336440


=== Classification Report BRANCHYNET===
                precision    recall  f1-score   support

      Analysis     0.3309    0.2500    0.2848       184
      Backdoor     0.3081