In [1]:
import torch
from torch import nn
import data_utils
from training.MADGAN_train import MadGanTrainingPipeline
from models.MADGAN import Generator, Discriminator, AnomalyDetector
from utils import evaluation

# Parameters

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model_type = "MAD-GAN"
num_generated_features = 6
seq_length = 30
seq_stride = 10

random_seed = 0
num_epochs = 100
batch_size = 256
lr = 0.001
latent_dim = 15
hidden_dim = 100

# Load data

In [4]:
train_dl, test_dl = data_utils.load_kdd99(seq_length, seq_stride, num_generated_features,batch_size, model_type)

load kdd99_train from .npy
load kdd99_test from .npy


# Model

In [5]:
pipeline = MadGanTrainingPipeline()

In [6]:
generator = Generator(
    latent_space_dim=latent_dim,
    hidden_units=hidden_dim,
    output_dim=num_generated_features)
generator.to(DEVICE)

Generator(
  (lstm): LSTM(15, 100, num_layers=2, batch_first=True, dropout=0.1)
  (linear): Linear(in_features=100, out_features=6, bias=True)
)

In [7]:
discriminator = Discriminator(input_dim=num_generated_features,
    hidden_units=hidden_dim,
    add_batch_mean=False)
discriminator.to(DEVICE)

Discriminator(
  (lstm): LSTM(6, 100, num_layers=2, batch_first=True, dropout=0.1)
  (linear): Linear(in_features=100, out_features=1, bias=True)
  (activation): Sigmoid()
)

# Loss and Optimizer

In [8]:
def loss_function(inputs, targets):
    return nn.BCELoss()(inputs, targets)

In [9]:
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=lr)
generator_optim = torch.optim.Adam(generator.parameters(), lr=lr)

# Train

In [10]:
pipeline.train(seq_length, latent_dim, train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, 
                loss_function, random_seed, num_epochs, DEVICE)

Epoch 0training:
G_loss: 2.3129790980707514, D_loss_real: 0.6358286591246725, D_loss_fake: 0.593481320024214
Evaluation metrics: {'D_loss': 0.9408629876344315, 'G_acc': 0.37529979051703616, 'D_acc': 1.591405477239678}
Epoch 1training:
G_loss: 1.7785078894685615, D_loss_real: 0.6077956793491136, D_loss_fake: 0.52419513463974
Evaluation metrics: {'D_loss': 1.8326628344046638, 'G_acc': 0.5449261705492444, 'D_acc': 1.3679562726786718}
Epoch 2training:
G_loss: 1.925299668718468, D_loss_real: 0.5787332387133078, D_loss_fake: 0.4421982808207924
Evaluation metrics: {'D_loss': 3.4522700062687535, 'G_acc': 0.6379971280307968, 'D_acc': 0.4366723670397398}
Epoch 3training:
G_loss: 1.5060006084767255, D_loss_real: 0.6179408759894696, D_loss_fake: 0.519465546309948
Evaluation metrics: {'D_loss': 1.6627681076217808, 'G_acc': 0.6857035544560981, 'D_acc': 0.5284781857475715}
Epoch 4training:
G_loss: 1.0663603205572476, D_loss_real: 0.6277930578724904, D_loss_fake: 0.5592702503908764
Evaluation metrics:

Evaluation metrics: {'D_loss': 3.420116082374296, 'G_acc': 0.7997118814660169, 'D_acc': 0.9447450955070247}
Epoch 38training:
G_loss: 2.2598427274010398, D_loss_real: 0.3454078818512657, D_loss_fake: 0.21319458572701974
Evaluation metrics: {'D_loss': 3.066730226259775, 'G_acc': 0.7795368467742297, 'D_acc': 0.9799652674037558}
Epoch 39training:
G_loss: 1.913848277926445, D_loss_real: 0.3958087173459882, D_loss_fake: 0.2822732250798832
Evaluation metrics: {'D_loss': 2.812660430379482, 'G_acc': 0.7512957854630725, 'D_acc': 0.8982969548726947}
Epoch 40training:
G_loss: 1.8174130125479264, D_loss_real: 0.4101984558965672, D_loss_fake: 0.2995548855851997
Evaluation metrics: {'D_loss': 2.4818845215239054, 'G_acc': 0.7444361305452999, 'D_acc': 1.0390354345113502}
Epoch 41training:
G_loss: 1.726599861275066, D_loss_real: 0.4234100271270356, D_loss_fake: 0.3626384783197533
Evaluation metrics: {'D_loss': 2.5476464311075953, 'G_acc': 0.6851791684646062, 'D_acc': 1.135425120437701}
Epoch 42training

