# Classification Experiment on Inner Speech dataset

Run this experiment to obtain a similar result as the one publisher in the paper

```latex
@article{gallo2024eeg,
  title={Thinking is Like a Sequence of Words},
  author={Gallo, Ignzio and Coarsh, Silvia},
  journal={IJCNN},
  volume={??},
  pages={??--??},
  year={2024},
  publisher={IEEE}
}
```

## Import libraries

In [1]:
import os
import random
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import yaml
import mne
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import datetime as dt
from tqdm import tqdm

In [2]:
def create_save_dir(args):
    if False: #"subject_num" in args:
        if type(args['subject_num']) is not list:
            args['subject_num'] = [args['subject_num']]

        subjs_str = ','.join(str(x) for x in args['subject_num'])
        args['save_dir'] = os.path.join(args['save_dir'], subjs_str, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    else:
        args['save_dir'] = os.path.join(args['save_dir'], datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    if not os.path.isdir(args['save_dir']):
        os.makedirs(args['save_dir'])    

In [3]:
class Metrics:
    def __init__(self, column_names):
        column_names.insert(0, "time_stamp")
        self.df = pd.DataFrame(columns=column_names)

    def add_row(self, row_list):
        row_list.insert(0, str(dt.datetime.now()))
        # print(row_list)
        self.df.loc[len(self.df)] = row_list

    def save_to_csv(self, filepath):
        self.df.to_csv(filepath, index=False)

## Load raw dataset

In [4]:
def sub_name(sub_id):
    # Standarize subjects name
    return f"sub-{sub_id:02}"

def ses_name(ses_id):
    # Standarize session name
    return f"ses-{ses_id:02}"

"""
Events data. Each event data file (in .dat format) contains a four column matrix where each row corresponds
to one trial. The first two columns were obtained from the raw events, by deleting the trigger column (second
column of the raw events) and renumbering the classes 31, 32, 33, 34 as 0, 1, 2, 3, respectively. Finally, the last
two columns correspond to condition and session number, respectively. Thus, the resulting final structure of the
events data file is as depicted in Table 5.
----------------------------------------------------------------------------------------------------------------
Sample                                     | Trial’s class         | Trials’ condition        | Trials’ session
----------------------------------------------------------------------------------------------------------------
Sample at which the event occurred         | 0 = “Arriba” (up)     | 0 = Pronounced speech    | 1 = session 1
(Numbered starting at n = 0, corresponding | 1 = “Abajo” (down)    | 1 = Inner speech         | 2 = session 2
to the beginning of the recording)         | 2 = “Derecha” (right) | 2 = Visualized condition | 3 = session 3
                                           | 3 = “Izquierda” (left)|                          |
----------------------------------------------------------------------------------------------------------------
"""
def load_events(root_dir, subject_id, N_B):
    subject_str = sub_name(subject_id)
    session_str = ses_name(N_B)
    # Create file Name
    # file_name =root_dir+"/derivatives/"+subject_str+"/ses-0"+str(N_B)+"/"+subject_str+"_ses-0"+str(N_B)+"_events.dat"
    file_name = os.path.join(root_dir, subject_str, session_str, subject_str+"_"+session_str+"_events.dat")
    # Load Events
    events = np.load(file_name,allow_pickle=True)
    
    return events    


def select_time_window_single(X, t_start=1, t_end=2.5, fs=256):
    s_max=X.shape[0]
    start = max(round(t_start * fs), 0)
    end = min(round(t_end * fs), s_max)

    #Copy interval
    X = X[start:end, :]
    return X


# Code from: https://github.com/N-Nieto/Inner_Speech_Dataset
def extract_data_from_subject(root_dir, subject_id, datatype):
    """
    Load all blocks for one subject and stack the results in X
    Reading from 'derivatives' directory: folder, containing five files obtained after the proposed processing: 
    EEG data, Baseline data, External electrodes data, Events data and a Report file. 
    For more details about processing see page 5 of https://www.nature.com/articles/s41597-022-01147-2 PDF
    """
    data=dict()
    y=dict()
    session_id_arr=[1,2,3]
    datatype=datatype.lower()
    
    for session_id in session_id_arr:
        # name correction if subject_idubj is less than 10
        subject_str = sub_name(subject_id)   
        session_str = ses_name(session_id)
            
        y[session_id] = load_events(root_dir, subject_id, session_id)
        
        # three consecutive sessions for each partecipant: baseline, inner speech, visualized condition
        if datatype=="eeg": # 128 active EEG channels
            #  load data and events
            file_name = os.path.join(root_dir, subject_str, session_str, subject_str+'_'+session_str+'_eeg-epo.fif')
            X= mne.read_epochs(file_name,verbose='WARNING')
            data[session_id]= X._data
            
        elif datatype=="exg": # 8 External electrodes
            file_name = os.path.join(root_dir, subject_str, session_str, subject_str+'_'+session_str+'_exg-epo.fif')
            X= mne.read_epochs(file_name,verbose='WARNING')
            data[session_id]= X._data
        
        elif datatype=="baseline":
            file_name = os.path.join(root_dir, subject_str, session_str, subject_str+'_'+session_str+'_baseline-epo.fif')
            X= mne.read_epochs(file_name,verbose='WARNING')
            data[session_id]= X._data

        else:
            raise Exception("Invalid Datatype")
         
    X = np.vstack((data.get(1),data.get(2),data.get(3))) 
    Y = np.vstack((y.get(1),y.get(2),y.get(3))) 

    return X, Y

In [5]:

class InnerSpeechDataset(Dataset):
    def __init__(self, X, Y, sampling_rate, task, transform=None, t_start=1.5, t_end=3.5, t_win=1.0): 
        self.sampling_rate = sampling_rate
        self.transform = transform
        # Select the useful par of each trial. Time in seconds
        self.t_start = t_start
        self.t_end = t_end  
        self.t_win = t_win
        self.task = task
        # sliding window on the source timeseries
        self.data, self.targets = X, Y # sliding_window(X, Y, 
        #    self.t_start, self.t_end, self.sampling_rate, t_stride=0.05, t_padding=0.05)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.targets[idx]
        
        # Cut usefull time. i.e action interval plus a random value
        # s = random.random() - 0.5 # aug1: select a 2 sec. random win range [1.0, 4.0]
        # x = select_time_window_single(X=x, t_start=self.t_start+s, t_end=self.t_end+s, fs=self.sampling_rate)

        if self.task == 'train':
            s = random.uniform(self.t_start, self.t_end-self.t_win) # aug2: select a t_win sec. random win range [t_start, t_end-t_win]
            # print(f"selected win: [{s}s,{s+1}s] -> [{int(s*self.sampling_rate)},{int((s+1)*self.sampling_rate)}], max: [0,{x.shape[0]-1}]")
            x = select_time_window_single(X=x, t_start=s, t_end=s+self.t_win, fs=self.sampling_rate)
        else:
            x = select_time_window_single(X=x, t_start=self.t_start+self.t_win, t_end=self.t_start+2*self.t_win, fs=self.sampling_rate)

        # Convert NumPy arrays to PyTorch tensors
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)

        # if self.transform:
        #     x = self.transform(x)

        return x, y

In [6]:
def read_innerspeech_raw(args):
    # Data Type
    datatype = "EEG"

    # Sampling rate
    fs = args['sampling_rate'] # 1024 # 256

    # Select the useful par of each trial. Time in seconds
    t_start = 1.5
    t_end = 3.5   

    X = np.array([])
    Y = np.array([])
    for subj in args['subject_num']:
        # Load all trials for each subject
        xs, ys = extract_data_from_subject(args['data_dir'], subj, datatype)
        X = np.concatenate([X, xs]) if X.size else xs
        Y = np.concatenate([Y, ys]) if Y.size else ys

    # Select Trials’ condition 1 = Inner speech 
    Y_inner=np.where(Y[:,2]==1)
    Y1=Y[Y_inner]
    X1=X[Y_inner]
    Y1 = Y1[:,1] # get class info only
    X1 = np.transpose(X1, (0, 2, 1)) # transform into (batch, vocab_size, emb_dim)

    SCALE = 1000
    X1=X1*SCALE

    # Random split using fixed random_state
    X_train, X_test, Y_train, Y_test = train_test_split(np.float32(X1), Y1.astype(int), test_size=0.2, random_state=1)
    print("Training orig. shape:", X_train.shape)
    print("Test orig. shape:", X_test.shape)

    # Convert data to DataLoader
    train_dataset = InnerSpeechDataset(X_train, Y_train, fs, task='train', t_start=1.5, t_end=3.5, t_win=1.0)
    print("Input pattern shape:", train_dataset.__getitem__(0)[0].shape)
    train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True)

    test_dataset = InnerSpeechDataset(X_test, Y_test, fs, task='test', t_start=1.5, t_end=3.5, t_win=1.0) 
    test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

    return train_loader, test_loader

