### **Axial LOB**

In [1]:
# Load necessary packages
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm 
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
import torch
from torch.utils import data
import torch.nn as nn
import torch.optim as optim

# Set device (GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [3]:
# Import necessary packages
import numpy as np

# Define paths based on your directory structure for NoAuction
train_file = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Training/Train_Dst_NoAuction_ZScore_CF_7.txt'
test_file1 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_7.txt'
test_file2 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_8.txt'
test_file3 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_9.txt'

# Load training and validation data
dec_data = np.loadtxt(train_file)
dec_train = dec_data[:, :int(dec_data.shape[1] * 0.8)]
dec_val = dec_data[:, int(dec_data.shape[1] * 0.8):]

# Load test data and concatenate
dec_test1 = np.loadtxt(test_file1)
dec_test2 = np.loadtxt(test_file2)
dec_test3 = np.loadtxt(test_file3)
dec_test = np.hstack((dec_test1, dec_test2, dec_test3))

# Set parameters
W = 40                     # Number of features
dim = 40                   # Number of LOB states
horizon = 2                # Horizon for target calculation
T = 5                      # Time window size for dataset creation

# Prepare labels
y_train = dec_train[-horizon, :].flatten()
y_val = dec_val[-horizon, :].flatten()
y_test = dec_test[-horizon, :].flatten()

# Adjust labels for training, validation, and test sets
y_train = y_train[dim-1:] - 1
y_val = y_val[dim-1:] - 1
y_test = y_test[dim-1:] - 1 

# Prepare data for model input
dec_train = dec_train[:40, :].T
dec_val = dec_val[:40, :].T
dec_test = dec_test[:40, :].T

# Print shapes to verify data
print("Training data shape:", dec_train.shape)
print("Validation data shape:", dec_val.shape)
print("Testing data shape:", dec_test.shape)


Training data shape: (203800, 40)
Validation data shape: (50950, 40)
Testing data shape: (139587, 40)


In [5]:
import torch
from torch.utils import data

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, x, y, num_classes, dim):
        """Initialization""" 
        self.num_classes = num_classes
        self.dim = dim
        self.x = x   
        self.y = y

        # Compute length based on rolling window
        self.length = x.shape[0] - T - self.dim + 1
        print("Dataset length:", self.length)

        # Convert to PyTorch tensors
        x = torch.from_numpy(x).float()  # Ensure data is float for model input
        self.x = torch.unsqueeze(x, 1)   # Add channel dimension
        self.y = torch.from_numpy(y).long()  # Labels should be long type for classification

    def __len__(self):
        """Denotes the total number of samples"""
        return self.length

    def __getitem__(self, i):
        # Extract input with rolling window and adjust shape
        input = self.x[i:i+self.dim, :]
        input = input.permute(1, 0, 2)  # Adjust to expected shape [1, dim, features]
        
        return input, self.y[i]

# Set parameters
batch_size = 64
num_classes = 3  # Adjust based on your problem (e.g., 3 classes for LOB levels)
dim = 40  # Number of LOB states, adjust as needed

# Instantiate Dataset objects for train, validation, and test sets
dataset_train = Dataset(dec_train, y_train, num_classes, dim)
dataset_val = Dataset(dec_val, y_val, num_classes, dim)
dataset_test = Dataset(dec_test, y_test, num_classes, dim)

# Create DataLoader objects for batching and shuffling
train_loader = data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
test_loader = data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

# Verify DataLoader functionality
for inputs, labels in train_loader:
    print("Input batch shape:", inputs.shape)
    print("Label batch shape:", labels.shape)
    break  # Test with a single batch


Dataset length: 203756
Dataset length: 50906
Dataset length: 139543
Input batch shape: torch.Size([64, 1, 40, 40])
Label batch shape: torch.Size([64])


In [11]:
# Import necessary packages
import torch
from torch.utils import data

# Define hyperparameters
batch_size = 64
epochs = 50 
c_final = 4              # Channel output size of the second conv layer
n_heads = 4
c_in_axial = 32          # Channel output size of the first conv layer
c_out_axial = 32
pool_kernel = (1, 4)
pool_stride = (1, 4)

num_classes = 3

