In [None]:
import torch
import utils.data_utils as data_utils
from models.DeepAnT import DeepAnT_CNN, DeepAnT_LSTM
from training.DeepAnT_train import train_kdd99, train_financial
import matplotlib.pyplot as plt

In [None]:
# Try to use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# Hyperparameters

In [None]:
# choose between CNN and LSTM based DeepAnT
model_type = "LSTM"
# choose dataset (e.g., kdd99, aapl, gm, axp, etc.)
dataset = "kdd99"

In [None]:
if model_type == "LSTM":
    # Data
    num_features = 7
    seq_length = 30
    seq_stride = 10
    gen_seq_len = seq_length
    batch_size = 8
    # Model
    hidden_dim = 128
    layers = 4
    anm_det_thr = 0.5
    # Training
    num_epochs = 100
    lr=1e-5
    wd=5e-6

In [None]:
if model_type == "CNN":
    # Data
    num_features = 34
    seq_length = 30
    seq_stride = 1
    gen_seq_len = 1
    batch_size = 256
    # Model
    dense_dim = 448 
    num_channels = 64
    kernel_size = 3
    anm_det_thr = 0.5
    # Training
    num_epochs = 100
    lr=1e-5
    wd=5e-6

# Dataset

In [None]:
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_length, seq_stride, num_features, gen_seq_len, batch_size, deepant=True)
else:
    file_path = 'data/financial_data/Stocks/' + dataset + '.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_length, seq_stride, gen_seq_len, batch_size, normalise=True)

# Model 

In [None]:
if model_type == "LSTM":
    model = DeepAnT_LSTM(num_features,hidden_dim,layers,anm_det_thr).to(DEVICE)
elif model_type == "CNN":
    model = DeepAnT_CNN(seq_length,num_features, kernel_size, dense_dim, num_channels, anm_det_thr).to(DEVICE)

# Loss & Optimizer

In [None]:
loss_function = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(list(model.parameters()), lr=lr, weight_decay=wd)

# Training & Evaluation

In [None]:
if dataset == "kdd99":
    train_kdd99(train_dl, test_dl, model, optimizer, loss_function, num_epochs, DEVICE)
else:
    train_financial(tscv_dl_list, model, optimizer, loss_function, num_epochs, DEVICE)

# Generated Samples

In [None]:
if dataset == "kdd99":
    batch = next(iter(test_dl))
else:
    batch = next(iter(tscv_dl_list[4][1]))
x = batch[0][0] if (model_type == "LSTM") else batch[0]
y = batch[2][0] if (model_type == "LSTM") else batch[2][:seq_length].squeeze()
z = model(x.to(DEVICE)).cpu().detach() 
if (model_type == "CNN"): z = z[:seq_length]

In [None]:
# Generated Sample
plt.plot(z)

In [None]:
# Real Sample
plt.plot(y)