## The model

Deep neural network based on a basic Transformer.

In [7]:
class NetTraST(nn.Module):
    def __init__(self, args): 
        super(NetTraST, self).__init__()
        self.batch_norm1 = nn.BatchNorm1d(args['vocab_size'])
        p = args['kernel_size'] // 2
        self.conv1 = nn.Conv1d(in_channels=args['vocab_size'], out_channels=args['kernel_num'], kernel_size=args['kernel_size'], stride=1, padding=p)
        
        #self.conv2 = nn.Conv1d(in_channels=args['embed_dim'], out_channels=args['kernel_num'], kernel_size=args['kernel_size'], stride=1, padding=p)
        self.upsamp = nn.Upsample((args['embed_dim']))
        
        self.rrelu = nn.RReLU(0.1, 0.3)
        nl=3 
        self.spatial_tra = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=args['embed_dim'],
                nhead=args['nhead'],
                dim_feedforward=args['dim_feedforward'],
            ),
            num_layers=nl,
        )
        #self.temporal_tra = nn.TransformerEncoder(
        #    nn.TransformerEncoderLayer(
        #        d_model=args['vocab_size'],
        #        nhead=args['nhead'],
        #        dim_feedforward=args['dim_feedforward'],
        #    ),
        #    num_layers=nl,
        #)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=args['kernel_num'],
                nhead=args['nhead'],
                dim_feedforward=args['dim_feedforward'],
            ),
            num_layers=args['num_layers'],
        )
        self.batch_norm3 = nn.BatchNorm1d(args['kernel_num'])
        self.fl = nn.Flatten()
        self.fc1 = nn.Linear(args['kernel_num']*args['embed_dim'], args['kernel_num'])
        self.dropout = nn.Dropout(args['dropout'])
        self.fc2 = nn.Linear(args['kernel_num'], args['class_num'])

    def forward(self, x): 
        x = self.batch_norm1(x)
        
        x1 = self.conv1(x) 
        x1 = self.spatial_tra(x1)
        
        #x2 = x.permute(0, 2, 1)
        #x2 = self.conv2(x2) 
        #x2 = self.temporal_tra(x2)
        #x2 = self.upsamp(x2)

        x = x1 #x1+x2 
        #x = torch.cat((x1, x2), 1)
        
        # Reshape the input for the Transformer layer
        x = x.permute(2, 0, 1)  # Change the shape to (sequence_length, batch_size, input_size)
        x = self.transformer(x)
        # Reshape the output back to the original shape
        x = x.permute(1, 2, 0)  # Change the shape to (batch_size, input_size, sequence_length)
        x = self.batch_norm3(x)
        x = self.fl(x)
        x = self.rrelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