# Adjust label preparation without flattening the entire dataset
horizon = 2
dim = 40

# Define lengths based on data size and required horizon offset
train_len = dec_train.shape[0] - dim + 1 - horizon
val_len = dec_val.shape[0] - dim + 1 - horizon
test_len = dec_test.shape[0] - dim + 1 - horizon

# Slicing the labels to match the data lengths exactly
y_train = dec_train[dim-1:dim-1 + train_len, -horizon] - 1
y_val = dec_val[dim-1:dim-1 + val_len, -horizon] - 1
y_test = dec_test[dim-1:dim-1 + test_len, -horizon] - 1

# Confirm alignment
print("Training data shape:", dec_train[:train_len].shape)
print("Training labels shape:", y_train.shape)
print("Validation data shape:", dec_val[:val_len].shape)
print("Validation labels shape:", y_val.shape)
print("Testing data shape:", dec_test[:test_len].shape)
print("Testing labels shape:", y_test.shape)

# Create Dataset instances
dataset_train = Dataset(dec_train, y_train, num_classes, dim)
dataset_val = Dataset(dec_val, y_val, num_classes, dim)
dataset_test = Dataset(dec_test, y_test, num_classes, dim)

# Set up DataLoader instances
train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

# Verify DataLoader functionality with a sample batch
for inputs, labels in train_loader:
    print("Sample input batch shape:", inputs.shape)
    print("Sample label batch shape:", labels.shape)
    break  # Test with a single batch


Training data shape: (203759, 40)
Training labels shape: (203759,)
Validation data shape: (50909, 40)
Validation labels shape: (50909,)
Testing data shape: (139546, 40)
Testing labels shape: (139546,)
Dataset length: 203756
Dataset length: 50906
Dataset length: 139543
Sample input batch shape: torch.Size([64, 1, 40, 40])
Sample label batch shape: torch.Size([64])


In [12]:
# Import necessary packages
import torch
from torch.utils import data
import torch.nn as nn
import math

# Define hyperparameters
batch_size = 64
epochs = 50 
c_final = 4
n_heads = 4
c_in_axial = 32
c_out_axial = 32
pool_kernel = (1, 4)
pool_stride = (1, 4)
num_classes = 3

# Ensure labels align with data by slicing both consistently
horizon = 2
dim = 40
train_len = dec_train.shape[0] - dim + 1 - horizon
val_len = dec_val.shape[0] - dim + 1 - horizon
test_len = dec_test.shape[0] - dim + 1 - horizon
y_train = dec_train[dim-1:dim-1 + train_len, -horizon] - 1
y_val = dec_val[dim-1:dim-1 + val_len, -horizon] - 1
y_test = dec_test[dim-1:dim-1 + test_len, -horizon] - 1

# Dataset setup
dataset_train = Dataset(dec_train[:train_len], y_train, num_classes, dim)
dataset_val = Dataset(dec_val[:val_len], y_val, num_classes, dim)
dataset_test = Dataset(dec_test[:test_len], y_test, num_classes, dim)
train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

# Model Architecture
def _conv1d1x1(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm1d(out_channels)
    )

