In [1]:
import torch
from torch import nn
import data_utils
from training.MADGAN_train import MadGanTrainingPipeline
from models.MADGAN import Generator, Discriminator, AnomalyDetector
from utils import evaluation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Hyperparameters

In [3]:
# choose dataset (kdd99, aapl, gm, axp)
dataset = "aapl"

In [4]:
model_type = "MAD-GAN"
num_features = 7
seq_len = 30
seq_stride = 10
gen_seq_len = seq_len

random_seed = 0
num_epochs = 100
batch_size = 8
lr = 1e-5
wd = 5e-7

latent_dim = 100
hidden_dim = 250
anomaly_threshold = 0.5

# Load data

In [5]:
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_len, seq_stride, num_features, gen_seq_len, batch_size)
else:
    file_path = 'data/financial_data/Stocks/'+dataset+'.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_len, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model

In [6]:
generator = Generator(input_dim=latent_dim,hidden_size=hidden_dim,output_dim=num_features).to(device=DEVICE)

In [7]:
discriminator = Discriminator(input_dim=num_features,hidden_size=hidden_dim).to(device=DEVICE)

# Loss and Optimizer

In [8]:
def loss_function(inputs, targets):
    return nn.BCELoss()(inputs, targets)

In [9]:
discriminator_optim = torch.optim.Adam(discriminator.parameters(), lr=lr, weight_decay=wd)
generator_optim = torch.optim.Adam(generator.parameters(), lr=lr, weight_decay=wd)

# Train

In [10]:
pipeline = MadGanTrainingPipeline()

In [11]:
if dataset == "kdd99":
    pipeline.train_kdd99(seq_len, latent_dim, train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, anomaly_threshold, loss_function, random_seed, num_epochs, DEVICE)
else:
    pipeline.train_financial(seq_len, latent_dim, tscv_dl_list, discriminator, generator, discriminator_optim, generator_optim, anomaly_threshold, loss_function, random_seed, num_epochs, DEVICE)

Epoch 0: G_loss: 1.0045518159866333, D_loss_real: 0.852536928653717, D_loss_fake: 0.45769746899604796
Epoch 1: G_loss: 1.009397029876709, D_loss_real: 0.8372261166572571, D_loss_fake: 0.4553071022033691
Epoch 2: G_loss: 1.0145559072494508, D_loss_real: 0.819399356842041, D_loss_fake: 0.45327579975128174
Epoch 3: G_loss: 1.0211325407028198, D_loss_real: 0.8031409859657288, D_loss_fake: 0.44903733134269713
Epoch 4: G_loss: 1.0202515125274658, D_loss_real: 0.7882660984992981, D_loss_fake: 0.45080235600471497
Epoch 5: G_loss: 1.0266775608062744, D_loss_real: 0.7713093161582947, D_loss_fake: 0.44772478342056277
Epoch 6: G_loss: 1.0274111032485962, D_loss_real: 0.7571299791336059, D_loss_fake: 0.4432076156139374
Epoch 7: G_loss: 1.0283027410507202, D_loss_real: 0.7416645765304566, D_loss_fake: 0.4456164717674255
Epoch 8: G_loss: 1.0283642053604125, D_loss_real: 0.7297725081443787, D_loss_fake: 0.4428197741508484
Epoch 9: G_loss: 1.0364309549331665, D_loss_real: 0.7189203381538392, D_loss_fak

