In [1]:
from utils import load_and_prepare_data
from neural_network import NeuralNetwork, FocalLoss
import torch

dataset_path = './Dataset/http_ton.csv'

numerical_cols = [
        "duration",
        "dst_bytes",
        "missed_bytes",
        "src_bytes",
        "src_ip_bytes",
        "src_pkts",
        "dst_pkts",
        "dst_ip_bytes",
        "http_request_body_len",
        "http_response_body_len"

    ]

categorical_cols = [
        "proto",
        "conn_state",
        "http_status_code",
        "http_method",
        "http_orig_mime_types",
        "http_resp_mime_types",
    ]

target_col = 'type'
values_to_remove = {'type': ['mitm', 'dos']}

In [None]:
train_dataloader, valid_dataloader, test_dataloader, cat_cardinalities, cw = load_and_prepare_data(
    file_path=dataset_path,
    target_col=target_col,
    numerical_cols=numerical_cols,
    categorical_cols=categorical_cols,
    rows_to_remove=values_to_remove,
    batch_size=4096
)

In [None]:
model = NeuralNetwork(
    hidden_layers_sizes=[512, 256, 128], 
    cat_cardinalities=cat_cardinalities,
    embedding_dims=[min(50, (card + 1) // 2) for card in cat_cardinalities],
    num_numerical_features=len(numerical_cols),
    num_target_classes=6,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.85, patience=5)

model.fit(
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    device=device,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    epochs=100,
    weights = cw
)

--- Epoch: 0  |  Loss: 0.4823  |  F1 Score: 0.6587  |  Accuracy: 0.7538 ---
--- Epoch: 1  |  Loss: 0.3070  |  F1 Score: 0.8236  |  Accuracy: 0.9274 ---
--- Epoch: 2  |  Loss: 0.2332  |  F1 Score: 0.8263  |  Accuracy: 0.9297 ---
--- Epoch: 3  |  Loss: 0.1926  |  F1 Score: 0.8920  |  Accuracy: 0.9555 ---
--- Epoch: 4  |  Loss: 0.1531  |  F1 Score: 0.8962  |  Accuracy: 0.9590 ---
--- Epoch: 5  |  Loss: 0.1371  |  F1 Score: 0.9155  |  Accuracy: 0.9621 ---
--- Epoch: 6  |  Loss: 0.1340  |  F1 Score: 0.9069  |  Accuracy: 0.9631 ---
--- Epoch: 7  |  Loss: 0.1774  |  F1 Score: 0.8562  |  Accuracy: 0.9421 ---
--- Epoch: 8  |  Loss: 0.1221  |  F1 Score: 0.9163  |  Accuracy: 0.9658 ---
--- Epoch: 9  |  Loss: 0.1176  |  F1 Score: 0.9199  |  Accuracy: 0.9671 ---
--- Epoch: 10  |  Loss: 0.1160  |  F1 Score: 0.9207  |  Accuracy: 0.9650 ---
--- Epoch: 11  |  Loss: 0.0937  |  F1 Score: 0.9350  |  Accuracy: 0.9739 ---
--- Epoch: 12  |  Loss: 0.1064  |  F1 Score: 0.9218  |  Accuracy: 0.9705 ---
--- Epoch

KeyboardInterrupt: 