class GatedAxialAttention(nn.Module):
    def __init__(self, in_channels, out_channels, heads, dim, flag):
        super().__init__()
        assert (in_channels % heads == 0) and (out_channels % heads == 0)
        self.in_channels, self.out_channels, self.heads = in_channels, out_channels, heads
        self.dim_head_v = out_channels // heads
        self.flag, self.dim = flag, dim
        self.dim_head_qk = self.dim_head_v // 2
        self.qkv_channels = self.dim_head_v + self.dim_head_qk * 2

        self.to_qkv = _conv1d1x1(in_channels, heads * self.qkv_channels)
        self.bn_qkv, self.bn_similarity = nn.BatchNorm1d(heads * self.qkv_channels), nn.BatchNorm2d(heads * 3)
        self.bn_output = nn.BatchNorm1d(heads * self.qkv_channels)

        # Gating mechanism
        self.f_qr, self.f_kr = nn.Parameter(torch.tensor(0.3), False), nn.Parameter(torch.tensor(0.3), False)
        self.f_sve, self.f_sv = nn.Parameter(torch.tensor(0.3), False), nn.Parameter(torch.tensor(0.5), False)

        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.dim_head_v * 2, dim * 2 - 1), requires_grad=True)
        query_index = torch.arange(dim).unsqueeze(0)
        key_index = torch.arange(dim).unsqueeze(1)
        relative_index = key_index - query_index + dim - 1
        self.register_buffer('flatten_index', relative_index.view(-1))

    def forward(self, x):
        if self.flag:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.view(N * W, C, H)

        x = self.to_qkv(x)
        qkv = self.bn_qkv(x)
        q, k, v = torch.split(qkv.reshape(N * W, self.heads, self.dim_head_v * 2, H),
                              [self.dim_head_v // 2, self.dim_head_v // 2, self.dim_head_v], dim=2)

        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.dim_head_v * 2, self.dim, self.dim)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.dim_head_qk, self.dim_head_qk, self.dim_head_v], dim=0)
        qr, kr, qk = torch.einsum('bgci,cij->bgij', q, q_embedding), torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3), torch.einsum('bgci, bgcj->bgij', q, k)
        qr, kr = torch.mul(qr, self.f_qr), torch.mul(kr, self.f_kr)

        similarity = torch.softmax(self.bn_similarity(torch.cat([qk, qr, kr], dim=1)).view(N * W, 3, self.heads, H, H).sum(dim=1), dim=3)
        sv, sve = torch.mul(torch.einsum('bgij,bgcj->bgci', similarity, v), self.f_sv), torch.mul(torch.einsum('bgij,cij->bgci', similarity, v_embedding), self.f_sve)

        output = self.bn_output(torch.cat([sv, sve], dim=-1).view(N * W, self.out_channels * 2, H)).view(N, W, self.out_channels, 2, H).sum(dim=-2)
        return output.permute(0, 2, 3, 1) if not self.flag else output.permute(0, 2, 1, 3)

class AxialLOB(nn.Module):
    def __init__(self, W, H, c_in, c_out, c_final, n_heads, pool_kernel, pool_stride):
        super().__init__()
        self.CNN_in, self.CNN_out = nn.Conv2d(1, c_in, kernel_size=1), nn.Conv2d(c_out, c_final, kernel_size=1)
        self.CNN_res2, self.CNN_res1 = nn.Conv2d(c_out, c_final, kernel_size=1), nn.Conv2d(1, c_out, kernel_size=1)
        self.norm, self.res_norm2, self.res_norm1, self.norm2 = nn.BatchNorm2d(c_in), nn.BatchNorm2d(c_final), nn.BatchNorm2d(c_out), nn.BatchNorm2d(c_final)
        self.axial_height_1, self.axial_width_1 = GatedAxialAttention(c_out, c_out, n_heads, H, False), GatedAxialAttention(c_out, c_out, n_heads, W, True)
        self.axial_height_2, self.axial_width_2 = GatedAxialAttention(c_out, c_out, n_heads, H, False), GatedAxialAttention(c_out, c_out, n_heads, W, True)
        self.activation, self.linear = nn.ReLU(), nn.Linear(1600, 3)
        self.pooling = nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride)

    def forward(self, x):
        y = self.activation(self.norm(self.CNN_in(x)))
        y, x = self.axial_width_1(y), self.activation(self.res_norm1(self.CNN_res1(x)))
        y, y_copy = y + x, y.detach().clone()
        y = self.axial_width_2(y + x)
        y = self.activation(self.res_norm2(self.CNN_out(self.axial_height_2(self.axial_height_1(y)))))
        return torch.softmax(self.linear(torch.flatten(self.pooling(y + self.activation(self.norm2(self.CNN_res2(y_copy)))), 1)), dim=1)


Dataset length: 203715
Dataset length: 50865
Dataset length: 139502


In [None]:
# Import necessary packages
import numpy as np
import torch
from torch.utils import data
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from tqdm import tqdm
import math
from collections import Counter
from sklearn.model_selection import train_test_split

# Set device (GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define paths based on your directory structure for NoAuction
train_file = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Training/Train_Dst_NoAuction_ZScore_CF_7.txt'
test_file1 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_7.txt'
test_file2 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_8.txt'
test_file3 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_9.txt'

