In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import torch
import pandas as pd
import numpy as np
import lightning as L

## Data Module

In [3]:
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

from dataset.dataset import ForexDataset

class ForexDataModule(L.LightningDataModule):
    def __init__(
        self,
        data,
        IDs: list,
        sequence_length: int,
        horizon: int,
        features: list,
        target: list,
        batch_size: int = 64,
        num_workers: int = 0,
        val_split: float = 0.2,
        shuffle: bool = True,
        random_state: int = 42
    ):
        super().__init__()
        self.data = data
        self.IDs = IDs
        self.sequence_length = sequence_length
        self.horizon = horizon
        self.features = features
        self.target = target
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.shuffle = shuffle
        self.random_state = random_state

    def setup(self, stage=None):

        train_idx, val_idx = train_test_split(
            self.IDs,
            test_size=self.val_split,
            shuffle=self.shuffle,
            random_state=self.random_state
        )

        self.train_dataset = ForexDataset(
            self.data, train_idx, self.sequence_length, self.horizon, self.features, self.target
        )

        self.val_dataset = ForexDataset(
            self.data, val_idx, self.sequence_length, self.horizon, self.features, self.target
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )


# GRU classification Model

We've tried different criterion.
    a. softmax + mse
    b. raw logits + cross entropy
We found that cross entropy performs better when using pytorch

In [4]:
import torch
from torch import nn
import lightning as L
from torchmetrics.classification import MulticlassAccuracy


class GRUModel(nn.Module):
    def __init__(self, n_features, output_size, n_hidden, n_layers, dropout):
        super().__init__()

        self.gru = nn.GRU(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout
        )
        self.linear = nn.Linear(n_hidden, output_size) # ouptput logits
        

    def forward(self, x):
        self.gru.flatten_parameters()
        _, hidden = self.gru(x)
        logits = self.linear(hidden[-1])
        return logits


class GRUModule(L.LightningModule):
    def __init__(self, n_features=1, output_size=1, n_hidden=64, n_layers=2, dropout=0.0):
        super().__init__()
        self.save_hyperparameters()

        self.model = GRUModel(
            n_features=self.hparams.n_features,
            output_size=self.hparams.output_size,
            n_hidden=self.hparams.n_hidden,
            n_layers=self.hparams.n_layers,
            dropout=self.hparams.dropout,
        )

        self.criterion = nn.CrossEntropyLoss()
        self.test_accuracy = MulticlassAccuracy(num_classes=output_size)

    def forward(self, x, labels=None):
        output = self.model(x)
        loss = 0
        if labels is not None:
            labels = labels.squeeze().long()
            loss = self.criterion(output, labels)
        return loss, output

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        loss, out = self(x, y)

        self.log('train_loss', loss, prog_bar=True, logger=True)
        return {
            'loss': loss
        }

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        loss, out = self(x, y)

        self.log('val_loss', loss, prog_bar=True, logger=True)
        return {
            'loss': loss
        }

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        loss, out = self(x, y)

        y = y.squeeze().long()
        preds = torch.argmax(out, dim=1)
        acc = self.test_accuracy(preds, y)

        self.log('test_loss', loss, prog_bar=True, logger=True)
        self.log('test_acc', acc, prog_bar=True, logger=True)
        return {'loss': loss, 'acc': acc}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]


Now we've defined all the classes we need for the training.
The following steps will use these to train a classification model to predict future movement of USDJPY close price

In [8]:
PKL_PATH = "../data/processed/usdjpy-h1-bar-2019-01-01-2025-05-12_processed.pkl"
SEQUENCE_LENGTH=30
HORIZON=1 # The next nth timeframe to predict
STRIDE=5 # Non-overlapping timeframe
FEATURES_COLS = ['close_log_return_scaled']
TARGET_COLS = ['train_label']

## Read data

In [9]:
import pandas as pd

In [10]:
df = pd.read_pickle(PKL_PATH)

df=df[df['timestamp'].dt.year <= 2024]
df