## Training and Evaluation functions

In [8]:
def evaluation_raw(args, model, test_loader, criterion):
    # Evaluation
    model.eval()
    with torch.no_grad():
        tot_loss = 0
        test_corrects = torch.tensor(0, device=args['device'])
        for inputs, labels in test_loader:
            inputs = inputs.to(args['device'])
            labels = labels.to(args['device'])
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs, 1)
            corrects = (torch.max(outputs, 1)[1].view(labels.size()).data == labels.data).sum()
            test_corrects += corrects
            tot_loss += loss

        ts_acc = 100.0 * test_corrects/len(test_loader.dataset)
        #print(f'Test Accuracy: {ts_acc:.4f}, Test loss: {tot_loss:.6f}')  
    return ts_acc.cpu().item(), tot_loss


def train_raw(args, model, train_loader, optimizer, criterion, test_loader, metrics, subj, scheduler=None): 
    #metrics = Metrics(["epoch", "lr", "train_loss", "train_acc", "test_loss", "test_acc", "best_test_acc"])
    # Training loop
    best_acc = 0.0
    patience_counter = 0
    steps = 0
    loop_obj = tqdm(range(args['epochs']))
    loop_obj.set_postfix_str(f"Best val. acc.: {best_acc:.4f}")  # Adds text after progressbar
    for epoch in loop_obj:
        loop_obj.set_description(f"Subj.: {subj}, Training epoch: {epoch+1}")  # Adds text before progessbar
        train_corrects = torch.tensor(0, device=args['device'])
        tot_loss = 0
        model.train()
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            inputs = inputs.to(args['device'])

            labels = labels.to(args['device'])
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            corrects = (torch.max(outputs, 1)[1].view(labels.size()).data == labels.data).sum()
            train_corrects += corrects
            tot_loss += loss
            loss.backward()
            optimizer.step()

        if scheduler: scheduler.step()

        tr_acc = 100.0 * train_corrects/len(train_loader.dataset)
        # Validation
        dev_acc, test_loss = evaluation_raw(args, model, test_loader, criterion)

        if dev_acc > best_acc:
            best_acc = dev_acc
            patience_counter = 0
            #torch.save(model, os.path.join(args['save_dir'], f"{model.__class__.__name__}_model_best.pt"))
            loop_obj.set_postfix_str(f"Best val. acc.: {best_acc:.4f}")
        else:
            patience_counter += 1

        #EP=args['epochs']
        #print(f'Epoch [{epoch+1}/{EP}] Tr. Loss: {tot_loss.item():.4f} Val. Accuracy: {dev_acc:.4f} Best Val. Accuracy: {best_acc:.4f}')
        lr=optimizer.param_groups[0]["lr"]
        metrics.add_row([epoch+1, lr, tot_loss.cpu().item(), tr_acc.cpu().item(), test_loss.cpu().item(), dev_acc, best_acc])
        metrics.save_to_csv(os.path.join(args['save_dir'], "metrics_classifciation.csv"))

        if patience_counter > args['early_stopping_patience']:
            print(f"Early stopping... {patience_counter} > {args['early_stopping_patience']}")
            break
    return best_acc

