In [1]:
import torch
import data_utils
from models.DeepAnT import DeepAnT_CNN, DeepAnT_LSTM
from training.DeepAnT_train import DeepAntTrainingPipeline

In [2]:
# Try to use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters

In [3]:
# choose between CNN and LSTM absed DeepAnT models
model_type = "DeepAnT_CNN"

# hyperparameters
if model_type == "DeepAnT_LSTM":
    num_features = 6
    src_seq_length = 30
    src_seq_stride = 10
    gen_seq_len = src_seq_length
    anm_det_thr = 0.5
    num_epochs = 50
    batch_size = 256
    lr = 0.001
elif model_type == "DeepAnT_CNN":
    num_features = 6
    src_seq_length = 30
    src_seq_stride = 10
    gen_seq_len = 1
    anm_det_thr = 0.8 
    num_epochs = 100
    batch_size = 256
    lr = 0.001

# Load data

In [4]:
train_dl, test_dl = data_utils.load_kdd99(src_seq_length, src_seq_stride, num_features, gen_seq_len, batch_size)

load kdd99_train from .npy
load kdd99_test from .npy


# Model

In [5]:
if model_type == "DeepAnT_LSTM":
    model = DeepAnT_LSTM(src_seq_length,num_features,anm_det_thr)
elif model_type == "DeepAnT_CNN":
    model = DeepAnT_CNN(src_seq_length,num_features,anm_det_thr)

In [6]:
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-5)

# Training

In [7]:
pipeline = DeepAntTrainingPipeline()

pipeline.train(model=model, loss=criterion, optimizer=optimizer, train_dl=train_dl, test_dl=test_dl, num_epochs=num_epochs, DEVICE=DEVICE)

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Training Loss: 1.4579969584941863 - Epoch: 1
Training Loss: 1.2936990364031358 - Epoch: 2
Training Loss: 1.0992158144712447 - Epoch: 3
Training Loss: 0.8939644301479513 - Epoch: 4
Training Loss: 0.699322712150487 - Epoch: 5
Training Loss: 0.5593394841660153 - Epoch: 6
Training Loss: 0.4816588120027022 - Epoch: 7
Training Loss: 0.4505217581987381 - Epoch: 8
Training Loss: 0.4380418522791429 - Epoch: 9
Training Loss: 0.43364796110174875 - Epoch: 10
Training Loss: 0.43019406402652915 - Epoch: 11
Training Loss: 0.42675626454028215 - Epoch: 12
Training Loss: 0.42343502139503303 - Epoch: 13
Training Loss: 0.42137133011763744 - Epoch: 14
Training Loss: 0.418591419133273 - Epoch: 15
Training Loss: 0.4151797882535241 - Epoch: 16
Training Loss: 0.4133163653991439 - Epoch: 17
Training Loss: 0.411122111027891 - Epoch: 18
Training Loss: 0.4088273936374621 - Epoch: 19
Training Loss: 0.40589532533829864 - Epoch: 20
Training Loss: 0.40322964760390195 - Epoch: 21
Training Loss: 0.4013939634642818 - Epo

# Evaluation

In [8]:
test_em, test_mv = pipeline.evaluate(model, test_dl)
print("Final results - EM: {0}, MV: {1}".format(test_em, test_mv))

{'em': 0.00020406829999999995, 'mv': 2850.6073608398438}
{'em': 0.00020204626666666666, 'mv': 2363.0722045898438}
{'em': 0.00022125287499999996, 'mv': 1790.7500610351562}
{'em': 0.0005449498, 'mv': 1412.2911987304688}
{'em': 0.0003466159583333333, 'mv': 6220.423095703125}
{'em': 0.00023643075416666663, 'mv': 1203.9523315429688}
{'em': 0.0001897176291666666, 'mv': 6117.2274169921875}
{'em': 0.00021600474999999996, 'mv': 526.4565582275391}
{'em': 0.00022052762499999998, 'mv': 12341.85791015625}
{'em': 0.0002334837083333333, 'mv': 212.84042358398438}
{'em': 0.00026864666666666663, 'mv': 6072.1197509765625}
{'em': 0.00027904920833333324, 'mv': 207.79806518554688}
{'em': 0.000252915625, 'mv': 1958.7169189453125}
{'em': 0.0002544186875, 'mv': 437.2053680419922}
{'em': 0.00023090970833333332, 'mv': 344.9389343261719}
{'em': 0.00023210147735226478, 'mv': 9520.83544921875}
{'em': 0.0004719408124999999, 'mv': 25935.7265625}
{'em': 0.000492, 'mv': 3.1112010670753695e-17}
{'em': 0.0004974808, 'mv'

{'em': 0.000492, 'mv': 5.5362293096550275e-06}
{'em': 0.000492, 'mv': 5.5362293096550275e-06}
{'em': 0.000492, 'mv': 5.5362293096550275e-06}
{'em': 0.000492, 'mv': 1.2741065802401863e-07}
{'em': 0.000492, 'mv': 3.075387667195173e-06}
{'em': 0.000492, 'mv': 3.075387667195173e-06}
{'em': 0.000492, 'mv': 3.075387667195173e-06}
{'em': 0.000492, 'mv': 3.075387667195173e-06}
{'em': 0.000492, 'mv': 3.075387667195173e-06}
{'em': 0.000492, 'mv': 1.2741065802401863e-07}
{'em': 0.000492, 'mv': 4.707435209638788e-06}
{'em': 0.000492, 'mv': 4.707435209638788e-06}
{'em': 0.0005633047124999999, 'mv': 4609.7838134765625}
{'em': 0.00031098068750000005, 'mv': 24824.8212890625}
{'em': 0.00031618225, 'mv': 7996.9951171875}
{'em': 0.000301881075, 'mv': 15306.38232421875}
{'em': 0.0005222954, 'mv': 23767.20263671875}
{'em': 0.000492, 'mv': 0.001988718460779637}
{'em': 0.000492, 'mv': 0.2193032205104828}
{'em': 0.000492, 'mv': 0.0014593927189707756}
{'em': 0.000492, 'mv': 0.0017851361772045493}
{'em': 0.0004