# Import dependencies and custom modules

In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'  # Suppress TF log messages

import random
import numpy as np
import tensorflow as tf

In [2]:
from DatasetAPI.DataLoader import DatasetLoader

# Loading the dataset

In [3]:
Model = 'Attention_based_Long_Short_Term_Memory'
DIR = 'DatasetAPI/EEG-Motor-Movement-Imagery-Dataset/'
SAVE = os.path.join('Saved_Files', Model)
os.makedirs(SAVE, exist_ok=True)

# GPU memory growth configuration (TF2 style)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Load the dataset. Assumes CSV files named appropriately.
train_data, train_labels, test_data, test_labels = DatasetLoader(DIR=DIR)

# One-hot encode the labels (for 4 classes) and squeeze extra dimensions.
train_labels = tf.one_hot(train_labels, depth=4)
train_labels = tf.squeeze(train_labels)
test_labels = tf.one_hot(test_labels, depth=4)
test_labels = tf.squeeze(test_labels)

print(train_data.shape)
print(train_labels.shape)

(967680, 64)
(967680, 4)


In [None]:

# Model Hyper-parameters
n_input = 64       # Input size per time step
max_time = 64      # Number of time steps per sequence
lstm_size = 256    # Number of LSTM units (per direction)
attention_size = 8 # Size of the attention layer
n_class = 4        # Number of output classes
n_hidden = 64      # Hidden units in the FC layer
num_epoch = 300    # Number of training epochs
keep_rate = 0.75   # Dropout keep probability

# Learning rate parameters
initial_lr = 1e-4
lr_decay_epoch = 50   # Decay every 50 epochs
lr_decay = 0.50       # Multiply learning rate by 0.5

batch_size = 1024

In [None]:
# Create dataset objects using tf.data
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_dataset = train_dataset.shuffle(buffer_size=train_data.shape[0]).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_dataset = test_dataset.batch(batch_size)