In [1]:
import torch
import matplotlib.pyplot as plt

from utils import Dataset, load_data
from trainer import batch_gd, evaluate
from models import DeepLOB

In [2]:
config = {
    # Data configs
    'data_path' : './data/',
    'batch_size' : 32,
    'num_classes' : 3,
    'T' : 100,
    'k' : 10,
    
    # Training configs
    'lr' : 0.01,
    'eps' : 1.0,
    'epochs' : 150,
    'device' : torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'patience': 20,
    'min_delta': 1e-6,
    'print_freq': 10
}
print(config['device'])

cuda


In [3]:
# Load the data
train, val, test = load_data(config['data_path'])

dataset_train = Dataset(data=train, k=config['k'], num_classes=config['num_classes'], T=config['T'])
dataset_val = Dataset(data=val, k=config['k'], num_classes=config['num_classes'], T=config['T'])
dataset_test = Dataset(data=test, k=config['k'], num_classes=config['num_classes'], T=config['T'])

train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=config['batch_size'], shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=config['batch_size'], shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=config['batch_size'], shuffle=False)

print(dataset_train.x.shape, dataset_train.y.shape)

(149, 203800) (149, 50950) (149, 139587)
torch.Size([203701, 1, 100, 40]) torch.Size([203701])


In [4]:
model = DeepLOB(y_len = dataset_train.num_classes)
model.to(config['device'])

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=config['eps'])      # set as in the paper or use lr=0.0001

In [None]:
model, train_losses, val_losses, train_accs, val_accs = batch_gd(model,
                                                                 criterion,
                                                                 optimizer,
                                                                 train_loader,
                                                                 val_loader,
                                                                 config
                                                                 )

  1%|          | 1/150 [02:22<5:52:41, 142.03s/it]

best model updated
Epoch 1/150, Train Loss: 0.9539,           Validation Loss: 0.9183, Duration: 0:02:22.023892, Best Val Epoch: 0


  1%|▏         | 2/150 [04:39<5:43:08, 139.11s/it]

best model updated
Epoch 2/150, Train Loss: 0.9443,           Validation Loss: 0.9180, Duration: 0:02:17.070747, Best Val Epoch: 1


  2%|▏         | 3/150 [06:41<5:21:34, 131.25s/it]

best model updated
Epoch 3/150, Train Loss: 0.9434,           Validation Loss: 0.9178, Duration: 0:02:01.897824, Best Val Epoch: 2


  3%|▎         | 4/150 [08:42<5:09:34, 127.22s/it]

best model updated
Epoch 4/150, Train Loss: 0.9428,           Validation Loss: 0.9178, Duration: 0:02:01.032930, Best Val Epoch: 3


  3%|▎         | 5/150 [10:44<5:03:32, 125.60s/it]

best model updated
Epoch 5/150, Train Loss: 0.9423,           Validation Loss: 0.9178, Duration: 0:02:02.725754, Best Val Epoch: 4


  4%|▍         | 6/150 [13:14<5:21:23, 133.91s/it]

best model updated
Epoch 6/150, Train Loss: 0.9396,           Validation Loss: 0.9177, Duration: 0:02:30.044187, Best Val Epoch: 5


  5%|▍         | 7/150 [15:27<5:18:06, 133.48s/it]

Epoch 7/150, Train Loss: 0.8921,           Validation Loss: 0.9179, Duration: 0:02:12.575179, Best Val Epoch: 5


  5%|▌         | 8/150 [17:33<5:10:17, 131.11s/it]

Epoch 8/150, Train Loss: 0.8650,           Validation Loss: 0.9453, Duration: 0:02:06.037049, Best Val Epoch: 5


  6%|▌         | 9/150 [19:42<5:06:21, 130.37s/it]

best model updated
Epoch 9/150, Train Loss: 0.8520,           Validation Loss: 0.8717, Duration: 0:02:08.729100, Best Val Epoch: 8


  6%|▌         | 9/150 [19:58<5:12:49, 133.12s/it]


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(15,6))
plt.subplot(1,2,1)
plt.plot(train_losses, label='train loss')
plt.plot(val_losses, label='validation loss')
plt.title('Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.title('Accuracies')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [6]:
report_k10 = evaluate(model, test_loader, config)

KeyboardInterrupt: 