In [1]:
from utils import load_and_prepare_data
from neural_network import *
from dense_branchynet import *
import time
from sklearn.metrics import classification_report

numerical_cols = [
        "duration",
        "dst_bytes",
        "missed_bytes",
        "src_bytes",
        "src_ip_bytes",
        "src_pkts",
        "dst_pkts",
        "dst_ip_bytes",
        "http_request_body_len",
        "http_response_body_len"

    ]

categorical_cols = [
        "proto",
        "conn_state",
        "http_status_code",
        "http_method",
        "http_orig_mime_types",
        "http_resp_mime_types",
    ]


target_col = 'type'
num_target_classes = 8
dataset_path = 'datasets/http_ton.csv'
batch_size = 2048
epochs = 20

In [2]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities, cw, target_names = load_and_prepare_data(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    batch_size=batch_size,
    rows_to_remove={}
)

embedding_dims = [min(50, (card + 1) // 2) for card in cat_cardinalities]

In [3]:
model = NeuralNetwork(
    hidden_layers_sizes=[256, 256, 256], 
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical_features=len(numerical_cols),
    num_target_classes=num_target_classes,
)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.85, patience=5)

In [5]:
model.fit(
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    epochs=epochs,
)

--- Epoch: 0  |  Loss: 0.6041  |  F1 Score: 0.5374  |  Accuracy: 0.7383 ---
--- Epoch: 1  |  Loss: 0.6508  |  F1 Score: 0.6078  |  Accuracy: 0.8669 ---
--- Epoch: 2  |  Loss: 0.1581  |  F1 Score: 0.7201  |  Accuracy: 0.9621 ---
--- Epoch: 3  |  Loss: 0.1465  |  F1 Score: 0.7324  |  Accuracy: 0.9618 ---
--- Epoch: 4  |  Loss: 0.1324  |  F1 Score: 0.7317  |  Accuracy: 0.9662 ---
--- Epoch: 5  |  Loss: 0.1547  |  F1 Score: 0.7509  |  Accuracy: 0.9660 ---
--- Epoch: 6  |  Loss: 0.2331  |  F1 Score: 0.7300  |  Accuracy: 0.9645 ---
--- Epoch: 7  |  Loss: 0.1963  |  F1 Score: 0.7416  |  Accuracy: 0.9650 ---
--- Epoch: 8  |  Loss: 0.1678  |  F1 Score: 0.7037  |  Accuracy: 0.9646 ---
--- Epoch: 9  |  Loss: 0.2172  |  F1 Score: 0.7401  |  Accuracy: 0.9652 ---
--- Epoch: 10  |  Loss: 0.2624  |  F1 Score: 0.6573  |  Accuracy: 0.9233 ---
--- Epoch: 11  |  Loss: 0.1312  |  F1 Score: 0.7165  |  Accuracy: 0.9835 ---
--- Epoch: 12  |  Loss: 0.1390  |  F1 Score: 0.7618  |  Accuracy: 0.9825 ---
--- Epoch

In [6]:
def make_optimizer(model, wd=1e-4,
                   lr_trunk=1e-3, lr_heads=2e-3, lr_emb=None):
    if lr_emb is None:
        lr_emb = lr_trunk * 0.5

    trunk_params = list(model.fc1.parameters()) + \
                   list(model.fc2.parameters()) + \
                   list(model.fc3.parameters())
    heads12_params = list(model.head1.parameters()) + list(model.head2.parameters())
    head3_params   = list(model.head3.parameters())

    opt = torch.optim.AdamW([
        {"params": model.embeddings.parameters(), "lr": lr_emb},
        {"params": trunk_params, "lr": lr_trunk},
        {"params": heads12_params, "lr": lr_heads},
        {"params": head3_params, "lr": lr_heads},
    ], weight_decay=wd)
    return opt


In [8]:
branchynet = DenseBranchyNet(
    hidden_layers_sizes=[256, 256, 256],
    taus=[1.4, 1.6],
    alphas=[0.2, 0.8, 0.9],
    cat_cardinalities=cat_cardinalities,
    embedding_dims=embedding_dims,
    num_numerical=len(numerical_cols),
    num_target_classes=num_target_classes
    ).to(device)

In [9]:
branchynet.set_stage("stage0_trunk_final")
optimizer = make_optimizer(branchynet, lr_trunk=1e-3, lr_heads=2e-3, lr_emb=5e-4)
hist0 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=1.4338 (l1=1.3430 l2=0.8546 l3=0.5350)  | val_acc=0.9251  val_f1=0.9169
[Ep 002] train_loss=0.9660 (l1=1.2144 l2=0.6479 l3=0.2275)  | val_acc=0.9240  val_f1=0.9200
[Ep 003] train_loss=0.8119 (l1=1.1535 l2=0.5337 l3=0.1713)  | val_acc=0.9387  val_f1=0.9335
[Ep 004] train_loss=0.7157 (l1=1.1067 l2=0.4525 l3=0.1472)  | val_acc=0.9518  val_f1=0.9483
[Ep 005] train_loss=0.6517 (l1=1.0694 l2=0.3915 l3=0.1385)  | val_acc=0.8792  val_f1=0.8769
[Ep 006] train_loss=0.5862 (l1=1.0373 l2=0.3417 l3=0.1171)  | val_acc=0.9625  val_f1=0.9592
[Ep 007] train_loss=0.5938 (l1=1.0134 l2=0.3292 l3=0.1419)  | val_acc=0.9635  val_f1=0.9604
[Ep 008] train_loss=0.5772 (l1=0.9943 l2=0.3177 l3=0.1380)  | val_acc=0.9549  val_f1=0.9520
[Ep 009] train_loss=0.5361 (l1=0.9742 l2=0.2961 l3=0.1160)  | val_acc=0.9626  val_f1=0.9603
[Ep 010] train_loss=0.5194 (l1=0.9570 l2=0.2746 l3=0.1204)  | val_acc=0.9656  val_f1=0.9640


In [10]:
branchynet.set_stage("stage1_heads_only")
optimizer = make_optimizer(branchynet, lr_trunk=0.0, lr_heads=3e-3, lr_emb=0.0) 
hist1 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=0.3172 (l1=0.4326 l2=0.1659 l3=0.1089)  | val_acc=0.9613  val_f1=0.9583
[Ep 002] train_loss=0.2682 (l1=0.3016 l2=0.1374 l3=0.1088)  | val_acc=0.9639  val_f1=0.9618
[Ep 003] train_loss=0.2574 (l1=0.2673 l2=0.1326 l3=0.1088)  | val_acc=0.9623  val_f1=0.9594
[Ep 004] train_loss=0.2478 (l1=0.2476 l2=0.1254 l3=0.1088)  | val_acc=0.9630  val_f1=0.9609
[Ep 005] train_loss=0.2433 (l1=0.2340 l2=0.1233 l3=0.1088)  | val_acc=0.9646  val_f1=0.9626
[Ep 006] train_loss=0.2391 (l1=0.2242 l2=0.1205 l3=0.1088)  | val_acc=0.9643  val_f1=0.9623
[Ep 007] train_loss=0.2361 (l1=0.2163 l2=0.1186 l3=0.1088)  | val_acc=0.9643  val_f1=0.9622
[Ep 008] train_loss=0.2349 (l1=0.2097 l2=0.1188 l3=0.1088)  | val_acc=0.9642  val_f1=0.9620
[Ep 009] train_loss=0.2306 (l1=0.2047 l2=0.1148 l3=0.1088)  | val_acc=0.9650  val_f1=0.9632
[Ep 010] train_loss=0.2294 (l1=0.2021 l2=0.1137 l3=0.1088)  | val_acc=0.9649  val_f1=0.9630


In [11]:
branchynet.set_stage("stage2_finetune_all")
optimizer = make_optimizer(branchynet, lr_trunk=1e-4, lr_heads=3e-4, lr_emb=5e-5)
hist2 = branchynet.fit(train_dataloader, valid_dataloader, optimizer, device, epochs=epochs//2)

[Ep 001] train_loss=0.2110 (l1=0.1961 l2=0.1074 l3=0.0954)  | val_acc=0.9671  val_f1=0.9651
[Ep 002] train_loss=0.1983 (l1=0.1937 l2=0.1026 l3=0.0861)  | val_acc=0.9683  val_f1=0.9666
[Ep 003] train_loss=0.1907 (l1=0.1918 l2=0.0994 l3=0.0808)  | val_acc=0.9683  val_f1=0.9666
[Ep 004] train_loss=0.1840 (l1=0.1901 l2=0.0965 l3=0.0764)  | val_acc=0.9687  val_f1=0.9670
[Ep 005] train_loss=0.1791 (l1=0.1884 l2=0.0942 l3=0.0734)  | val_acc=0.9698  val_f1=0.9681
[Ep 006] train_loss=0.1740 (l1=0.1866 l2=0.0915 l3=0.0705)  | val_acc=0.9702  val_f1=0.9684
[Ep 007] train_loss=0.1698 (l1=0.1851 l2=0.0894 l3=0.0681)  | val_acc=0.9702  val_f1=0.9685
[Ep 008] train_loss=0.1656 (l1=0.1834 l2=0.0873 l3=0.0656)  | val_acc=0.9711  val_f1=0.9694
[Ep 009] train_loss=0.1623 (l1=0.1820 l2=0.0859 l3=0.0636)  | val_acc=0.9708  val_f1=0.9691
[Ep 010] train_loss=0.1585 (l1=0.1806 l2=0.0839 l3=0.0614)  | val_acc=0.9717  val_f1=0.9700


In [12]:
@torch.no_grad()
def measure_stage_costs(model, loader, device, n_batches=5, warmup=1):
    """
    Misura i tempi medi dei tre stadi:
      c1   = embed + fc1 + head1
      c2   = (fc2 + head2) incrementale (cioè oltre c1)
      c3   = (fc3 + head3) incrementale (cioè oltre c1+c2)
    Ritorna: dict {"c1": ..., "c2": ..., "c3": ...} in secondi.
    """
    model.eval()

    # Timer containers
    t_c1, t_c12, t_c123 = [], [], []

    # Limita n_batches a quello che c'è davvero
    max_batches = min(n_batches + warmup, len(loader)) if hasattr(loader, "__len__") else n_batches + warmup
    batches_done = 0

    # Utility per timing
    use_cuda = (device.type == "cuda")

    if use_cuda:
        starter = torch.cuda.Event(enable_timing=True)
        ender   = torch.cuda.Event(enable_timing=True)

    for i, (x_num, x_cat, _) in enumerate(loader):
        if batches_done >= max_batches:
            break

        x_num = x_num.to(device, non_blocking=True)
        x_cat = x_cat.to(device, non_blocking=True).long()

        # --- C1 ---
        if use_cuda:
            torch.cuda.synchronize()
            starter.record()
            x  = model._embed_input(x_num, x_cat)
            h1 = F.relu(model.fc1(x)); l1 = model.head1(h1)
            ender.record(); torch.cuda.synchronize()
            c1 = starter.elapsed_time(ender) / 1000.0
        else:
            t0 = time.perf_counter()
            x  = model._embed_input(x_num, x_cat)
            h1 = F.relu(model.fc1(x)); l1 = model.head1(h1)
            c1 = time.perf_counter() - t0

        # --- C1 + C2 ---
        if use_cuda:
            starter.record()
            h2 = F.relu(model.fc2(h1)); l2 = model.head2(h2)
            ender.record(); torch.cuda.synchronize()
            c12 = starter.elapsed_time(ender) / 1000.0 + c1
        else:
            t1 = time.perf_counter()
            h2 = F.relu(model.fc2(h1)); l2 = model.head2(h2)
            c12 = (time.perf_counter() - t1) + c1

        # --- C1 + C2 + C3 ---
        if use_cuda:
            starter.record()
            h3 = F.relu(model.fc3(h2)); l3 = model.head3(h3)
            ender.record(); torch.cuda.synchronize()
            c123 = starter.elapsed_time(ender) / 1000.0 + c12
        else:
            t2 = time.perf_counter()
            h3 = F.relu(model.fc3(h2)); l3 = model.head3(h3)
            c123 = (time.perf_counter() - t2) + c12

        # Salta i warm-up
        if i >= warmup:
            t_c1.append(c1)
            t_c12.append(c12)
            t_c123.append(c123)

        batches_done += 1

    if len(t_c1) == 0:
        raise RuntimeError("measure_stage_costs: nessun batch misurato (loader troppo corto?)")

    c1   = sum(t_c1)  / len(t_c1)
    c12  = sum(t_c12) / len(t_c12)
    c123 = sum(t_c123)/ len(t_c123)

    # tempi incrementali
    return {"c1": c1, "c2": max(1e-9, c12 - c1), "c3": max(1e-9, c123 - c12)}

In [13]:
@torch.no_grad()
def f1_baseline_head3(model, loader, device):
    from sklearn.metrics import f1_score
    model.eval(); y_true, y_pred = [], []
    for x_num, x_cat, y in loader:
        x_num = x_num.to(device); x_cat = x_cat.to(device).long()
        _, _, out3 = model(x_num, x_cat)
        y_true += y.tolist(); y_pred += out3.argmax(1).cpu().tolist()
    return f1_score(y_true, y_pred, average="weighted")

use_margin = True  # stesso setting di evaluate/predict
base_f1 = f1_baseline_head3(branchynet, valid_dataloader, device)
cost = measure_stage_costs(branchynet, valid_dataloader, device, n_batches=5)

best = branchynet.calibrate_taus(
    valid_dataloader, device, use_margin=use_margin,
    n_grid=21, mode="min_cost_at_f1",
    f1_min=base_f1 - 0.01,                 # es. max -1% drop F1
    cost=cost
)
branchynet.taus = [best["t1"], best["t2"]]
print("τ:", branchynet.taus, "F1:", best["f1"], "rates:", best["rates"], "cost_norm:", best["cost"])


τ: [3.703824996948242, 4.516445159912109] F1: 0.9817587132513133 rates: {'r1': 0.5, 'r2': 0.25, 'r3': 0.25} cost_norm: 0.5165398151965035


In [19]:
y_true = torch.cat([y for _, _, y in test_dataloader]).numpy()

In [33]:
start_model = time.time()
model_preds = model.predict(test_dataloader,device)
end_model = time.time()

In [32]:
start_branchynet = time.time()
branchy_preds = branchynet.predict(test_dataloader, device, early_exit=True, use_margin=use_margin)
end_branchynet = time.time()

In [34]:
dur_model   = end_model - start_model
dur_branchy = end_branchynet - start_branchynet

print(f"Classic DNN: {dur_model:.6f} s")
print(f"BranchyMLP:  {dur_branchy:.6f} s")

Classic DNN: 6.849227 s
BranchyMLP:  5.452931 s


In [35]:
overhead = 100 - ((100 / dur_branchy) * dur_model)
print(f'BranchyNet Overhead: {overhead:.2f}%')

BranchyNet Overhead: -25.61%


In [20]:
print("\n=== Classification Report DNN===")
print(classification_report(y_true, model_preds.numpy(), target_names=target_names, digits=4))
print("\n=== Classification Report BRANCHYNET===")
print(classification_report(y_true, branchy_preds.numpy(), target_names=target_names, digits=4))



=== Classification Report DNN===
              precision    recall  f1-score   support

        ddos     0.9878    0.9845    0.9862     50614
         dos     0.2125    0.6538    0.3208        26
   injection     0.9702    0.9674    0.9688     50968
        mitm     0.0000    0.0000    0.0000         7
      normal     0.9339    0.9426    0.9382      9197
    password     0.9919    0.9944    0.9932    189474
    scanning     0.9923    0.9304    0.9604      4686
         xss     0.9929    0.9929    0.9929    211140

    accuracy                         0.9886    516112
   macro avg     0.7602    0.8083    0.7700    516112
weighted avg     0.9887    0.9886    0.9886    516112


=== Classification Report BRANCHYNET===
              precision    recall  f1-score   support

        ddos     0.9761    0.9806    0.9784     50614
         dos     0.2609    0.6923    0.3789        26
   injection     0.9517    0.9582    0.9549     50968
        mitm     0.0000    0.0000    0.0000         7
   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