Unnamed: 0,timestamp,open,high,low,close,volume,time,log_volume,log_volume_scaled,time_group,close_delta,close_return,close_log_return,close_log_return_scaled,ret_mean_5,ret_mean_10,labels,train_label
10,2019-01-02 08:00:00,109.308,109.323,108.902,108.950,20474.2598,2019-01-02 08:00:00,9.926973,0.972262,1,-0.358,-0.003275,-0.003281,-2.853028,-0.000861,-0.000665,-1,0
11,2019-01-02 09:00:00,108.949,108.977,108.707,108.941,16183.7695,2019-01-02 09:00:00,9.691826,0.732430,1,-0.009,-0.000083,-0.000083,-0.078014,-0.000817,-0.000667,-1,0
12,2019-01-02 10:00:00,108.942,109.102,108.879,109.045,13739.5801,2019-01-02 10:00:00,9.528109,0.565451,1,0.104,0.000955,0.000954,0.821677,-0.000473,-0.000530,-1,0
13,2019-01-02 11:00:00,109.045,109.176,109.043,109.107,12859.2305,2019-01-02 11:00:00,9.461895,0.497918,1,0.062,0.000569,0.000568,0.486914,-0.000273,-0.000448,-1,0
14,2019-01-02 12:00:00,109.104,109.285,109.101,109.142,12204.0098,2019-01-02 12:00:00,9.409602,0.444583,1,0.035,0.000321,0.000321,0.271991,-0.000304,-0.000247,-1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
37425,2024-12-31 17:00:00,157.370,157.546,157.310,157.365,31525.6895,2024-12-31 17:00:00,10.358590,1.412479,319,-0.003,-0.000019,-0.000019,-0.022871,0.000605,0.000742,-1,0
37426,2024-12-31 18:00:00,157.368,157.393,157.234,157.295,24407.4902,2024-12-31 18:00:00,10.102686,1.151477,319,-0.070,-0.000445,-0.000445,-0.392415,0.000363,0.000731,-1,0
37427,2024-12-31 19:00:00,157.297,157.311,157.236,157.308,18772.5898,2024-12-31 19:00:00,9.840206,0.883767,319,0.013,0.000083,0.000083,0.065386,0.000684,0.000425,-1,0
37428,2024-12-31 20:00:00,157.308,157.376,157.293,157.344,29975.6406,2024-12-31 20:00:00,10.308174,1.361058,319,0.036,0.000229,0.000229,0.192235,0.000235,0.000324,-1,0


## Create Datamodule

In [11]:
from utils import get_sequence_start_indices

In [12]:
## get valid indices that wont create sequences crossing time gaps
IDs = get_sequence_start_indices(
    df,
    sequence_length=SEQUENCE_LENGTH,
    horizon=HORIZON,
    stride=STRIDE,
    group_col='time_group',
)
# Initialize Data Module
dm = ForexDataModule(
    data=df,
    IDs=IDs,
    sequence_length=SEQUENCE_LENGTH,
    target=TARGET_COLS,
    features=FEATURES_COLS,
    horizon=HORIZON,
    batch_size=64,
    val_split=0.2,
    num_workers=0,
)

In [13]:
model = GRUModule(
    n_features=len(FEATURES_COLS),
    output_size=3,
    n_hidden=256,
    n_layers=3,
    dropout=0.3,
)

# Training Script

In [14]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.profilers import SimpleProfiler

### Logging

In [15]:
logger = TensorBoardLogger("lightning_logs", name="prob_gru")

In [16]:
profiler = SimpleProfiler(filename='profiler')

### Earlystopping

In [17]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5,
    verbose=True
)

### Checkpoint