Epoch 80: G_loss: 0.7684701800346374, D_loss_real: 0.7611085653305054, D_loss_fake: 0.6257846593856812
Epoch 81: G_loss: 0.7760273337364196, D_loss_real: 0.7530265569686889, D_loss_fake: 0.6213193893432617
Epoch 82: G_loss: 0.780797815322876, D_loss_real: 0.7449196934700012, D_loss_fake: 0.6159611105918884
Epoch 83: G_loss: 0.7872031331062317, D_loss_real: 0.7358371138572692, D_loss_fake: 0.6110439896583557
Epoch 84: G_loss: 0.7886435031890869, D_loss_real: 0.7281290531158447, D_loss_fake: 0.6078667640686035
Epoch 85: G_loss: 0.7973419308662415, D_loss_real: 0.7201343655586243, D_loss_fake: 0.6014564514160157
Epoch 86: G_loss: 0.8018370985984802, D_loss_real: 0.7080768465995788, D_loss_fake: 0.6019129633903504
Epoch 87: G_loss: 0.8045929670333862, D_loss_real: 0.701444935798645, D_loss_fake: 0.598310649394989
Epoch 88: G_loss: 0.8145856857299805, D_loss_real: 0.694626533985138, D_loss_fake: 0.5931620955467224
Epoch 89: G_loss: 0.8180741786956787, D_loss_real: 0.6864175438880921, D_loss

Epoch 60: G_loss: 0.8378542065620422, D_loss_real: 0.6044990089204576, D_loss_fake: 0.5904137955771552
Epoch 61: G_loss: 0.8489077091217041, D_loss_real: 0.6082766652107239, D_loss_fake: 0.5858623584111532
Epoch 62: G_loss: 0.866818818781111, D_loss_real: 0.6052972012095981, D_loss_fake: 0.5742910040749444
Epoch 63: G_loss: 0.8647308084699843, D_loss_real: 0.5993110901779599, D_loss_fake: 0.5789912276797824
Epoch 64: G_loss: 0.8672464158799913, D_loss_real: 0.5971416797902849, D_loss_fake: 0.577690601348877
Epoch 65: G_loss: 0.8747036324607002, D_loss_real: 0.5919771989186605, D_loss_fake: 0.5738004181120131
Epoch 66: G_loss: 0.8822366661495633, D_loss_real: 0.5858766966395907, D_loss_fake: 0.5665432280964322
Epoch 67: G_loss: 0.8884385294384427, D_loss_real: 0.5774711006217532, D_loss_fake: 0.5651573671234978
Epoch 68: G_loss: 0.8931711382336087, D_loss_real: 0.5685934523741404, D_loss_fake: 0.5675294796625773
Epoch 69: G_loss: 0.8833408223258125, D_loss_real: 0.565541899866528, D_los

Epoch 40: G_loss: 0.8536425737234262, D_loss_real: 0.687139799961677, D_loss_fake: 0.6589103295252874
Epoch 41: G_loss: 0.8398884947483356, D_loss_real: 0.7207779849951084, D_loss_fake: 0.6850778964849619
Epoch 42: G_loss: 0.8232780465712914, D_loss_real: 0.7346746474504471, D_loss_fake: 0.6861465252362765
Epoch 43: G_loss: 0.7918177338746878, D_loss_real: 0.7379957655301461, D_loss_fake: 0.6969380195324237
Epoch 44: G_loss: 0.786931941142449, D_loss_real: 0.7304555418399664, D_loss_fake: 0.6992407395289495
Epoch 45: G_loss: 0.8084697631689218, D_loss_real: 0.7275312737776682, D_loss_fake: 0.6825631994467515
Epoch 46: G_loss: 0.8151640204282907, D_loss_real: 0.706986658848249, D_loss_fake: 0.6730814713698167
Epoch 47: G_loss: 0.8310265632776114, D_loss_real: 0.6869988590478897, D_loss_fake: 0.6559150081414443
Epoch 48: G_loss: 0.8521252045264611, D_loss_real: 0.6630640820815012, D_loss_fake: 0.63375068627871
Epoch 49: G_loss: 0.8786732004239008, D_loss_real: 0.6438640218514663, D_loss_