# Load the full dataset
dec_data = np.loadtxt(train_file)

# Set parameters
W = 40                     # Number of features
dim = 40                   # Number of LOB states (sequence length)
horizon = 2                # Horizon for target calculation

# Prepare data for model input
dec_data = dec_data[:40, :].T  # Shape: (num_samples, features)
N = dec_data.shape[0]

# Create sequences and corresponding labels
sequences = []
labels = []
for i in range(N - dim - horizon + 1):
    seq = dec_data[i:i+dim]  # Sequence of length 'dim'
    label = dec_data[i+dim+horizon-1, -horizon]  # Label corresponding to the sequence
    sequences.append(seq)
    labels.append(label)

# Convert to numpy arrays
X_full = np.array(sequences)
y_full_raw = np.array(labels)

# Map labels to classes using fixed thresholds
mean_label = np.mean(y_full_raw)
std_label = np.std(y_full_raw)
lower_threshold = mean_label - std_label
upper_threshold = mean_label + std_label

# Function to map continuous labels to class indices using fixed thresholds
def map_labels_fixed(y, lower_threshold, upper_threshold):
    y_mapped = np.zeros_like(y, dtype=int)
    y_mapped[y <= lower_threshold] = 0
    y_mapped[(y > lower_threshold) & (y <= upper_threshold)] = 1
    y_mapped[y > upper_threshold] = 2
    return y_mapped

# Map labels
y_full = map_labels_fixed(y_full_raw, lower_threshold, upper_threshold)

# Perform stratified splitting
X_train_seq, X_val_seq, y_train_seq, y_val_seq = train_test_split(
    X_full, y_full, test_size=0.2, stratify=y_full, random_state=42)

# Load test data and concatenate
dec_test1 = np.loadtxt(test_file1)
dec_test2 = np.loadtxt(test_file2)
dec_test3 = np.loadtxt(test_file3)

# Concatenate and slice to the first 40 features
dec_test = np.hstack((dec_test1, dec_test2, dec_test3))
dec_test = dec_test[:40, :].T  # Shape: (num_samples, 40)

# Prepare test data
N_test = dec_test.shape[0]
sequences_test = []
labels_test = []
for i in range(N_test - dim - horizon + 1):
    seq = dec_test[i:i+dim]  # Sequence of shape (40, 40)
    label = dec_test[i+dim+horizon-1, -horizon]
    sequences_test.append(seq)
    labels_test.append(label)

# Convert to numpy arrays
X_test_seq = np.array(sequences_test, dtype=np.float32)
y_test_raw = np.array(labels_test)
y_test_seq = map_labels_fixed(y_test_raw, lower_threshold, upper_threshold)

# Confirm alignment
print("Training data shape:", X_train_seq.shape)
print("Training labels shape:", y_train_seq.shape)
print("Validation data shape:", X_val_seq.shape)
print("Validation labels shape:", y_val_seq.shape)
print("Testing data shape:", X_test_seq.shape)
print("Testing labels shape:", y_test_seq.shape)

# Define Dataset class
class Dataset(data.Dataset):
    def __init__(self, X_seq, y_seq):
        self.X_seq = X_seq
        self.y_seq = y_seq

    def __len__(self):
        return len(self.y_seq)

    def __getitem__(self, index):
        X = torch.tensor(self.X_seq[index], dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        y = torch.tensor(self.y_seq[index], dtype=torch.long)
        return X, y

# Define AxialLOB model with Gated Axial Attention
# Include the AxialLOB and GatedAxialAttention classes as provided in your previous code

# Model, loss function, optimizer, and scheduler setup
model = AxialLOB(W=40, H=40, c_in=c_in_axial, c_out=c_out_axial, c_final=c_final, n_heads=n_heads,
                 pool_kernel=pool_kernel, pool_stride=pool_stride).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.00001)

# Create Dataset instances
dataset_train = Dataset(X_train_seq, y_train_seq)
dataset_val = Dataset(X_val_seq, y_val_seq)
dataset_test = Dataset(X_test_seq, y_test_seq)

# Set up DataLoader instances
train_loader = data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
test_loader = data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

