In [1]:
from utils import load_and_prepare_nb15
from neural_network import *
from dense_branchynet import *
import time
from sklearn.metrics import classification_report


numerical_cols = [
    "NUM_PKTS_128_TO_256_BYTES",
    "RETRANSMITTED_OUT_PKTS",
    "SRC_TO_DST_IAT_STDDEV",
    "SRC_TO_DST_SECOND_BYTES",
    "IN_PKTS",
    "LONGEST_FLOW_PKT",
    "NUM_PKTS_256_TO_512_BYTES",
    "DST_TO_SRC_IAT_AVG",
    "OUT_BYTES",
    "NUM_PKTS_UP_TO_128_BYTES",
    "DURATION_OUT",
    "NUM_PKTS_512_TO_1024_BYTES",
    "SRC_TO_DST_IAT_AVG",
    "DURATION_IN",
    "SHORTEST_FLOW_PKT",
    "RETRANSMITTED_IN_PKTS",
    "FLOW_DURATION_MILLISECONDS",
    "IN_BYTES",
    "MIN_IP_PKT_LEN",
    "TCP_WIN_MAX_OUT",
    "SRC_TO_DST_IAT_MIN",
    "RETRANSMITTED_OUT_BYTES",
    "DST_TO_SRC_IAT_MAX",
    "DST_TO_SRC_SECOND_BYTES",
    "DNS_TTL_ANSWER",
    "NUM_PKTS_1024_TO_1514_BYTES",
    "SRC_TO_DST_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_STDDEV",
    "OUT_PKTS",
    "SRC_TO_DST_IAT_MAX",
    "TCP_WIN_MAX_IN",
    "MAX_IP_PKT_LEN",
    "DST_TO_SRC_AVG_THROUGHPUT",
    "DST_TO_SRC_IAT_MIN",
    "RETRANSMITTED_IN_BYTES"

    ]

categorical_cols = [
    "PROTOCOL",
    "L7_PROTO",
    "TCP_FLAGS",
    "CLIENT_TCP_FLAGS",
    "SERVER_TCP_FLAGS",
    "ICMP_TYPE",
    "ICMP_IPV4_TYPE",
    "DNS_QUERY_TYPE",
    "FTP_COMMAND_RET_CODE"
    ]

target_col = 'Attack'
num_target_classes = 10
dataset_path = 'datasets/NF-UNSW-NB15-v3.csv'
batch_size = 2018
epochs = 20


In [2]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities, cw, target_names = load_and_prepare_nb15(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=batch_size,
)