Epoch 20: G_loss: 0.9840103723108768, D_loss_real: 0.5661254115402699, D_loss_fake: 0.5528805758804083
Epoch 21: G_loss: 0.987640805542469, D_loss_real: 0.560398967936635, D_loss_fake: 0.5404492579400539
Epoch 22: G_loss: 0.9631064683198929, D_loss_real: 0.549043757840991, D_loss_fake: 0.5516470037400723
Epoch 23: G_loss: 0.963819932192564, D_loss_real: 0.5480208862572908, D_loss_fake: 0.5497039817273617
Epoch 24: G_loss: 0.9577808864414692, D_loss_real: 0.54969054274261, D_loss_fake: 0.5535768307745457
Epoch 25: G_loss: 0.9233168922364712, D_loss_real: 0.5562433786690235, D_loss_fake: 0.5744868628680706
Epoch 26: G_loss: 0.9240431003272533, D_loss_real: 0.5639439150691032, D_loss_fake: 0.5775330327451229
Epoch 27: G_loss: 0.9258337765932083, D_loss_real: 0.5813708435744047, D_loss_fake: 0.5836319141089916
Epoch 28: G_loss: 0.9160209186375141, D_loss_real: 0.5850105173885822, D_loss_fake: 0.5848170593380928
Epoch 29: G_loss: 0.9154998697340488, D_loss_real: 0.5945979245007038, D_loss_f

EM: 0.0002500277444444444, MV: 247540.45778145755
Epoch 0: G_loss: 1.097877648472786, D_loss_real: 0.5682435035705566, D_loss_fake: 0.509650219976902
Epoch 1: G_loss: 1.0303472310304642, D_loss_real: 0.5684812046587467, D_loss_fake: 0.5401949152350426
Epoch 2: G_loss: 1.0170018792152404, D_loss_real: 0.5967227272689343, D_loss_fake: 0.5510829746723175
Epoch 3: G_loss: 1.002763044834137, D_loss_real: 0.6054095454514027, D_loss_fake: 0.5523543491959572
Epoch 4: G_loss: 0.9870935201644897, D_loss_real: 0.6101661637425423, D_loss_fake: 0.5489535078406333
Epoch 5: G_loss: 1.016166526079178, D_loss_real: 0.6113621063530446, D_loss_fake: 0.5469211876392365
Epoch 6: G_loss: 1.0200626134872437, D_loss_real: 0.6193337127566337, D_loss_fake: 0.5569493666291236
Epoch 7: G_loss: 1.0363053768873214, D_loss_real: 0.6130607523024082, D_loss_fake: 0.5559684753417968
Epoch 8: G_loss: 1.0541617214679717, D_loss_real: 0.6088127613067627, D_loss_fake: 0.5505159422755241
Epoch 9: G_loss: 1.0795358538627624,

Epoch 80: G_loss: 0.7809051007032395, D_loss_real: 0.690264168381691, D_loss_fake: 0.6523868978023529
Epoch 81: G_loss: 0.757995268702507, D_loss_real: 0.7018009155988694, D_loss_fake: 0.6773413985967636
Epoch 82: G_loss: 0.7363401770591735, D_loss_real: 0.7022120535373688, D_loss_fake: 0.6913903385400773
Epoch 83: G_loss: 0.7347515761852265, D_loss_real: 0.7159253001213074, D_loss_fake: 0.6906785130500793
Epoch 84: G_loss: 0.7323300212621688, D_loss_real: 0.7060212135314942, D_loss_fake: 0.6923205196857453
Epoch 85: G_loss: 0.7447802782058716, D_loss_real: 0.6945411771535873, D_loss_fake: 0.6793056637048721
Epoch 86: G_loss: 0.7717454046010971, D_loss_real: 0.6701499849557877, D_loss_fake: 0.655437046289444
Epoch 87: G_loss: 0.7943775504827499, D_loss_real: 0.6513190716505051, D_loss_fake: 0.6384333938360214
Epoch 88: G_loss: 0.8045241922140122, D_loss_real: 0.6368367880582809, D_loss_fake: 0.6318967491388321
Epoch 89: G_loss: 0.778816157579422, D_loss_real: 0.6429203927516938, D_loss