In [None]:
import torch
import utils.data_utils as data_utils
from models.RSRAE import RSRLayer, RSRAutoEncoder
from training.RSRAE_train import train_kdd99, train_financial
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cpu"#torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# Hyperparameters

In [None]:
# choose dataset (kdd99, aapl, gm, axp)
dataset = "kdd99"

In [None]:
# Data
num_features = 7
seq_len = 1
seq_stride = 10
gen_seq_len = seq_len
# Training
random_seed = 0
num_epochs = 50
batch_size = 8
lr = 0.01 # Peak lr
wd = 5e-7
# Model
d = 128
D = 128*4
lambda1=1.0
lambda2=1.0
threshold=0.8

# Load Data 

In [None]:
import numpy as np
import torch.utils.data as data
if dataset == "kdd99":
    train_dl, test_dl = data_utils.kdd99(seq_len, seq_stride, num_features, gen_seq_len, batch_size)
    steps_per_epoch=len(train_dl)
else:
    file_path = 'data/financial_data/Stocks/'+dataset+'.us.txt'
    tscv_dl_list = data_utils.load_stock_as_crossvalidated_timeseries(file_path, seq_len, seq_stride, gen_seq_len, batch_size, normalise=True)
    steps_per_epoch=len(tscv_dl_list[0][1])

# Model

In [None]:
model = RSRAutoEncoder(num_features*seq_len, d, D).to(DEVICE)
print(model)

# Loss and Optimizer

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=steps_per_epoch)

In [None]:
def RSRLoss(z,A,lambda1=lambda1, lambda2=lambda2, d=d, D=D):
    Id = torch.eye(d) # identity matrix
    z_hat = A @ z.view(z.size(0), D, 1)
    AtAz = (A.T @ z_hat).squeeze(2)
    
    term1 = torch.sum(torch.norm(z - AtAz, p=2))
    term2 = torch.norm(A @ A.T - Id, p=2) ** 2

    return lambda1 * term1 + lambda2 * term2

In [None]:
def L2p_loss(y_hat, y, p=1.0):
    return torch.sum(torch.pow(torch.norm(y - y_hat, p=2), p))

# Train & Evaluate

In [None]:
if dataset == "kdd99":
    train_kdd99(model, train_dl, test_dl, num_epochs, L2p_loss, RSRLoss, optimizer, scheduler, threshold, DEVICE)
else:
    train_financial(model, tscv_dl_list, num_epochs, L2p_loss, RSRLoss, optimizer, scheduler, threshold, DEVICE)

# Generate Samples

In [None]:
if dataset == "kdd99":
    batch = next(iter(test_dl))
else:
    batch = next(iter(tscv_dl_list[4][1]))
x = batch[0].squeeze()
y = batch[1].squeeze()
enc, z, latent, A = model(x.to(DEVICE))
z = z.cpu().detach()

In [None]:
plt.plot(z)

In [None]:
plt.plot(x)