This file is used to train EEG embedding model, which transforms EEG signals into a low-dimensional semantic representation.

To get the ground truth for semantic representation, We first use gemini-2.5-flash to transform  video to descriptive text. Then we use "openai/clip-vit-large-patch14" to transform text into embedding vectors. 

## CONFIGURATIONS

In [50]:
CONFIG = {
    "eeg_data_path": "data\\PSD_DE\\imaging", # or "data\\PSD_DE\\watching"
    "text_embedding_path": "data\\metadata\\text_embedding.pt",
    "save_checkpoint_path": "checkpoints\\eeg_embedding",

    "seed": 42,
    "train_valid": [0.8, 0.2], 
    "batch_size": 32, 
    "num_workers": 0,

    "epochs": 200, 
    "learning_rate": 5e-4
}

## Set seeds

In [51]:
import os
import random
import numpy as np
import torch

def set_seed(seed):
    """Sets random seeds for reproducibility across all libraries.

    Args:
        seed (int): The seed to use for all random number generators.
        cudnn_deterministic (bool): If True, sets `torch.backends.cudnn.deterministic`
                                    to True. This can slow down training but ensures
                                    reproducibility for convolutional operations.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Define data

In [52]:
class EEGTextDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for EEG and corresponding text embeddings.

    Args:
        eeg (np.ndarray): A numpy array of EEG data, with shape
                          (n_samples, n_features).
        text (torch.Tensor): A torch tensor of text embeddings, with shape
                             (n_samples, n_embedding_dim).
    """
    def __init__(self, eeg: np.ndarray, text: torch.Tensor):
        self.eeg = eeg
        self.text = text
        self.len = eeg.shape[0]

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, item: int) -> tuple[np.ndarray, torch.Tensor]:
        return self.eeg[item], self.text[item]

## load data

In [53]:
from einops import rearrange, repeat
from sklearn import preprocessing
from torch.utils.data import DataLoader, random_split

def prepare_dataloader(eeg_data_path: str, text_embed_path: str,
                       batch_size: int, num_workers: int):
    """Loads, preprocesses, and prepares the EEG and Text data.

    This function handles:
    1. Loading the raw EEG data.
    2. Selecting specific trials based on `gt_label` and `chosen_labels`.
    3. Loading corresponding text embeddings.
    4. Normalizing the EEG data using StandardScaler.
    5. Creating a DataLoader for training.

    Args:
        eeg_data_path (str): Path to the EEG numpy file.
        text_embed_path_template (str): A format string for the path to text
                                        embedding files, e.g., 'path/to/block{i}.pt'.
        gt_label (np.ndarray): A ground truth label matrix to find trial indices.
        chosen_labels (list[int]): A list of class labels to include.
        batch_size (int): The batch size for the DataLoader.

    Returns:
        DataLoader: A DataLoader for the prepared training data.
    """
    print("Preparing data...")
    eeg_data = []
    for file in os.listdir(eeg_data_path):
        file_path = os.path.join(eeg_data_path, file)
        eeg_data.append(np.load(file_path))
    eeg_data = np.stack(eeg_data, axis=0)
    print(f"Loaded EEG data with shape: {eeg_data.shape}") # (60, 2, 5, 50, 62, 5)


    # Process each session's EEG data
    eeg = rearrange(eeg_data, 'a b c d e f -> (a b c d) (e f)') # shape(60*2*5*50, 62*5)
    
    # Process each session's text data
    text = torch.load(text_embed_path) # shape(250, 77, 768)
    text = repeat(text, 'a b c -> (num a) (b c)', num = eeg_data.shape[0]) # shape(60*250, 59136)

    # Normalize EEG data
    normalize = preprocessing.StandardScaler()
    eeg = normalize.fit_transform(eeg)
    print(f"Final EEG data shape: {eeg.shape}")
    print(f"Final Text embedding shape: {text.shape}")

    dataset = EEGTextDataset(eeg, text)
    trainset, validset = random_split(dataset, CONFIG["train_valid"])
    train_loader = DataLoader(
        trainset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers, 
        pin_memory=True
    )
    
    valid_loader = DataLoader(
        validset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers, 
        pin_memory=True
    )
    return train_loader, valid_loader

