In [None]:
from dotenv import load_dotenv
import os
from glob import glob
import mne
import numpy as np
import torch
import torch.nn as nn
import gc 
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd

load_dotenv()
root_dir = os.getenv("ROOT_DIR")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
# Settings

# Set seed for reproducibility (optional)
torch.manual_seed(42)

DECIMATED_SAMPLE_RATE_HZ = 256
# Seconds
WORD_SIGNAL_BEGIN = int(DECIMATED_SAMPLE_RATE_HZ * 1)
WORD_SIGNAL_COMPLETE = int(DECIMATED_SAMPLE_RATE_HZ * 3.5)


In [16]:
# Model definitions
class InnerSpeechDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [17]:
# CNN/LSTM hybrid
class InnerSpeechModel(nn.Module):
    def __init__(self):
        super().__init__()

        # CNN component: outputs 256 channels
        self.convolv = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1),  # Fixed to 256 channels
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3)
        )

        # Bi-LSTM component (2 Layers)
        self.lstm = nn.LSTM(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)

        self.attn_weight = nn.Linear(2 * 128, 1, bias=False)

        # Fully connected layer
        self.fc = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(2*128, 4)  # Matches hidden_size=128
        )

    def forward(self, x):
        # Input shape: (batch, 20, 64)
        x = x.permute(0, 2, 1)  # Shape: (batch, 64, 20)
        x = self.convolv(x)      # Shape: (batch, 256, 20)
        x = x.permute(0, 2, 1)   # Shape: (batch, 20, 256)

        lstm_out, (h_n, c_n) = self.lstm(x)  # lstm_out shape: (batch, 20, 128)

        # Compute attention scores
        # Flatten across features: attn_score[i, t] = wT * h_{i, t}
        # Then softmax over t to get α_{i, t}
        attn_scores = self.attn_weight(lstm_out).squeeze(-1)
        attn_weights = torch.softmax(attn_scores, dim=1)
        # Weighted sum of LSTM outputs:
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), lstm_out).squeeze(1)

        # Regression to 3D motion
        output = self.fc(attn_applied)
        return output

### Initial Analysis

In [18]:
sorted(glob(os.path.join(root_dir, "**", "sub-01*.fif"), recursive=True))

