In [None]:
import torch
import matplotlib.pyplot as plt

from utils import Dataset, load_data
from trainer import batch_gd, evaluate
from models import DeepLOB

In [None]:
config = {
    # Data configs
    'data_path' : './data/',
    'batch_size' : 32,
    'num_classes' : 3,
    'T' : 100,
    'k' : 10,
    
    # Training configs
    'lr' : 0.01,
    'eps' : 1.0,
    'epochs' : 150,
    'device' : torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'patience': 20,
    'min_delta': 1e-6,
    'print_freq': 10
}
print(config['device'])

In [None]:
# Load the data
train, val, test = load_data(config['data_path'])

dataset_train = Dataset(data=train, k=config['k'], num_classes=config['num_classes'], T=config['T'])
dataset_val = Dataset(data=val, k=config['k'], num_classes=config['num_classes'], T=config['T'])
dataset_test = Dataset(data=test, k=config['k'], num_classes=config['num_classes'], T=config['T'])

train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=config['batch_size'], shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=config['batch_size'], shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=config['batch_size'], shuffle=False)

print(dataset_train.x.shape, dataset_train.y.shape)

In [None]:
model = DeepLOB(y_len = dataset_train.num_classes)
model.to(config['device'])

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=config['eps'])      # set as in the paper or use lr=0.0001

In [None]:
model, train_losses, val_losses, train_accs, val_accs = batch_gd(model,
                                                                 criterion,
                                                                 optimizer,
                                                                 train_loader,
                                                                 val_loader,
                                                                 config
                                                                 )

In [None]:
plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.plot(train_losses, label='train loss')
plt.plot(val_losses, label='validation loss')
plt.title('Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.title('Accuracies')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
report_k10 = evaluate(model, test_loader, config)