In [1]:
from utils import load_and_prepare_data
from neural_network import NeuralNetwork
import torch

dataset_path = './Dataset/http_ton.csv'

numerical_cols = [
        "duration",
        "dst_bytes",
        "missed_bytes",
        "src_bytes",
        "src_ip_bytes",
        "src_pkts",
        "dst_pkts",
        "dst_ip_bytes",
        "http_request_body_len",
        "http_response_body_len"

    ]

categorical_cols = [
        "proto",
        "conn_state",
        "http_status_code",
        "http_method",
        "http_orig_mime_types",
        "http_resp_mime_types",
    ]

target_col = 'type'
values_to_remove = {'type': ['mitm', 'dos']}

In [2]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities = load_and_prepare_data(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    rows_to_remove=values_to_remove,
    batch_size=1024
)

In [None]:
model = NeuralNetwork(
    hidden_layers_sizes=[256, 128, 64],
    cat_cardinalities=cat_cardinalities,
    embedding_dims=[min(50, (card + 1) // 2) for card in cat_cardinalities],
    num_numerical_features=len(numerical_cols),
    num_target_classes=6
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.fit(
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    lr=1e-3,
    epochs=20
)

--- Epoch: 0  |  Loss: 0.1573  |  F1 Score: 0.3959  |  Accuracy: 0.4897 ---
--- Epoch: 1  |  Loss: 0.1255  |  F1 Score: 0.6893  |  Accuracy: 0.8134 ---
--- Epoch: 2  |  Loss: 0.0897  |  F1 Score: 0.6307  |  Accuracy: 0.7040 ---
--- Epoch: 3  |  Loss: 0.0670  |  F1 Score: 0.5648  |  Accuracy: 0.6303 ---
--- Epoch: 4  |  Loss: 0.0577  |  F1 Score: 0.4395  |  Accuracy: 0.5160 ---
--- Epoch: 5  |  Loss: 0.0699  |  F1 Score: 0.3627  |  Accuracy: 0.4304 ---
--- Epoch: 6  |  Loss: 0.0577  |  F1 Score: 0.5545  |  Accuracy: 0.6334 ---
--- Epoch: 7  |  Loss: 0.0458  |  F1 Score: 0.7094  |  Accuracy: 0.8340 ---
--- Epoch: 8  |  Loss: 0.0405  |  F1 Score: 0.3708  |  Accuracy: 0.4744 ---
--- Epoch: 9  |  Loss: 0.0412  |  F1 Score: 0.5997  |  Accuracy: 0.6544 ---