## Define the model

In [None]:
from torch import nn

class EEG_Embedding(nn.Module):
    """A simple MLP model to map EEG signals to text embedding space.

    This model takes flattened EEG data and projects it into a high-dimensional
    space matching the dimensions of CLIP's text embeddings (77 * 768).

    Attributes:
        mlp (nn.Sequential): A multi-layer perceptron consisting of linear
                             layers and ReLU activations.
    """
    def __init__(self):
        super(EEG_Embedding, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(310, 1000),
            nn.ReLU(),
            nn.Linear(1000, 10000),
            nn.ReLU(),
            nn.Linear(1000, 1000),
            nn.ReLU(),
            nn.Linear(1000, 1000),
            nn.ReLU(),
            nn.Linear(1000, 77 * 768)
        )

    def forward(self, eeg: torch.Tensor) -> torch.Tensor:
        """Forward pass of the model.

        Args:
            eeg (torch.Tensor): A batch of EEG data with shape (batch_size, 310).

        Returns:
            torch.Tensor: The predicted text embeddings with shape
                          (batch_size, 77 * 768).
        """
        eeg_embeddings = self.mlp(eeg)
        return eeg_embeddings

## Define loss function

In [55]:
def loss(input, target, criterion, device): 
    input = input.to(device)
    target = target.to(device)
    loss = criterion(input, target)
    return loss

## Valid function

In [56]:
from tqdm import tqdm

def valid(data_loader, model, criterion, device): 
    '''Validate the model on the validation set'''
    model.eval()
    running_loss = 0
    pbar = tqdm(total = len(data_loader.dataset), ncols=0, desc='Valid')

    for i, batch in enumerate(data_loader):
        with torch.no_grad():
            loss = loss(batch, model, criterion, device)
            running_loss += loss.item()
        pbar.update(data_loader.batch_size)
        pbar.set_postfix(loss=f"{running_loss / (i+1):.2f}")
    pbar.close()
    model.train()

    return running_loss / len(data_loader)

## Main function

In [57]:
import torch.nn.functional as F

def main(): 
    set_seed(CONFIG["seed"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info]: Use {device} now!")

    train_loader, valid_loader = prepare_dataloader(CONFIG["eeg_data_path"], CONFIG["text_embedding_path"], CONFIG["batch_size"], CONFIG["num_workers"])

    model = EEG_Embedding().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"] * len(train_loader))
    criterion = F.mse_loss

    best_accuracy = 0
    best_state_dict = None

    pbar = tqdm(range(CONFIG['epochs']), desc='Train', unit='epoch', dynamic_ncols=True)
    for epoch in pbar:
        running_loss = 0.0
        for i, batch in enumerate(train_loader):
            loss = loss(batch, model, criterion, device)
            batch_loss = loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            running_loss += batch_loss

        train_loss = running_loss/len(train_loader)
        pbar.set_postfix(train_loss=train_loss, train_acc=1-train_loss)

        if (epoch + 1) % CONFIG['valid_steps'] == 0:
            valid_loss = valid(valid_loader, model, criterion, device)
            valid_acc = 1-valid_loss
            pbar.write(f"[Info]: Valid acc: {valid_acc:.4f}")
            if valid_acc > best_accuracy:
                best_accuracy = valid_acc
                best_state_dict = model.state_dict()
                pbar.write(f"[Info]: 😄 Best acc updated: {best_accuracy:.4f}")

    pbar.close()

    torch.save(best_state_dict, os.path.join(CONFIG["save_path"], "best_model.pth"))
    print("="*50, f"\n[Info]: Best model saved to {os.path.join(CONFIG['save_path'], 'best_model.pth')}")


## Start Training!

In [58]:
if __name__ == '__main__':
    main()

[Info]: Use cpu now!
Preparing data...
Loaded EEG data with shape: (6, 2, 5, 50, 62, 5)
Final EEG data shape: (3000, 310)
Final Text embedding shape: torch.Size([1500, 59136])


RuntimeError: [enforce fail at alloc_cpu.cpp:116] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2365440000 bytes.