## Run an experiment

- change the parameter '*subject_num*' in the *args* dictionary to change the subject to one of the following 
  - 'subject_num': [1]
  - 'subject_num': [2]
  - ...
  - 'subject_num': [10]  

In [9]:
def get_default_args():
    args = {
        'class_num': 4,
        'dropout': 0.1 ,
        'nhead': 4 ,
        'dim_feedforward': 256 ,
        'num_layers': 5 ,
        'embed_dim': 128,
        'vocab_size': 1024,
        'kernel_num': 128,
        'kernel_size': 3, 
        'batch_size': 128 ,
        'epochs': 1000 ,
        'early_stopping_patience': 300 ,
        'lr': 0.001 ,
        'log_interval': 1,
        'device': 'cuda:0' if torch.cuda.is_available() else 'cpu',
        'data_dir': "/home/jovyan/nfs/igallo/datasets/EEG/Inner_Speech_Dataset/derivatives_no_filter/", 
        'save_dir': 'experiments/transformer/inner_speech/no_filter_swin1sec/',
        'save_best': True,
        'verbose': True,
        'test_interval': 100,
        'save_interval': 500,
        'sampling_rate': 1024,
        'subject_num': [2] # [1,2,3,4,5,6,7,8,9,10], #  
    }

    return args

### Ablation 2
- NO Conv2
- NO Temporal transformer T2

In [10]:
def single_run():
    args = get_default_args()
    create_save_dir(args)
    with open(os.path.join(args['save_dir'], "config.yaml"), "w") as f:
        yaml.dump(
            args, stream=f, default_flow_style=False, sort_keys=False
        )

    # For all the results
    metrics = Metrics(["epoch", "lr", "train_loss", "train_acc", "test_loss", "test_acc", "best_test_acc"])
    results = {}
    for sub in range(1,11):
        args['subject_num'] = [sub]
        model = NetTraST(args)
        model = model.to(args['device'])

        # Define the loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters()) #, lr=args['lr'])

        train_loader, test_loader = read_innerspeech_raw(args)
        print("Training size:", len(train_loader.dataset))
        print("Test size:", len(test_loader.dataset))

        best_acc = train_raw(args, model, train_loader, optimizer, criterion, test_loader, metrics, sub) 
        #torch.save(model, os.path.join(args['save_dir'], f"{model.__class__.__name__}_model_last.pt"))
        results[sub] = best_acc
        accs = np.array(list(results.values()))
        print(f'acc: mean={np.mean(accs):.4f}%, std={np.std(accs):.4f}%')

    # Print subject results
    str = f"RESULTS FOR 10 SUBJECTS\n"
    str += '--------------------------------\n'
    for key, value in results.items():
        str += f'Subject {key}: {value:.4f} %\n'
    accs = np.array(list(results.values()))
    str += f'mean: {np.mean(accs):.4f}%, std: {np.std(accs):.4f}%\n'
    print(str)
    with open(os.path.join(args['save_dir'], "mean_results.txt"), "w") as f:
        f.write(str)