['/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_baseline-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_exg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_baseline-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_exg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-03/sub-01_ses-03_baseline-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-03/sub-01_ses-03_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-03/sub-0

In [19]:
data_l_epochs = sorted(glob(os.path.join(root_dir, "**", 'sub-*_ses-*_eeg-epo.fif'), recursive=True))

In [20]:
data_l_epochs = [data for data in data_l_epochs if "ses-03" not in data]

In [21]:
data_l_epochs

['/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-02/ses-01/sub-02_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-02/ses-02/sub-02_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-03/ses-01/sub-03_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-03/ses-02/sub-03_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-04/ses-01/sub-04_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-04/ses-02/sub-04_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-05/ses-01/sub-05_ses-01_eeg-ep

In [22]:
data_l_epochs.__len__()

20

In [23]:
data_l_epochs[:6]

['/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-02/ses-01/sub-02_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-02/ses-02/sub-02_ses-02_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-03/ses-01/sub-03_ses-01_eeg-epo.fif',
 '/media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-03/ses-02/sub-03_ses-02_eeg-epo.fif']

In [None]:
epochs = mne.read_epochs(data_l_epochs[0])

In [None]:
data_l_epochs[0]

In [None]:
epoch = mne.read_epochs(data_l_epochs[1], preload=True)

In [None]:
X = epoch.get_data()

In [None]:
X.shape


In [None]:
event_map = {v: k for k, v in epochs.event_id.items()}

In [None]:
event_map

In [None]:
y = [event_map[event[-1]] for event in epochs.events]

In [None]:
y[0]

In [None]:
256 * 3.5

In [None]:
X.shape

In [None]:
# This is the signal
X[0][:, WORD_SIGNAL_BEGIN:WORD_SIGNAL_COMPLETE].shape

In [None]:
device

In [None]:
asdf

In [24]:
X_all = []
y_all = []
for data in data_l_epochs:
    epochs = mne.read_epochs(data, preload=True)
    epochs.apply_baseline(baseline=(-0.5, 0))
    X = epochs.get_data()
    event_map = {v: k for k, v in epochs.event_id.items()}
    y = [event_map[event[-1]] for event in epochs.events]
    X = torch.tensor(X[:, :, WORD_SIGNAL_BEGIN:WORD_SIGNAL_COMPLETE], device=device)

    X_all.append(X)
    y_all.append(y)

Reading /media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-01/sub-01_ses-01_eeg-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    -500.00 ...    4000.00 ms
        0 CTF compensation matrices available
Not setting metadata
200 matching events found
No baseline correction applied
0 projection items activated
Applying baseline correction (mode: mean)
Reading /media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-01/ses-02/sub-01_ses-02_eeg-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    -500.00 ...    4000.00 ms
        0 CTF compensation matrices available
Not setting metadata
200 matching events found
No baseline correction applied
0 projection items activated
Applying baseline correction (mode: mean)
Reading /media/linux-pc/Stargate/inner-speech/ds003626-download/derivatives/sub-02/ses-01/sub-02_ses-01_eeg-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =    -50

#### Exploring the Channel Placement:

In [None]:

# Load standard BioSemi 128 montage
montage = mne.channels.make_standard_montage('biosemi128')

# Get the dictionary: channel name → (x, y, z) in meters
ch_pos = montage.get_positions()['ch_pos']

my_channel_names = epochs.ch_names

# Extract 3D coordinates in meters
channel_xyz = {ch: ch_pos[ch] for ch in my_channel_names if ch in ch_pos}

# Convert to mm or cm if desired
channel_xyz_mm = {ch: tuple(1000 * x for x in pos) for ch, pos in channel_xyz.items()}

# Optional: Convert to a DataFrame
df_channels = pd.DataFrame.from_dict(channel_xyz_mm, orient='index', columns=['x_mm', 'y_mm', 'z_mm'])

print(df_channels.head())




            x_mm       y_mm       z_mm
A1  0.000000e+00   0.000000  95.000000
A2  1.109950e-15 -18.126855  93.254582
A3  2.272911e-15 -37.119457  87.447961
A4  3.252866e-15 -53.123326  78.758569
A5 -2.670150e+01 -62.904799  65.992545


In [None]:
# df_channels.to_csv("./data/eeg_channel_physical_location.csv")

In [None]:
X_all[0][0].shape

In [None]:
y_all[0][0]

In [None]:
# Saving Labels

# y_all = np.array(y_all)
# unique_labels = sorted(set(y_all.flatten()))
# label_mapping = {label: index for index, label in enumerate(unique_labels)}
# int_labels = np.vectorize(label_mapping.get)(y_all)
# tensor_labels = torch.tensor(int_labels, dtype=torch.long, device=device)
# torch.save(tensor_labels,"data/labels.pth")

In [None]:
y_all.__len__()

In [None]:
X_all.__len__()

In [None]:
y_all[0].__len__()

In [None]:
# Channels x Samples @ 256 Hz (from seconds 1 to 3.5)
X_all[0][0].shape

In [None]:
y_all[0][0]

In [None]:
X_all[0].shape

In [None]:
# Concatenating the data into a single object:
X_all = torch.cat(X_all, dim=0)  # shape: [200 * 20 = 4000, 128, 640]

In [None]:
X_all.shape

In [None]:
y_flat = [label for sublist in y_all for label in sublist]  # Length: 4000

In [None]:
y_flat

In [None]:
torch.save(y_flat, "data/y_labels.pth")
torch.save(X_all, "data/X_all.pth")

In [None]:
del y_flat, X_all 


In [None]:
del y_all

In [None]:
# gc.collect()

#### Loading Data

In [None]:
# Loading the Labels
tensor_labels = torch.load("data/y_labels.pth")
X_all = torch.load("data/X_all.pth")

In [None]:
# Split into Train and Test data
# Total samples
num_samples = X_all.shape[0]
num_train = int(0.8 * num_samples)

# Generate a random permutation of indices
indices = torch.randperm(num_samples)

# Split indices
train_indices = indices[:num_train]
test_indices = indices[num_train:]

# Create train and test splits
train_data = X_all[train_indices]
test_data = X_all[test_indices]

# Print shapes to verify
print("Train shape:", train_data.shape)  # Should be [3200, 128, 640]
print("Test shape:", test_data.shape)    # Should be [800, 128, 640]

# Train and Test labels
train_indices_list = train_indices.tolist()
train_indices = [tensor_labels[index] for index in train_indices_list]
test_indices = [tensor_labels[index] for index in test_indices.tolist()]

In [None]:
torch.save(test_indices, "data/y_test.pth")
torch.save(test_data, "data/X_test.pth")

In [None]:
torch.save(train_indices, "data/y_train.pth")
torch.save(train_data, "data/X_train.pth")

In [None]:
# Future Progess:
# Load the data into a Neural Network
# Train the Neural Network on the data
# Output predictions from the neural network

# Generate synthetic neural signal based on this data
# stream the synthetic neural signal to chat-studio using websockets
# use an api with the trained neural network to decode the neural signal 
# display the neural signal as a suggested selection for text input

## Training a Neural Network

In [5]:
X_train = torch.load("data/X_train.pth")
y_train = torch.load("data/y_train.pth")
X_test = torch.load("data/X_test.pth")
y_test = torch.load("data/y_test.pth")

In [6]:
y_train = y_train.cpu().numpy()

AttributeError: 'list' object has no attribute 'cpu'

In [None]:
y_train = y_train.cpu().numpy()
X_train = X_train.cpu().numpy()

### Model Definition

In [None]:
X_train.device

In [None]:
X_train_split, X_val_split, y_train_split, y_validation_split = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
)

In [None]:
train_loader = InnerSpeechDataset(X_train_split, y_train_split)
val_loader = InnerSpeechDataset(X_val_split, y_validation_split)

In [None]:
train_loader = InnerSpeechDataset()