# Define batch_gd function
def batch_gd(model, criterion, optimizer, epochs):
    train_losses = np.zeros(epochs)
    val_losses = np.zeros(epochs)
    best_val_loss = np.inf
    best_epoch = 0

    for it in tqdm(range(epochs)):
        model.train()
        train_loss = []

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        
        train_losses[it] = np.mean(train_loss)

        # Validation phase
        model.eval()
        val_loss = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss.append(loss.item())
        val_losses[it] = np.mean(val_loss)

        scheduler.step()
        
        # Save the best model based on validation loss
        if val_losses[it] < best_val_loss:
            torch.save(model.state_dict(), 'model_best.pth')
            best_val_loss = val_losses[it]
            best_epoch = it
            print('Model saved')

        print(f"Epoch {it+1}/{epochs}, Train Loss: {train_losses[it]:.4f}, "
              f"Validation Loss: {val_losses[it]:.4f}, Best Val Epoch: {best_epoch+1}")

    return train_losses, val_losses

# Execute training
train_losses, val_losses = batch_gd(model, criterion, optimizer, epochs)


Using device: cpu
Training data shape: (203767, 40, 40)
Training labels shape: (203767,)
Validation data shape: (50942, 40, 40)
Validation labels shape: (50942,)
Testing data shape: (139546, 40, 40)
Testing labels shape: (139546,)


  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# Import necessary packages
import numpy as np
import torch
from torch.utils import data
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from tqdm import tqdm
import math
from collections import Counter
from sklearn.model_selection import train_test_split

# Set device (GPU if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define paths based on your directory structure for NoAuction
train_file = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Training/Train_Dst_NoAuction_ZScore_CF_7.txt'
test_file1 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_7.txt'
test_file2 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_8.txt'
test_file3 = 'C:/Users/monam/BenchmarkDatasets/NoAuction/1.NoAuction_Zscore/NoAuction_Zscore_Testing/Test_Dst_NoAuction_ZScore_CF_9.txt'

# Load the full dataset
dec_data = np.loadtxt(train_file)

# Set parameters
W = 40                     # Number of features
dim = 40                   # Number of LOB states (sequence length)
horizon = 2                # Horizon for target calculation

# Prepare data for model input
dec_data = dec_data[:40, :].T  # Shape: (num_samples, features)
N = dec_data.shape[0]

# Create sequences and corresponding labels
sequences = []
labels = []
for i in range(N - dim - horizon + 1):
    seq = dec_data[i:i+dim]  # Sequence of length 'dim'
    label = dec_data[i+dim+horizon-1, -horizon]  # Label corresponding to the sequence
    sequences.append(seq)
    labels.append(label)

# Convert to numpy arrays
X_full = np.array(sequences)
y_full_raw = np.array(labels)

# Map labels to classes using fixed thresholds
mean_label = np.mean(y_full_raw)
std_label = np.std(y_full_raw)
lower_threshold = mean_label - std_label
upper_threshold = mean_label + std_label

# Function to map continuous labels to class indices using fixed thresholds
def map_labels_fixed(y, lower_threshold, upper_threshold):
    y_mapped = np.zeros_like(y, dtype=int)
    y_mapped[y <= lower_threshold] = 0
    y_mapped[(y > lower_threshold) & (y <= upper_threshold)] = 1
    y_mapped[y > upper_threshold] = 2
    return y_mapped

# Map labels
y_full = map_labels_fixed(y_full_raw, lower_threshold, upper_threshold)

# Perform stratified splitting
X_train_seq, X_val_seq, y_train_seq, y_val_seq = train_test_split(
    X_full, y_full, test_size=0.2, stratify=y_full, random_state=42)

# Load test data and concatenate
dec_test1 = np.loadtxt(test_file1)
dec_test2 = np.loadtxt(test_file2)
dec_test3 = np.loadtxt(test_file3)

# Concatenate and slice to the first 40 features
dec_test = np.hstack((dec_test1, dec_test2, dec_test3))
dec_test = dec_test[:40, :].T  # Shape: (num_samples, 40)

# Prepare test data
N_test = dec_test.shape[0]
sequences_test = []
labels_test = []
for i in range(N_test - dim - horizon + 1):
    seq = dec_test[i:i+dim]  # Sequence of shape (40, 40)
    label = dec_test[i+dim+horizon-1, -horizon]
    sequences_test.append(seq)
    labels_test.append(label)

