In [1]:
import torch
from torch import nn
import data_utils
from models.FenceGAN import Generator, Discriminator
from training.FenceGAN_train import FenceGanTrainingPipeline
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Hyperparameters

In [3]:
# Data
batch_size = 256
random_seed = 0
num_features = 6
seq_length = 1
seq_stride = 10
gen_seq_len = seq_length
# Model
latent_dim = 30
# Training
gen_lr = 1e-4
gen_wd = 1e-3
dis_lr = 8e-6
dis_wd = 1e-3
dis_momentum = 0.9
num_epochs = 50

# Load data

In [4]:
dataset = "kdd99_small"

In [5]:
if dataset == "kdd99_small":
    train_dl, test_dl = data_utils.kdd99(seq_length, seq_stride, num_features, gen_seq_len, batch_size)
elif dataset == "kdd99_large":
    train_dl_normal = data_utils.large_kdd99('data/kdd99/X_train_normal.npy', seq_length, seq_stride, num_features, gen_seq_len,batch_size)
    train_dl_anomaly = data_utils.large_kdd99('data/kdd99/X_train_anomaly.npy', seq_length, seq_stride, num_features, gen_seq_len,batch_size)
    test_dl_normal = data_utils.large_kdd99('data/kdd99/X_test_normal.npy', seq_length, seq_stride, num_features, gen_seq_len,batch_size)
    test_dl_anomaly = data_utils.large_kdd99('data/kdd99/X_test_anomaly.npy', seq_length, seq_stride, num_features, gen_seq_len,batch_size)
elif dataset == "apple":
    file_path = './data/Stocks/aapl.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_length, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model

In [6]:
# use xavier initialization for weights
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)   

In [7]:
generator = Generator(input_dim=latent_dim,output_dim=num_features).to(device=DEVICE)
generator.apply(init_weights)

Generator(
  (linear1): Linear(in_features=30, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=6, bias=True)
)

In [8]:
discriminator = Discriminator(input_dim=num_features).to(device=DEVICE)
discriminator.apply(init_weights)

Discriminator(
  (linear1): Linear(in_features=6, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

# Loss and Optimizer

In [9]:
generator_optim = torch.optim.Adam(generator.parameters(), lr=gen_lr, weight_decay=gen_wd)
discriminator_optim = torch.optim.SGD(discriminator.parameters(), lr=dis_lr, weight_decay=dis_wd)

In [10]:
def dispersion_loss(G_out, y_pred, y_true):
    dispersion_weight = 30
    loss_b = nn.BCELoss()(y_pred, y_true)
    center = G_out.mean(dim=0, keepdims=True)
    distance_xy = torch.square(torch.subtract(G_out, center))
    distance = distance_xy.sum(dim=1)
    avg_distance = torch.sqrt(distance).mean()
    loss_d = torch.reciprocal(avg_distance)
    loss = loss_b + dispersion_weight*loss_d
    return loss

In [11]:
def disc_loss(real_pred, real_true, fake_pred, fake_true):
    gen_weight = 0.5
    loss_real = nn.BCELoss()(real_pred, real_true)
    loss_gen = nn.BCELoss()(fake_pred, fake_true)
    loss = loss_real + gen_weight * loss_gen
    return loss

# Training

In [12]:
pipeline = FenceGanTrainingPipeline()

In [13]:
if dataset == "kdd99_small":
    pipeline.train_kdd99_small(seq_length, latent_dim,  train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)
elif dataset == "kdd99_large":
    pipeline.train_kdd99_large(seq_length, latent_dim,  train_dl_normal, train_dl_anomaly, test_dl_anomaly, test_dl_normal, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)
else:
    pipeline.train(seq_length, latent_dim, tscv_dl_list, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)

Epoch 0 training:
G_loss: 19.338057899475096, D_loss: 1.1276006395166571
Epoch 1 training:
G_loss: 10.747882652282716, D_loss: 1.1266415829008276
Epoch 2 training:
G_loss: 7.354594523256475, D_loss: 1.1357076281850988
Epoch 3 training:
G_loss: 5.50076459754597, D_loss: 1.156045286763798
Epoch 4 training:
G_loss: 4.354316354881633, D_loss: 1.1847474645484577
Epoch 5 training:
G_loss: 3.5926332809708335, D_loss: 1.2097719685597854
Epoch 6 training:
G_loss: 3.047925609892065, D_loss: 1.2274696967818521
Epoch 7 training:
G_loss: 2.6655534083192998, D_loss: 1.235485921122811
Epoch 8 training:
G_loss: 2.4161304235458374, D_loss: 1.2257977864959024
Epoch 9 training:
G_loss: 2.262792724912817, D_loss: 1.2103874385356903
Epoch 10 training:
G_loss: 2.1753327505155045, D_loss: 1.2017200811342759
Epoch 11 training:
G_loss: 2.1139880245382137, D_loss: 1.2020989954471588
Epoch 12 training:
G_loss: 2.031764817237854, D_loss: 1.224377325448123
Epoch 13 training:
G_loss: 1.9367716805501418, D_loss: 1.2

KeyboardInterrupt: 