In [1]:
import torch
from torch import nn
import data_utils
from models.FenceGAN import Generator, Discriminator
from training.FenceGAN_train import FenceGanTrainingPipeline
import matplotlib.pyplot as plt

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


# Hyperparameters

In [3]:
# choose dataset (kdd99, aapl, gm, axp)
dataset = "axp"

In [4]:
# Data
batch_size = 8
random_seed = 0
num_features = 7
seq_length = 1
seq_stride = 1
gen_seq_len = seq_length
# Model
latent_dim = 30
# Training
gen_lr = 1e-4
gen_wd = 1e-3
dis_lr = 8e-6
dis_wd = 1e-3
dis_momentum = 0.9
num_epochs = 50

# Load data

In [5]:
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_length, seq_stride, num_features, gen_seq_len, batch_size)
else:
    file_path = 'data/financial_data/Stocks/'+dataset+'.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_length, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model

In [6]:
# use xavier initialization for weights
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)   

In [7]:
generator = Generator(input_dim=latent_dim,output_dim=num_features).to(device=DEVICE)
generator.apply(init_weights)

Generator(
  (linear1): Linear(in_features=30, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=7, bias=True)
)

In [8]:
discriminator = Discriminator(input_dim=num_features).to(device=DEVICE)
discriminator.apply(init_weights)

Discriminator(
  (linear1): Linear(in_features=7, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

# Loss and Optimizer

In [9]:
generator_optim = torch.optim.Adam(generator.parameters(), lr=gen_lr, weight_decay=gen_wd)
discriminator_optim = torch.optim.SGD(discriminator.parameters(), lr=dis_lr, weight_decay=dis_wd)

In [10]:
def dispersion_loss(G_out, y_pred, y_true):
    dispersion_weight = 30
    loss_b = nn.BCELoss()(y_pred, y_true)
    center = G_out.mean(dim=0, keepdims=True)
    distance_xy = torch.square(torch.subtract(G_out, center))
    if G_out.dim() > 1:
        distance = distance_xy.sum(dim=1)
    else:
        distance = distance_xy.sum()
    avg_distance = torch.sqrt(distance).mean()
    loss_d = torch.reciprocal(avg_distance)
    loss = loss_b + dispersion_weight*loss_d
    return loss

In [11]:
def disc_loss(real_pred, real_true, fake_pred, fake_true):
    gen_weight = 0.5
    loss_real = nn.BCELoss()(real_pred, real_true)
    loss_gen = nn.BCELoss()(fake_pred, fake_true)
    loss = loss_real + gen_weight * loss_gen
    return loss

# Training

In [12]:
pipeline = FenceGanTrainingPipeline()

In [13]:
if dataset == "kdd99":
    pipeline.train_kdd99_small(seq_length, latent_dim,  train_dl, test_dl, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)
else:
    pipeline.train(seq_length, latent_dim, tscv_dl_list, discriminator, generator, discriminator_optim, generator_optim, disc_loss, dispersion_loss, random_seed, num_epochs, DEVICE)

Epoch 0: G_loss: 29.92186709148128, D_loss_real: 1.0154574498897646
Epoch 1: G_loss: 29.10731738951148, D_loss_real: 1.00639479770893
Epoch 2: G_loss: 25.743796185749332, D_loss_real: 1.012747033340175
Epoch 3: G_loss: 25.25482526639613, D_loss_real: 1.005496935146611
Epoch 4: G_loss: 23.283972344747404, D_loss_real: 1.0101549232878335
Epoch 5: G_loss: 22.92787570488162, D_loss_real: 1.0128223023763516
Epoch 6: G_loss: 22.059006714239352, D_loss_real: 1.0062012279905923
Epoch 7: G_loss: 20.574369197938502, D_loss_real: 1.012354148597252
Epoch 8: G_loss: 19.225924422101276, D_loss_real: 1.0148338006763924
Epoch 9: G_loss: 18.530715616737925, D_loss_real: 1.0176356144067717
Epoch 10: G_loss: 17.95960349571414, D_loss_real: 1.020706998138893
Epoch 11: G_loss: 17.877934060445646, D_loss_real: 1.0135226336921133
Epoch 12: G_loss: 16.232094206461092, D_loss_real: 1.0166252386279222
Epoch 13: G_loss: 15.603293186280785, D_loss_real: 1.022705448836815
Epoch 14: G_loss: 15.760701947095917, D_lo

Epoch 18: G_loss: 2.940221057674749, D_loss_real: 0.866848066570313
Epoch 19: G_loss: 2.8231382486296863, D_loss_real: 0.8657662446905927
Epoch 20: G_loss: 2.8973135676810413, D_loss_real: 0.8808841225577564
Epoch 21: G_loss: 2.9253182895784455, D_loss_real: 0.876339300376613
Epoch 22: G_loss: 2.8249681499915393, D_loss_real: 0.880386323463626
Epoch 23: G_loss: 2.673241995214447, D_loss_real: 0.889183366686348
Epoch 24: G_loss: 2.6324719704263577, D_loss_real: 0.9177477393693071
Epoch 25: G_loss: 2.541871724089956, D_loss_real: 0.9631193362600435
Epoch 26: G_loss: 2.412976973425082, D_loss_real: 0.9888836058174691
Epoch 27: G_loss: 2.3172958498078633, D_loss_real: 1.0243067528174175
Epoch 28: G_loss: 2.3416755344809554, D_loss_real: 1.0367878869297058
Epoch 29: G_loss: 2.3407413019397394, D_loss_real: 1.052089599089894
Epoch 30: G_loss: 2.302420023010998, D_loss_real: 1.0531055491145065
Epoch 31: G_loss: 2.4202400920836906, D_loss_real: 1.0135304685530624
Epoch 32: G_loss: 2.5493445347

Epoch 36: G_loss: 3.2420969032659763, D_loss_real: 0.7701922800482773
Epoch 37: G_loss: 3.2734563897295694, D_loss_real: 0.7676740445741793
Epoch 38: G_loss: 3.2488741444378366, D_loss_real: 0.7666985000052103
Epoch 39: G_loss: 3.233262556355174, D_loss_real: 0.7705791781588298
Epoch 40: G_loss: 3.2997622734162864, D_loss_real: 0.7714510510607464
Epoch 41: G_loss: 3.2524163804403163, D_loss_real: 0.7715100238962871
Epoch 42: G_loss: 3.2683356308355562, D_loss_real: 0.7641523933992154
Epoch 43: G_loss: 3.281598377227783, D_loss_real: 0.7692254025761674
Epoch 44: G_loss: 3.301019718588852, D_loss_real: 0.7682534057919572
Epoch 45: G_loss: 3.242855088303729, D_loss_real: 0.7717110471027654
Epoch 46: G_loss: 3.273510550289619, D_loss_real: 0.7710723978717152
Epoch 47: G_loss: 3.2669238369639326, D_loss_real: 0.7669783717248498
Epoch 48: G_loss: 3.26959867942624, D_loss_real: 0.769632228700126
Epoch 49: G_loss: 3.3255721336457786, D_loss_real: 0.7639395548076164
EM: 0.004513555497761758, MV