In [11]:
single_run()

Training orig. shape: (160, 4609, 128)
Test orig. shape: (40, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 160
Test size: 40


Subj.: 1, Training epoch: 308:  31%|███       | 307/1000 [01:46<04:00,  2.88it/s, Best val. acc.: 35.0000]

Early stopping... 301 > 300
acc: mean=35.0000%, std=0.0000%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 2, Training epoch: 307:  31%|███       | 306/1000 [01:56<04:23,  2.63it/s, Best val. acc.: 47.9167]

Early stopping... 301 > 300
acc: mean=41.4583%, std=6.4583%





Training orig. shape: (176, 4609, 128)
Test orig. shape: (44, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 176
Test size: 44


Subj.: 3, Training epoch: 606:  60%|██████    | 605/1000 [03:33<02:19,  2.84it/s, Best val. acc.: 40.9091]

Early stopping... 301 > 300
acc: mean=41.2753%, std=5.2796%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 4, Training epoch: 303:  30%|███       | 302/1000 [01:58<04:33,  2.56it/s, Best val. acc.: 37.5000]

Early stopping... 301 > 300
acc: mean=40.3314%, std=4.8557%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 5, Training epoch: 309:  31%|███       | 308/1000 [02:02<04:36,  2.51it/s, Best val. acc.: 41.6667]

Early stopping... 301 > 300
acc: mean=40.5985%, std=4.3758%





Training orig. shape: (172, 4609, 128)
Test orig. shape: (44, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 172
Test size: 44


Subj.: 6, Training epoch: 514:  51%|█████▏    | 513/1000 [03:20<03:10,  2.56it/s, Best val. acc.: 43.1818]

Early stopping... 301 > 300
acc: mean=41.0290%, std=4.1089%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 7, Training epoch: 453:  45%|████▌     | 452/1000 [05:16<06:23,  1.43it/s, Best val. acc.: 33.3333]

Early stopping... 301 > 300
acc: mean=39.9297%, std=4.6608%





Training orig. shape: (160, 4609, 128)
Test orig. shape: (40, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 160
Test size: 40


Subj.: 8, Training epoch: 350:  35%|███▍      | 349/1000 [04:28<08:20,  1.30it/s, Best val. acc.: 37.5000]

Early stopping... 301 > 300
acc: mean=39.6259%, std=4.4332%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 9, Training epoch: 496:  50%|████▉     | 495/1000 [08:01<08:10,  1.03it/s, Best val. acc.: 47.9167]

Early stopping... 301 > 300
acc: mean=40.5471%, std=4.9253%





Training orig. shape: (192, 4609, 128)
Test orig. shape: (48, 4609, 128)
Input pattern shape: torch.Size([1024, 128])
Training size: 192
Test size: 48


Subj.: 10, Training epoch: 323:  32%|███▏      | 322/1000 [05:16<11:06,  1.02it/s, Best val. acc.: 33.3333]

Early stopping... 301 > 300
acc: mean=39.8258%, std=5.1494%
RESULTS FOR 10 SUBJECTS
--------------------------------
Subject 1: 35.0000 %
Subject 2: 47.9167 %
Subject 3: 40.9091 %
Subject 4: 37.5000 %
Subject 5: 41.6667 %
Subject 6: 43.1818 %
Subject 7: 33.3333 %
Subject 8: 37.5000 %
Subject 9: 47.9167 %
Subject 10: 33.3333 %
mean: 39.8258%, std: 5.1494%




