In [1]:
import torch
from torch import nn
import data_utils
from models.FenceGAN import Generator, Discriminator
from training.FenceGAN_train import FenceGanTrainingPipeline
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Hyperparameters

In [3]:
# choose dataset (kdd99, aapl, gm, axp)
dataset = "kdd99"

In [4]:
# Data
batch_size = 256
random_seed = 0
num_features = 34
seq_length = 30
seq_stride = 10
gen_seq_len = seq_length
# Model
latent_dim = 30
# Training
gen_lr = 1e-4
gen_wd = 1e-3
dis_lr = 8e-6
dis_wd = 1e-3
dis_momentum = 0.9
num_epochs = 50

# Load data

In [5]:
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_length, seq_stride, num_features, gen_seq_len, batch_size)
else:
    file_path = 'data/financial_data/Stocks/'+dataset+'.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_length, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model

In [6]:
# use xavier initialization for weights
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)   

In [7]:
generator = Generator(input_dim=latent_dim,output_dim=num_features).to(device=DEVICE)
generator.apply(init_weights)

Generator(
  (linear1): Linear(in_features=30, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=34, bias=True)
)

In [8]:
discriminator = Discriminator(input_dim=num_features).to(device=DEVICE)
discriminator.apply(init_weights)

Discriminator(
  (linear1): Linear(in_features=34, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

# Loss and Optimizer

In [9]:
generator_optim = torch.optim.Adam(generator.parameters(), lr=gen_lr, weight_decay=gen_wd)
discriminator_optim = torch.optim.SGD(discriminator.parameters(), lr=dis_lr, weight_decay=dis_wd)

In [10]:
def dispersion_loss(G_out, y_pred, y_true):
    dispersion_weight = 30
    loss_b = nn.BCELoss()(y_pred, y_true)
    center = G_out.mean(dim=0, keepdims=True)
    distance_xy = torch.square(torch.subtract(G_out, center))
    if G_out.dim() > 1:
        distance = distance_xy.sum(dim=1)
    else:
        distance = distance_xy.sum()
    avg_distance = torch.sqrt(distance).mean()
    loss_d = torch.reciprocal(avg_distance)
    loss = loss_b + dispersion_weight*loss_d
    return loss

In [11]:
def disc_loss(real_pred, real_true, fake_pred, fake_true):
    gen_weight = 0.5
    loss_real = nn.BCELoss()(real_pred, real_true)
    loss_gen = nn.BCELoss()(fake_pred, fake_true)
    loss = loss_real + gen_weight * loss_gen
    return loss

# Training

In [12]:
pipeline = FenceGanTrainingPipeline()

In [13]:
if dataset == "kdd99":
    pipeline.train_kdd99(seq_length, latent_dim,  train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)
else:
    pipeline.train(seq_length, latent_dim, tscv_dl_list, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)

Epoch 0: G_loss: 9.709105465628884, D_loss_real: 1.2365699925205924
Epoch 1: G_loss: 5.25057809569619, D_loss_real: 1.4078323217955502
Epoch 2: G_loss: 2.9035993240096354, D_loss_real: 1.8773425887931476
Epoch 3: G_loss: 1.762805793502114, D_loss_real: 2.2901965152133594
Epoch 4: G_loss: 1.2812760250134902, D_loss_real: 2.3914693052118476
Epoch 5: G_loss: 1.1439002941955219, D_loss_real: 2.2341750394214284
Epoch 6: G_loss: 1.1894169677387585, D_loss_real: 1.99483581347899
Epoch 7: G_loss: 1.3010012529113075, D_loss_real: 1.7535684910687532
Epoch 8: G_loss: 1.414782965183258, D_loss_real: 1.5567764298482374
Epoch 9: G_loss: 1.523707706819881, D_loss_real: 1.458597771687941
Epoch 10: G_loss: 1.6771016175096685, D_loss_real: 1.4085985324599526
Epoch 11: G_loss: 1.9581425477157939, D_loss_real: 1.3021559753201224
Epoch 12: G_loss: 2.1215349013155156, D_loss_real: 1.2397004214200107
Epoch 13: G_loss: 2.031945520097559, D_loss_real: 1.224591152234511
Epoch 14: G_loss: 1.8250966044989498, D_l