# Convert to numpy arrays
X_test_seq = np.array(sequences_test, dtype=np.float32)
y_test_raw = np.array(labels_test)
y_test_seq = map_labels_fixed(y_test_raw, lower_threshold, upper_threshold)

# Confirm alignment
print("Training data shape:", X_train_seq.shape)
print("Training labels shape:", y_train_seq.shape)
print("Validation data shape:", X_val_seq.shape)
print("Validation labels shape:", y_val_seq.shape)
print("Testing data shape:", X_test_seq.shape)
print("Testing labels shape:", y_test_seq.shape)

# Define Dataset class
class Dataset(data.Dataset):
    def __init__(self, X_seq, y_seq):
        self.X_seq = X_seq
        self.y_seq = y_seq

    def __len__(self):
        return len(self.y_seq)

    def __getitem__(self, index):
        X = torch.tensor(self.X_seq[index], dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        y = torch.tensor(self.y_seq[index], dtype=torch.long)
        return X, y

# Define AxialLOB model with Gated Axial Attention
def _conv1d1x1(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm1d(out_channels)
    )

class GatedAxialAttention(nn.Module):
    def __init__(self, in_channels, out_channels, heads, dim, flag):
        super().__init__()
        assert (in_channels % heads == 0) and (out_channels % heads == 0)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.dim_head_v = out_channels // heads
        self.flag = flag
        self.dim = dim
        self.dim_head_qk = self.dim_head_v // 2
        self.qkv_channels = self.dim_head_v + self.dim_head_qk * 2

        # Multi-head self-attention
        self.to_qkv = _conv1d1x1(in_channels, self.heads * self.qkv_channels)
        self.bn_qkv = nn.BatchNorm1d(self.heads * self.qkv_channels)
        self.bn_similarity = nn.BatchNorm2d(heads * 3)
        self.bn_output = nn.BatchNorm1d(self.heads * self.qkv_channels)

        # Gating mechanism
        self.f_qr = nn.Parameter(torch.tensor(0.3), requires_grad=False)
        self.f_kr = nn.Parameter(torch.tensor(0.3), requires_grad=False)
        self.f_sve = nn.Parameter(torch.tensor(0.3), requires_grad=False)
        self.f_sv = nn.Parameter(torch.tensor(0.5), requires_grad=False)

        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.dim_head_v * 2, dim * 2 - 1), requires_grad=True)
        query_index = torch.arange(dim).unsqueeze(0)
        key_index = torch.arange(dim).unsqueeze(1)
        relative_index = key_index - query_index + dim - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.relative, 0., math.sqrt(1. / self.dim_head_v))

    def forward(self, x):
        if self.flag:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.reshape(N * W, C, H)

        # Transformations
        x = self.to_qkv(x)
        qkv = self.bn_qkv(x)
        q, k, v = torch.split(qkv.reshape(N * W, self.heads, self.dim_head_v * 2, H),
                              [self.dim_head_v // 2, self.dim_head_v // 2, self.dim_head_v], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).reshape(
            self.dim_head_v * 2, self.dim, self.dim)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings,
                                                            [self.dim_head_qk, self.dim_head_qk, self.dim_head_v],
                                                            dim=0)
        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
        qk = torch.einsum('bgci, bgcj->bgij', q, k)

        # Multiply by factors
        qr = torch.mul(qr, self.f_qr)
        kr = torch.mul(kr, self.f_kr)
        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        stacked_similarity = self.bn_similarity(stacked_similarity).reshape(N * W, 3, self.heads, H, H).sum(dim=1)
        similarity = torch.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
        sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)

        # Multiply by factors
        sv = torch.mul(sv, self.f_sv)
        sve = torch.mul(sve, self.f_sve)
        stacked_output = torch.cat([sv, sve], dim=-1).reshape(N * W, self.out_channels * 2, H)
        output = self.bn_output(stacked_output).reshape(N, W, self.out_channels, 2, H).sum(dim=-2)

        if self.flag:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        return output