Epoch 75training:
G_loss: 1.7502370872280815, D_loss_real: 0.42417231040414083, D_loss_fake: 0.30127001655372704
Evaluation metrics: {'D_loss': 2.3671272980734472, 'G_acc': 0.7810317124264228, 'D_acc': 0.9638820693205675}
Epoch 76training:
G_loss: 1.6591102079911666, D_loss_real: 0.42686160778288135, D_loss_fake: 0.34308375201442026
Evaluation metrics: {'D_loss': 2.070668927128451, 'G_acc': 0.7130286436856101, 'D_acc': 1.0600872171048674}
Epoch 77training:
G_loss: 1.5940453759648583, D_loss_real: 0.4677292990074916, D_loss_fake: 0.36023798733949663
Evaluation metrics: {'D_loss': 1.9694404416751368, 'G_acc': 0.7131684796180132, 'D_acc': 1.0682970190063659}
Epoch 78training:
G_loss: 1.6919855670495467, D_loss_real: 0.47114130729301407, D_loss_fake: 0.36875035708600823
Evaluation metrics: {'D_loss': 2.4264113322440823, 'G_acc': 0.781253357422707, 'D_acc': 0.969884138901771}
Epoch 79training:
G_loss: 1.6977559813044287, D_loss_real: 0.4521961365233768, D_loss_fake: 0.35680637102235446
Eval

# Evaluation

In [11]:
def scoring_function(model, data):
    x = torch.tensor(data, dtype=torch.float32).unsqueeze(dim=0)
    out = model.predict(x).squeeze()
    return out.numpy()

In [12]:
def torch_scoring_function(model, data):
    return model.predict(data)

In [14]:
ad = AnomalyDetector(discriminator=discriminator, generator=generator, latent_space_dim=latent_dim, anomaly_threshold=0.5)
total_em = total_mv = total_acc = 0
for X, Y, P in test_dl:
    prediction = ad.predict(X)
    acc = (prediction == Y).float()
    acc = acc.sum().div(batch_size)/30
    print(acc.item())
    scores = evaluation.torch_emmv_scores(ad,X,torch_scoring_function)
    print(scores)
    total_acc += acc.item()
    total_em += scores['em']
    total_mv += scores['mv']
print(total_em/len(test_dl),total_mv/len(test_dl),total_acc/len(test_dl))

0.79296875
{'em': 0.009643218458333332, 'mv': 14253.036804199219}
0.8197916746139526
{'em': 0.0255307735125, 'mv': 11815.361022949219}
0.7920572757720947
{'em': 0.0001620652083333333, 'mv': 8953.750305175781}
0.9951822757720947
{'em': 0.0008959864666666667, 'mv': 7061.455993652344}
0.9139322638511658
{'em': 0.0009405876250000002, 'mv': 1937.8691417480475}
0.7670572996139526
{'em': 0.00025033855416666667, 'mv': 6019.761657714844}
0.8186197876930237
{'em': 0.000441806, 'mv': 30586.137084960938}
0.833984375
{'em': 0.00043890708333333333, 'mv': 2632.2827911376953}
0.7444010376930237
{'em': 0.0002626013125, 'mv': 61709.28955078125}
0.7621093988418579
{'em': 0.002436067174999999, 'mv': 1064.2021179199219}
0.7005208134651184
{'em': 0.0018949774374999997, 'mv': 30360.598754882812}
0.7307291626930237
{'em': 0.00025055534999999997, 'mv': 1038.9903259277344}
0.811718761920929
{'em': 0.00015595075, 'mv': 9793.584594726562}
0.7419270873069763
{'em': 0.00219353175, 'mv': 2186.026840209961}
0.7723958

{'em': 0.00021995007500000003, 'mv': 2.429849556641779e-12}
0.9997395873069763
{'em': 0.000219954275, 'mv': 2.4297928512682815e-12}
1.0
{'em': 0.0001, 'mv': 1.829496873653616e-11}
0.9996093511581421
{'em': 0.00021987314999999996, 'mv': 1.829545669508126e-11}
0.9998697638511658
{'em': 0.000219945275, 'mv': 1.8296066643262627e-11}
1.0
{'em': 0.0002199194125, 'mv': 1.8294724757263613e-11}
0.9998697638511658
{'em': 0.0001, 'mv': 1.829557868471754e-11}
0.9997395873069763
{'em': 0.0002199146125, 'mv': 1.8295212715808717e-11}
0.9998697638511658
{'em': 0.000219948875, 'mv': 1.829570067435381e-11}
0.9998697638511658
{'em': 0.0002198286875, 'mv': 1.8296005648444497e-11}
0.9998697638511658
{'em': 0.00021995127499999998, 'mv': 7.255341689588256e-12}
0.9998697638511658
{'em': 0.0001, 'mv': 1.8295212715808717e-11}
1.0
{'em': 0.00021995667499999998, 'mv': 6.369552139113448e-13}
0.9897135496139526
{'em': 0.0029942720416666667, 'mv': 1406.5110421875}
0.5721353888511658
{'em': 0.0027918217083333333, 'mv