Facciamo il training del modello che crea le maschere e verifichiamo se il tempo di inferenza si riduce.

In [6]:
from utils import load_and_prepare_data
from neural_network import NeuralNetwork
from masknet import *
import torch
from sklearn.metrics import classification_report, confusion_matrix


In [7]:
numerical_cols = [
        "duration",
        "dst_bytes",
        "missed_bytes",
        "src_bytes",
        "src_ip_bytes",
        "src_pkts",
        "dst_pkts",
        "dst_ip_bytes",
        "http_request_body_len",
        "http_response_body_len"

    ]

categorical_cols = [
        "proto",
        "conn_state",
        "http_status_code",
        "http_method",
        "http_orig_mime_types",
        "http_resp_mime_types",
    ]


target_col = 'type'
values_to_remove = {'type': ['mitm', 'dos']}

dataset_path = './Dataset/http_ton.csv'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASK_THRESHOLD = 0.5


In [8]:
best_model = NeuralNetwork.load('best_model.pt', device=device)
config = best_model.config

In [9]:
mask_generator = MaskGenerator(
    cat_cardinalities=config['cat_cardinalities'],
    embedding_dims=config['embedding_dims'],
    num_numerical_features=config['num_numerical_features'],
    mask_sizes=config['hidden_layers_sizes'],
    share_embeddings=best_model.embeddings
)

In [10]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities, cw, target_names = load_and_prepare_data(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    rows_to_remove=values_to_remove,
    batch_size=4096
)

In [11]:
mask_generator.fit(
    model=best_model,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    threshold=MASK_THRESHOLD,
    alpha=0.001,
    epochs=70,
    warmup_epochs=1,
    lr=0.001
)

Epoch   0 | Valid Loss: 0.0223 | Valid Sparsity: 0.079 | F1: 0.9832 | Acc: 0.9938 | Alpha: 0.000100
Epoch   1 | Valid Loss: 0.0240 | Valid Sparsity: 0.099 | F1: 0.9844 | Acc: 0.9940 | Alpha: 0.001000
Epoch   2 | Valid Loss: 0.0241 | Valid Sparsity: 0.096 | F1: 0.9845 | Acc: 0.9940 | Alpha: 0.001100
Epoch   3 | Valid Loss: 0.0242 | Valid Sparsity: 0.125 | F1: 0.9846 | Acc: 0.9941 | Alpha: 0.001200
Epoch   4 | Valid Loss: 0.0243 | Valid Sparsity: 0.144 | F1: 0.9849 | Acc: 0.9942 | Alpha: 0.001300
Epoch   5 | Valid Loss: 0.0244 | Valid Sparsity: 0.156 | F1: 0.9847 | Acc: 0.9942 | Alpha: 0.001400
Epoch   6 | Valid Loss: 0.0245 | Valid Sparsity: 0.151 | F1: 0.9848 | Acc: 0.9941 | Alpha: 0.001500
Epoch   7 | Valid Loss: 0.0247 | Valid Sparsity: 0.159 | F1: 0.9846 | Acc: 0.9941 | Alpha: 0.001600
Epoch   8 | Valid Loss: 0.0249 | Valid Sparsity: 0.160 | F1: 0.9848 | Acc: 0.9942 | Alpha: 0.001700
Epoch   9 | Valid Loss: 0.0251 | Valid Sparsity: 0.183 | F1: 0.9848 | Acc: 0.9942 | Alpha: 0.001800


### Senza maschere:

In [12]:
y_pred = best_model.predict(test_dataloader, device)
y_true = torch.cat([y for _, _, y in test_dataloader]).numpy()
print("\n=== Classification Report SENZA MASCHERE ===")
print(classification_report(y_true, y_pred.numpy(), target_names=target_names, digits=4))


=== Classification Report SENZA MASCHERE ===
              precision    recall  f1-score   support

        ddos     0.9900    0.9949    0.9925     50615
   injection     0.9717    0.9871    0.9794     50967
      normal     0.9677    0.9906    0.9790      9197
    password     0.9980    0.9972    0.9976    189474
    scanning     0.7829    0.9942    0.8760      4686
         xss     0.9980    0.9867    0.9924    211140

    accuracy                         0.9916    516079
   macro avg     0.9514    0.9918    0.9695    516079
weighted avg     0.9922    0.9916    0.9917    516079



### Con maschere

In [13]:
y_pred = best_model.predict_with_masks(test_dataloader, device, mask_generator, threshold=MASK_THRESHOLD)
y_true = torch.cat([y for _, _, y in test_dataloader]).numpy()
print("\n=== Classification Report CON MASCHERE ===")
print(classification_report(y_true, y_pred.numpy(), target_names=target_names, digits=4))


=== Classification Report CON MASCHERE ===
              precision    recall  f1-score   support

        ddos     0.9925    0.9943    0.9934     50615
   injection     0.9791    0.9834    0.9812     50967
      normal     0.9841    0.9867    0.9854      9197
    password     0.9975    0.9981    0.9978    189474
    scanning     0.9829    0.9298    0.9556      4686
         xss     0.9952    0.9942    0.9947    211140

    accuracy                         0.9938    516079
   macro avg     0.9885    0.9811    0.9847    516079
weighted avg     0.9938    0.9938    0.9938    516079



### Benchmark Temporale

In [14]:
benchmark_inference_speed(
    model=best_model,
    mask_generator=mask_generator,
    dataloader=test_dataloader,
    device=device,
    threshold=MASK_THRESHOLD,
    num_runs=10
)

Inference normale:     5.6575s
Inference con maschere: 5.7193s
Overhead relativo:     1.1%


(5.657474994659424, 5.719270062446594)