class AxialLOB(nn.Module):
    def __init__(self, W, H, c_in, c_out, c_final, n_heads, pool_kernel, pool_stride):
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.c_final = c_final

        self.CNN_in = nn.Conv2d(in_channels=1, out_channels=c_in, kernel_size=1)
        self.CNN_out = nn.Conv2d(in_channels=c_out, out_channels=c_final, kernel_size=1)
        self.CNN_res2 = nn.Conv2d(in_channels=c_out, out_channels=c_final, kernel_size=1)
        self.CNN_res1 = nn.Conv2d(in_channels=1, out_channels=c_out, kernel_size=1)

        self.norm = nn.BatchNorm2d(c_in)
        self.res_norm2 = nn.BatchNorm2d(c_final)
        self.res_norm1 = nn.BatchNorm2d(c_out)
        self.norm2 = nn.BatchNorm2d(c_final)
        self.axial_height_1 = GatedAxialAttention(c_out, c_out, n_heads, dim=W, flag=False)
        self.axial_width_1 = GatedAxialAttention(c_out, c_out, n_heads, dim=W, flag=True)
        self.axial_height_2 = GatedAxialAttention(c_out, c_out, n_heads, dim=W, flag=False)
        self.axial_width_2 = GatedAxialAttention(c_out, c_out, n_heads, dim=W, flag=True)

        self.activation = nn.ReLU()
        self.linear = nn.Linear(1600, num_classes)
        self.pooling = nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride)

    def forward(self, x):
        y = self.activation(self.norm(self.CNN_in(x)))
        y, x_res = self.axial_width_1(y), self.activation(self.res_norm1(self.CNN_res1(x)))
        y, y_copy = y + x_res, y.detach().clone()
        y = self.axial_width_2(self.axial_height_1(y))
        y = self.activation(self.res_norm2(self.CNN_out(self.axial_height_2(y))))
        pooled = self.pooling(y + self.activation(self.norm2(self.CNN_res2(y_copy))))
        flattened = torch.flatten(pooled, 1)
        logits = self.linear(flattened)
        return logits

# Model, loss function, optimizer, and scheduler setup
batch_size = 16
epochs = 5
c_final = 4
n_heads = 2
c_in_axial = 16
c_out_axial = 16
pool_kernel = (1, 4)
pool_stride = (1, 4)
num_classes = 3

model = AxialLOB(W=40, H=40, c_in=c_in_axial, c_out=c_out_axial, c_final=c_final, n_heads=n_heads,
                 pool_kernel=pool_kernel, pool_stride=pool_stride).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.00001)

# Create Dataset instances
dataset_train = Dataset(X_train_seq, y_train_seq)
dataset_val = Dataset(X_val_seq, y_val_seq)
dataset_test = Dataset(X_test_seq, y_test_seq)

# Set up DataLoader instances
train_loader = data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)
test_loader = data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)

# Define batch_gd function
def batch_gd(model, criterion, optimizer, epochs):
    train_losses = np.zeros(epochs)
    val_losses = np.zeros(epochs)
    best_val_loss = np.inf
    best_epoch = 0

    for it in tqdm(range(epochs)):
        model.train()
        train_loss = []

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        
        train_losses[it] = np.mean(train_loss)

        # Validation phase
        model.eval()
        val_loss = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss.append(loss.item())
        val_losses[it] = np.mean(val_loss)

        scheduler.step()
        
        # Save the best model based on validation loss
        if val_losses[it] < best_val_loss:
            torch.save(model.state_dict(), 'model_best.pth')
            best_val_loss = val_losses[it]
            best_epoch = it
            print('Model saved')

        print(f"Epoch {it+1}/{epochs}, Train Loss: {train_losses[it]:.4f}, "
              f"Validation Loss: {val_losses[it]:.4f}, Best Val Epoch: {best_epoch+1}")

    return train_losses, val_losses

# Execute training
train_losses, val_losses = batch_gd(model, criterion, optimizer, epochs)


Using device: cpu
Training data shape: (203767, 40, 40)
Training labels shape: (203767,)
Validation data shape: (50942, 40, 40)
Validation labels shape: (50942,)
Testing data shape: (139546, 40, 40)
Testing labels shape: (139546,)


  0%|          | 0/5 [00:00<?, ?it/s]