embedding_dims = [min(50, (card + 1) // 2) for card in cat_cardinalities]

In [3]:
model = NeuralNetwork(
    hidden_layers_sizes=[256, 256, 256], 
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical_features=len(numerical_cols),
    num_target_classes=num_target_classes,
)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

In [6]:
model.fit(
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    epochs=epochs,
)

--- Epoch: 0  |  Loss: 0.0408  |  F1 Score: 0.3490  |  Accuracy: 0.9869 ---
--- Epoch: 1  |  Loss: 0.0383  |  F1 Score: 0.4018  |  Accuracy: 0.9873 ---
--- Epoch: 2  |  Loss: 0.0369  |  F1 Score: 0.4563  |  Accuracy: 0.9876 ---
--- Epoch: 3  |  Loss: 0.0362  |  F1 Score: 0.4586  |  Accuracy: 0.9879 ---
--- Epoch: 4  |  Loss: 0.0359  |  F1 Score: 0.5191  |  Accuracy: 0.9881 ---
--- Epoch: 5  |  Loss: 0.0355  |  F1 Score: 0.5313  |  Accuracy: 0.9882 ---
--- Epoch: 6  |  Loss: 0.0349  |  F1 Score: 0.4677  |  Accuracy: 0.9885 ---
--- Epoch: 7  |  Loss: 0.0344  |  F1 Score: 0.5492  |  Accuracy: 0.9886 ---
--- Epoch: 8  |  Loss: 0.0341  |  F1 Score: 0.5571  |  Accuracy: 0.9887 ---
--- Epoch: 9  |  Loss: 0.0338  |  F1 Score: 0.5570  |  Accuracy: 0.9887 ---
--- Epoch: 10  |  Loss: 0.0338  |  F1 Score: 0.5499  |  Accuracy: 0.9888 ---
--- Epoch: 11  |  Loss: 0.0339  |  F1 Score: 0.5697  |  Accuracy: 0.9889 ---
--- Epoch: 12  |  Loss: 0.0336  |  F1 Score: 0.5715  |  Accuracy: 0.9890 ---
--- Epoch

In [7]:
def make_optimizer(model, wd=1e-4,
                   lr_trunk=1e-3, lr_heads=2e-3, lr_emb=None):
    if lr_emb is None:
        lr_emb = lr_trunk * 0.5

    trunk_params = list(model.fc1.parameters()) + \
                   list(model.fc2.parameters()) + \
                   list(model.fc3.parameters())
    heads12_params = list(model.head1.parameters()) + list(model.head2.parameters())
    head3_params   = list(model.head3.parameters())

    opt = torch.optim.AdamW([
        {"params": model.embeddings.parameters(), "lr": lr_emb},
        {"params": trunk_params, "lr": lr_trunk},
        {"params": heads12_params, "lr": lr_heads},
        {"params": head3_params, "lr": lr_heads},
    ], weight_decay=wd)
    return opt


In [8]:
branchynet = DenseBranchyNet(
    hidden_layers_sizes=[256, 256, 256],
    taus=[1.4, 1.6],
    alphas=[0.2, 0.8, 0.9],
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical=len(numerical_cols),
    num_target_classes=num_target_classes
    ).to(device)

In [10]:
branchynet.set_stage("stage0_trunk_final")
optimizer = make_optimizer(branchynet, lr_trunk=1e-3, lr_heads=2e-3, lr_emb=5e-4)
hist0 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=0.1480 (l1=0.1443 l2=0.0692 l3=0.0708)  | val_acc=0.9860  val_f1=0.9834
[Ep 002] train_loss=0.0801 (l1=0.0552 l2=0.0409 l3=0.0404)  | val_acc=0.9868  val_f1=0.9847
[Ep 003] train_loss=0.0748 (l1=0.0493 l2=0.0383 l3=0.0381)  | val_acc=0.9870  val_f1=0.9852
[Ep 004] train_loss=0.0719 (l1=0.0462 l2=0.0370 l3=0.0368)  | val_acc=0.9872  val_f1=0.9853
[Ep 005] train_loss=0.0699 (l1=0.0442 l2=0.0360 l3=0.0359)  | val_acc=0.9878  val_f1=0.9860
[Ep 006] train_loss=0.0682 (l1=0.0429 l2=0.0352 l3=0.0350)  | val_acc=0.9878  val_f1=0.9860
[Ep 007] train_loss=0.0669 (l1=0.0420 l2=0.0345 l3=0.0344)  | val_acc=0.9877  val_f1=0.9861
[Ep 008] train_loss=0.0657 (l1=0.0412 l2=0.0339 l3=0.0337)  | val_acc=0.9882  val_f1=0.9868
[Ep 009] train_loss=0.0648 (l1=0.0406 l2=0.0334 l3=0.0333)  | val_acc=0.9881  val_f1=0.9868
[Ep 010] train_loss=0.0640 (l1=0.0402 l2=0.0330 l3=0.0328)  | val_acc=0.9883  val_f1=0.9868


In [12]:
branchynet.set_stage("stage1_heads_only")
optimizer = make_optimizer(branchynet, lr_trunk=0.0, lr_heads=3e-3, lr_emb=0.0) 
hist1 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=0.0612 (l1=0.0343 l2=0.0323 l3=0.0316)  | val_acc=0.9885  val_f1=0.9871
[Ep 002] train_loss=0.0607 (l1=0.0335 l2=0.0320 l3=0.0316)  | val_acc=0.9886  val_f1=0.9872
[Ep 003] train_loss=0.0608 (l1=0.0334 l2=0.0320 l3=0.0317)  | val_acc=0.9885  val_f1=0.9872
[Ep 004] train_loss=0.0607 (l1=0.0332 l2=0.0319 l3=0.0317)  | val_acc=0.9886  val_f1=0.9872
[Ep 005] train_loss=0.0604 (l1=0.0331 l2=0.0317 l3=0.0316)  | val_acc=0.9886  val_f1=0.9873
[Ep 006] train_loss=0.0607 (l1=0.0330 l2=0.0319 l3=0.0317)  | val_acc=0.9885  val_f1=0.9873
[Ep 007] train_loss=0.0605 (l1=0.0330 l2=0.0318 l3=0.0316)  | val_acc=0.9886  val_f1=0.9873
[Ep 008] train_loss=0.0596 (l1=0.0320 l2=0.0309 l3=0.0316)  | val_acc=0.9887  val_f1=0.9873
[Ep 009] train_loss=0.0596 (l1=0.0320 l2=0.0309 l3=0.0316)  | val_acc=0.9886  val_f1=0.9873
[Ep 010] train_loss=0.0596 (l1=0.0320 l2=0.0309 l3=0.0316)  | val_acc=0.9886  val_f1=0.9873


In [13]:
branchynet.set_stage("stage2_finetune_all")
optimizer = make_optimizer(branchynet, lr_trunk=1e-4, lr_heads=3e-4, lr_emb=5e-5)
hist2 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=0.0583 (l1=0.0314 l2=0.0305 l3=0.0307)  | val_acc=0.9888  val_f1=0.9876
[Ep 002] train_loss=0.0575 (l1=0.0312 l2=0.0301 l3=0.0302)  | val_acc=0.9889  val_f1=0.9878
[Ep 003] train_loss=0.0572 (l1=0.0312 l2=0.0300 l3=0.0300)  | val_acc=0.9889  val_f1=0.9879
[Ep 004] train_loss=0.0568 (l1=0.0310 l2=0.0298 l3=0.0298)  | val_acc=0.9889  val_f1=0.9879
[Ep 005] train_loss=0.0567 (l1=0.0310 l2=0.0297 l3=0.0297)  | val_acc=0.9889  val_f1=0.9879
[Ep 006] train_loss=0.0563 (l1=0.0309 l2=0.0295 l3=0.0296)  | val_acc=0.9889  val_f1=0.9879
[Ep 007] train_loss=0.0558 (l1=0.0306 l2=0.0292 l3=0.0292)  | val_acc=0.9890  val_f1=0.9879
[Ep 008] train_loss=0.0556 (l1=0.0306 l2=0.0291 l3=0.0292)  | val_acc=0.9890  val_f1=0.9880
[Ep 009] train_loss=0.0555 (l1=0.0305 l2=0.0291 l3=0.0291)  | val_acc=0.9890  val_f1=0.9879
[Ep 010] train_loss=0.0554 (l1=0.0305 l2=0.0290 l3=0.0290)  | val_acc=0.9890  val_f1=0.9880


In [16]:
@torch.no_grad()
def measure_stage_costs(model, loader, device, n_batches=5, warmup=1):
    """
    Misura i tempi medi dei tre stadi:
      c1   = embed + fc1 + head1
      c2   = (fc2 + head2) incrementale (cioè oltre c1)
      c3   = (fc3 + head3) incrementale (cioè oltre c1+c2)
    Ritorna: dict {"c1": ..., "c2": ..., "c3": ...} in secondi.
    """
    model.eval()

    # Timer containers
    t_c1, t_c12, t_c123 = [], [], []

    # Limita n_batches a quello che c'è davvero
    max_batches = min(n_batches + warmup, len(loader)) if hasattr(loader, "__len__") else n_batches + warmup
    batches_done = 0

    # Utility per timing
    use_cuda = (device.type == "cuda")

    if use_cuda:
        starter = torch.cuda.Event(enable_timing=True)
        ender   = torch.cuda.Event(enable_timing=True)

    for i, (x_num, x_cat, _) in enumerate(loader):
        if batches_done >= max_batches:
            break

        x_num = x_num.to(device, non_blocking=True)
        x_cat = x_cat.to(device, non_blocking=True).long()

        # --- C1 ---
        if use_cuda:
            torch.cuda.synchronize()
            starter.record()
            x  = model._embed_input(x_num, x_cat)
            h1 = F.relu(model.fc1(x)); l1 = model.head1(h1)
            ender.record(); torch.cuda.synchronize()
            c1 = starter.elapsed_time(ender) / 1000.0
        else:
            t0 = time.perf_counter()
            x  = model._embed_input(x_num, x_cat)
            h1 = F.relu(model.fc1(x)); l1 = model.head1(h1)
            c1 = time.perf_counter() - t0

        # --- C1 + C2 ---
        if use_cuda:
            starter.record()
            h2 = F.relu(model.fc2(h1)); l2 = model.head2(h2)
            ender.record(); torch.cuda.synchronize()
            c12 = starter.elapsed_time(ender) / 1000.0 + c1
        else:
            t1 = time.perf_counter()
            h2 = F.relu(model.fc2(h1)); l2 = model.head2(h2)
            c12 = (time.perf_counter() - t1) + c1

        # --- C1 + C2 + C3 ---
        if use_cuda:
            starter.record()
            h3 = F.relu(model.fc3(h2)); l3 = model.head3(h3)
            ender.record(); torch.cuda.synchronize()
            c123 = starter.elapsed_time(ender) / 1000.0 + c12
        else:
            t2 = time.perf_counter()
            h3 = F.relu(model.fc3(h2)); l3 = model.head3(h3)
            c123 = (time.perf_counter() - t2) + c12

        # Salta i warm-up
        if i >= warmup:
            t_c1.append(c1)
            t_c12.append(c12)
            t_c123.append(c123)

        batches_done += 1

    if len(t_c1) == 0:
        raise RuntimeError("measure_stage_costs: nessun batch misurato (loader troppo corto?)")

    c1   = sum(t_c1)  / len(t_c1)
    c12  = sum(t_c12) / len(t_c12)
    c123 = sum(t_c123)/ len(t_c123)

    # tempi incrementali
    return {"c1": c1, "c2": max(1e-9, c12 - c1), "c3": max(1e-9, c123 - c12)}

In [None]:
@torch.no_grad()
def f1_baseline_head3(model, loader, device):
    from sklearn.metrics import f1_score
    model.eval(); y_true, y_pred = [], []
    for x_num, x_cat, y in loader:
        x_num = x_num.to(device); x_cat = x_cat.to(device).long()
        _, _, out3 = model(x_num, x_cat)
        y_true += y.tolist(); y_pred += out3.argmax(1).cpu().tolist()
    return f1_score(y_true, y_pred, average="weighted")

use_margin = True  # stesso setting di evaluate/predict
base_f1 = f1_baseline_head3(branchynet, valid_dataloader, device)
cost = measure_stage_costs(branchynet, valid_dataloader, device, n_batches=5)

best = branchynet.calibrate_taus(
    valid_dataloader, device, use_margin=use_margin,
    n_grid=21, mode="min_cost_at_f1",
    f1_min=base_f1 - 0.01,                 # es. max -1% drop F1
    cost=cost
)
branchynet.taus = [best["t1"], best["t2"]]
print("τ:", branchynet.taus, "F1:", best["f1"], "rates:", best["rates"], "cost_norm:", best["cost"])


τ: [37.200836181640625, 49.46376037597656] F1: 0.9879852739580331 rates: {'r1': 0.5, 'r2': 0.25564736127853394, 'r3': 0.24435263872146606} cost_norm: 0.6522680845615498


In [18]:
y_true = torch.cat([y for _, _, y in test_dataloader]).numpy()

In [28]:
start_model = time.time()
model_preds = model.predict(test_dataloader,device)
end_model = time.time()

In [32]:
start_branchynet = time.time()
branchy_preds = branchynet.predict(test_dataloader, device, early_exit=True, use_margin=use_margin)
end_branchynet = time.time()

In [33]:
dur_model   = end_model - start_model
dur_branchy = end_branchynet - start_branchynet

print(f"Classic DNN: {dur_model:.6f} s")
print(f"BranchyNet:  {dur_branchy:.6f} s")

Classic DNN: 4.830464 s
BranchyNet:  4.041070 s


In [34]:
overhead = 100 - ((100 / dur_branchy) * dur_model)
print(f'BranchyNet Overhead: {overhead:.2f}%')

BranchyNet Overhead: -19.53%


In [23]:
print("\n=== Classification Report DNN===")
print(classification_report(y_true, model_preds.numpy(), target_names=target_names, digits=4))
print("\n=== Classification Report BRANCHYNET===")
print(classification_report(y_true, branchy_preds.numpy(), target_names=target_names, digits=4))



=== Classification Report DNN===
                precision    recall  f1-score   support

      Analysis     0.2800    0.0380    0.0670       184
      Backdoor     0.4688    0.0870    0.1468       517
        Benign     1.0000    1.0000    1.0000    322654
           DoS     0.5841    0.1649    0.2572       758
      Exploits     0.7765    0.8032    0.7896      5823
       Fuzzers     0.6879    0.9502    0.7980      3838
       Generic     0.7357    0.6317    0.6797       714
Reconnaissance     0.7424    0.5939    0.6599      1694
     Shellcode     0.7652    0.4244    0.5459       238
         Worms     0.5714    0.8000    0.6667        20

      accuracy                         0.9889    336440
     macro avg     0.6612    0.5493    0.5611    336440
  weighted avg     0.9883    0.9889    0.9878    336440


=== Classification Report BRANCHYNET===
                precision    recall  f1-score   support

      Analysis     0.3587    0.1793    0.2391       184
      Backdoor     0.5054