In [18]:
checkpoint_callback = ModelCheckpoint(
    filename='best_checkpoint',
    save_top_k=1,
    save_last=True,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

### Trainer

In [19]:
trainer = Trainer(
    # accelerator="gpu",
    # precision='16-mixed',
    profiler=profiler,
    callbacks=[checkpoint_callback, early_stopping],
    max_epochs=200,
    logger=logger,
    gradient_clip_val=1.0
    # num_sanity_val_steps=0,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(model, datamodule=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params | Mode 
-------------------------------------------------------------
0 | model         | GRUModel           | 989 K  | train
1 | criterion     | CrossEntropyLoss   | 0      | train
2 | test_accuracy | MulticlassAccuracy | 0      | train
-------------------------------------------------------------
989 K     Trainable params
0         Non-trainable params
989 K     Total params
3.957     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


                timestamp     open     high      low    close      volume  \
30520 2023-11-21 19:00:00  148.343  148.594  148.238  148.315  20056.0996   
30521 2023-11-21 20:00:00  148.314  148.429  148.273  148.388   8018.2700   
30522 2023-11-21 21:00:00  148.386  148.416  148.335  148.380   3078.6599   
30523 2023-11-21 22:00:00  148.364  148.396  148.233  148.237   1425.9700   
30524 2023-11-21 23:00:00  148.237  148.275  148.017  148.147   6125.9800   
30525 2023-11-22 00:00:00  148.146  148.299  148.079  148.233  14284.9502   
30526 2023-11-22 01:00:00  148.238  148.355  148.029  148.298  22533.3691   
30527 2023-11-22 02:00:00  148.298  148.332  148.133  148.236  11424.4199   
30528 2023-11-22 03:00:00  148.235  148.280  148.095  148.212  10594.2002   
30529 2023-11-22 04:00:00  148.214  148.389  148.202  148.376  11211.9297   
30530 2023-11-22 05:00:00  148.376  148.803  148.368  148.752  16636.9902   
30531 2023-11-22 06:00:00  148.753  149.053  148.735  148.932  16393.5605   

Sanity Checking: |                                                                                            …

C:\Users\yoyo\miniconda3\envs\fxml\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
C:\Users\yoyo\miniconda3\envs\fxml\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Metric val_loss improved. New best score: 1.095
Epoch 0, global step 62: 'val_loss' reached 1.09510 (best 1.09510), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.093
Epoch 1, global step 124: 'val_loss' reached 1.09325 (best 1.09325), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 2, global step 186: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.012 >= min_delta = 0.0. New best score: 1.081
Epoch 3, global step 248: 'val_loss' reached 1.08121 (best 1.08121), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.050 >= min_delta = 0.0. New best score: 1.032
Epoch 4, global step 310: 'val_loss' reached 1.03154 (best 1.03154), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 5, global step 372: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.030
Epoch 6, global step 434: 'val_loss' reached 1.02986 (best 1.02986), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 7, global step 496: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 1.027
Epoch 8, global step 558: 'val_loss' reached 1.02661 (best 1.02661), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.025
Epoch 9, global step 620: 'val_loss' reached 1.02471 (best 1.02471), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 10, global step 682: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 11, global step 744: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.024
Epoch 12, global step 806: 'val_loss' reached 1.02409 (best 1.02409), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.023
Epoch 13, global step 868: 'val_loss' reached 1.02267 (best 1.02267), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 14, global step 930: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 15, global step 992: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 16, global step 1054: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 17, global step 1116: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 1.022
Epoch 18, global step 1178: 'val_loss' reached 1.02161 (best 1.02161), saving model to 'lightning_logs\\prob_gru\\version_6\\checkpoints\\best_checkpoint.ckpt' as top 1


Validation: |                                                                                                 …

Epoch 19, global step 1240: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 20, global step 1302: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 21, global step 1364: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Epoch 22, global step 1426: 'val_loss' was not in top 1


Validation: |                                                                                                 …

Monitored metric val_loss did not improve in the last 5 records. Best score: 1.022. Signaling Trainer to stop.
Epoch 23, global step 1488: 'val_loss' was not in top 1


In [21]:
trainer.test(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


                timestamp     open     high      low    close      volume  \
30520 2023-11-21 19:00:00  148.343  148.594  148.238  148.315  20056.0996   
30521 2023-11-21 20:00:00  148.314  148.429  148.273  148.388   8018.2700   
30522 2023-11-21 21:00:00  148.386  148.416  148.335  148.380   3078.6599   
30523 2023-11-21 22:00:00  148.364  148.396  148.233  148.237   1425.9700   
30524 2023-11-21 23:00:00  148.237  148.275  148.017  148.147   6125.9800   
30525 2023-11-22 00:00:00  148.146  148.299  148.079  148.233  14284.9502   
30526 2023-11-22 01:00:00  148.238  148.355  148.029  148.298  22533.3691   
30527 2023-11-22 02:00:00  148.298  148.332  148.133  148.236  11424.4199   
30528 2023-11-22 03:00:00  148.235  148.280  148.095  148.212  10594.2002   
30529 2023-11-22 04:00:00  148.214  148.389  148.202  148.376  11211.9297   
30530 2023-11-22 05:00:00  148.376  148.803  148.368  148.752  16636.9902   
30531 2023-11-22 06:00:00  148.753  149.053  148.735  148.932  16393.5605   

C:\Users\yoyo\miniconda3\envs\fxml\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing: |                                                                                                    …

──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.39674457907676697
        test_loss           1.0228668451309204
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.0228668451309204, 'test_acc': 0.39674457907676697}]