In [1]:
import torch
from torch import nn
import data_utils
from training.MADGAN_train import MadGanTrainingPipeline
from models.MADGAN import Generator, Discriminator, AnomalyDetector
from utils import evaluation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Hyperparameters

In [3]:
# choose dataset (kdd99, aapl, gm, axp)
dataset = "kdd99"

In [4]:
model_type = "MAD-GAN"
num_features = 6
seq_len = 30
seq_stride = 10
gen_seq_len = seq_len

random_seed = 0
num_epochs = 100
batch_size = 256
lr = 1e-5
wd = 5e-7

latent_dim = 100
hidden_dim = 250
anomaly_threshold = 0.5

# Load data

In [5]:
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_len, seq_stride, num_features, gen_seq_len, batch_size)
else:
    file_path = 'data/financial_data/Stocks/'+dataset+'.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_len, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model

In [6]:
generator = Generator(input_dim=latent_dim,hidden_size=hidden_dim,output_dim=num_features).to(device=DEVICE)

In [7]:
discriminator = Discriminator(input_dim=num_features,hidden_size=hidden_dim).to(device=DEVICE)

# Loss and Optimizer

In [8]:
def loss_function(inputs, targets):
    return nn.BCELoss()(inputs, targets)

In [9]:
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=lr, weight_decay=wd)
generator_optim = torch.optim.Adam(generator.parameters(), lr=lr, weight_decay=wd)

# Train

In [10]:
pipeline = MadGanTrainingPipeline()

In [None]:
if dataset == "kdd99":
    pipeline.train_kdd99(seq_len, latent_dim, train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, anomaly_threshold, loss_function, random_seed, num_epochs, DEVICE)
else:
    pipeline.train_financial(seq_len, latent_dim, tscv_dl_list, discriminator, generator, discriminator_optim, generator_optim, anomaly_threshold, loss_function, random_seed, num_epochs, DEVICE)

Epoch 0: G_loss: 1.0056039821017873, D_loss_real: 0.8997917329723185, D_loss_fake: 0.539320739290931
Epoch 1: G_loss: 0.9012151143767617, D_loss_real: 0.8551822410388427, D_loss_fake: 0.5479047230698846
Epoch 2: G_loss: 0.6457203054969961, D_loss_real: 0.7357653288678689, D_loss_fake: 0.7694859986955469
Epoch 3: G_loss: 0.731800566478209, D_loss_real: 0.6718300944024866, D_loss_fake: 0.6722782866521315
Epoch 4: G_loss: 0.7432030008597807, D_loss_real: 0.7361921372738751, D_loss_fake: 0.669095099514181
Epoch 5: G_loss: 0.6770262368700721, D_loss_real: 0.7397203518585725, D_loss_fake: 0.7315865245732394
Epoch 6: G_loss: 0.7483059587803754, D_loss_real: 0.7018042896281589, D_loss_fake: 0.66213627836921
Epoch 7: G_loss: 0.7408932211724195, D_loss_real: 0.6840869690884244, D_loss_fake: 0.6636316443031485
Epoch 8: G_loss: 0.6979421648112211, D_loss_real: 0.7617347385395657, D_loss_fake: 0.7161209385503422
Epoch 9: G_loss: 0.7313909376209432, D_loss_real: 0.7129828341982581, D_loss_fake: 0.67

Epoch 80: G_loss: 0.6164416878060861, D_loss_real: 0.42176526104184714, D_loss_fake: 1.038955408334732
Epoch 81: G_loss: 1.05760613842444, D_loss_real: 0.7053654785860669, D_loss_fake: 0.5481032918800007
Epoch 82: G_loss: 0.8009182783690366, D_loss_real: 0.7299985538829457, D_loss_fake: 0.6400485810908404
Epoch 83: G_loss: 0.7411862419410186, D_loss_real: 0.7127690756862813, D_loss_fake: 0.7029784482988444
Epoch 84: G_loss: 0.7394547243009914, D_loss_real: 0.732217704707926, D_loss_fake: 0.6866128154776313
Epoch 85: G_loss: 0.7280404126102274, D_loss_real: 0.7491680884903128, D_loss_fake: 0.68779144463214
Epoch 86: G_loss: 0.7310237808661028, D_loss_real: 0.7432824725454504, D_loss_fake: 0.6785609998486258
Epoch 87: G_loss: 0.7446167447350241, D_loss_real: 0.7361321541396054, D_loss_fake: 0.6683426495302808
Epoch 88: G_loss: 0.7285148750651966, D_loss_real: 0.7124298970807682, D_loss_fake: 0.6828074088150805
Epoch 89: G_loss: 0.7677741072394632, D_loss_real: 0.7232976593754509, D_loss_