## Configuration

In [51]:
CONFIG = {
    "data_path": "data\PSD_DE\imaging", # or "data/PSD_DE/watching"
    "label_path": "data/metadata/GT_label.npy",
    "save_path": "checkpoints/eeg2label",

    "seed": 42, 
    "train_valid": [0.8, 0.2], 
    "batch_size": 32,
    "num_workers": 0,

    "input_dim": 62*5,
    "emb_dim": 64,
    "out_dim": 50,

    "learning_rate": 0.001, 
    "valid_steps": 10,
    "epochs": 100, 
}

In [52]:
import torch
import numpy as np
import random
import os

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

## Load Data

In [53]:
from torch.utils.data import Dataset

from einops import rearrange, repeat

class myDataset(Dataset): 
    '''Generate dataset'''
    def __init__(self, data_path, label_path):
        self.data, self.label = self.data_label(data_path, label_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].astype(np.float32), self.label[idx].astype(np.int64)-1

    def data_label(self, data_path, label_path):
        '''concat the watching or imaging data and label into structured data
        Args:
            data_path: the path of watching or imaging data, corresponding data's shape (2, 5, 50, 62, 5)
                2: PSD and DE, 5: 5 videos, 50: 50 clips per video, 62: 62 electrodes, 5: 5 frequency bands
            label_path: the path of label, corresponding label's shape (5, 50)
                5: 5 videos, 50: 50 clips per video
        Returns:
            structured_data: the structured data, (60*2*5*50, )
                60: 60 experiments, 2: PSD and DE, 5: 5 videos, 50: 50 clips per video
                each elements: 
                    'features':(62*5, ), 62 electrodes, 5 frequency bands
                    'label':int"
        '''
        data = []
        for file in os.listdir(data_path):
            file_path = os.path.join(data_path, file)
            data.append(np.load(file_path))
        data = np.stack(data, axis=0)
        label = np.load(label_path)
        rearrange_data = rearrange(data, 'a b c d e f -> (a b c d) (e f)') # a is number of experiments
        rearrange_label = repeat(label, 'a b -> (repeat a b)', repeat=data.shape[0]*data.shape[1])

        return rearrange_data, rearrange_label

## Load Data

In [54]:
from torch.utils.data import DataLoader, random_split

def get_dataloader(data_path, label_path, batch_size, num_workers=0):
    '''Generate dataloader'''
    dataset = myDataset(data_path, label_path)
    trainset, validset = random_split(dataset, CONFIG["train_valid"])

    train_loader = DataLoader(
        trainset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers, 
        pin_memory=True
    )
    
    valid_loader = DataLoader(
        validset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers, 
        pin_memory=True
    )

    return train_loader, valid_loader

## Define Model

In [55]:
from models.eeg2label import glfnet

## Loss function

In [56]:
def loss_acc(batch, model, criterion, device):
    '''Forward a batch through the model'''
    input, labels = batch
    input = input.to(device)
    labels = labels.to(device)

    output = model(input)

    loss = criterion(output, labels)

    preds = output.argmax(1)
    accuracy = torch.mean((preds == labels).float())

    return loss, accuracy

## Valid function

In [57]:
from tqdm import tqdm

def valid(data_loader, model, criterion, device): 
    '''Validate the model on the validation set'''
    model.eval()
    running_loss = 0
    running_accuracy = 0
    pbar = tqdm(total = len(data_loader.dataset), ncols=0, desc='Valid')

    for i, batch in enumerate(data_loader):
        with torch.no_grad():
            loss, accuracy = loss_acc(batch, model, criterion, device)
            running_loss += loss.item()
            running_accuracy += accuracy.item()
        pbar.update(data_loader.batch_size)
        pbar.set_postfix(
			loss=f"{running_loss / (i+1):.2f}",
			accuracy=f"{running_accuracy / (i+1):.2f}",
		)
    pbar.close()
    model.train()

    return running_accuracy / len(data_loader)

# Main Function

In [58]:
from torch import nn

def main():
    set_seed(CONFIG["seed"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info]: Use {device} now!")

    train_loader, valid_loader = get_dataloader(CONFIG["data_path"], CONFIG["label_path"], CONFIG["batch_size"], num_workers=CONFIG["num_workers"])
    print(f"[Info]: Finish loading data!")

    model = glfnet(input_dim=CONFIG["input_dim"], emb_dim=CONFIG["emb_dim"], out_dim=CONFIG["out_dim"]).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["learning_rate"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200 * len(train_loader))

    best_accuracy = 0
    best_state_dict = None

    pbar = tqdm(range(CONFIG['epochs']), desc='Train', unit='epoch', dynamic_ncols=True)
    for epoch in pbar:
        running_loss = 0.0
        running_accuracy = 0.0
        for i, batch in enumerate(train_loader):
            loss, accuracy = loss_acc(batch, model, criterion, device)
            batch_loss = loss.item()
            batch_accuracy = accuracy.item()

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            running_loss += batch_loss
            running_accuracy += batch_accuracy

        train_loss = running_loss/len(train_loader)
        train_acc = running_accuracy/len(train_loader)
        pbar.set_postfix(train_loss=train_loss, train_acc=train_acc)

        if (epoch + 1) % CONFIG['valid_steps'] == 0:
            valid_acc = valid(valid_loader, model, criterion, device)
            pbar.write(f"[Info]: Valid acc: {valid_acc:.4f}")
            if valid_acc > best_accuracy:
                best_accuracy = valid_acc
                best_state_dict = model.state_dict()
                pbar.write(f"[Info]: 😄 Best acc updated: {best_accuracy:.4f}")

    pbar.close()

    torch.save(best_state_dict, os.path.join(CONFIG["save_path"], "best_model.pth"))
    print("="*50, f"\n[Info]: Best model saved to {os.path.join(CONFIG['save_path'], 'best_model.pth')}")

    

## Start training!

In [59]:
if __name__ == '__main__':
    main()

[Info]: Use cpu now!
[Info]: Finish loading data!


Valid: 608it [00:00, 12276.98it/s, accuracy=0.04, loss=3.01]ain_acc=0.055, train_loss=3]    
Train:  10%|█         | 10/100 [00:03<00:28,  3.12epoch/s, train_acc=0.055, train_loss=3]

[Info]: Valid acc: 0.0428
[Info]: 😄 Best acc updated: 0.0428


Valid: 608it [00:00, 13387.88it/s, accuracy=0.04, loss=3.01]rain_acc=0.0483, train_loss=3]
Train:  20%|██        | 20/100 [00:06<00:24,  3.21epoch/s, train_acc=0.0483, train_loss=3]

[Info]: Valid acc: 0.0411


Valid: 608it [00:00, 13487.93it/s, accuracy=0.05, loss=3.00]rain_acc=0.0612, train_loss=2.99]
Train:  30%|███       | 30/100 [00:09<00:21,  3.27epoch/s, train_acc=0.0612, train_loss=2.99]

[Info]: Valid acc: 0.0461
[Info]: 😄 Best acc updated: 0.0461


Valid: 608it [00:00, 14080.36it/s, accuracy=0.05, loss=3.00]rain_acc=0.0571, train_loss=2.99]
Train:  40%|████      | 40/100 [00:12<00:18,  3.25epoch/s, train_acc=0.0571, train_loss=2.99]

[Info]: Valid acc: 0.0515
[Info]: 😄 Best acc updated: 0.0515


Valid: 608it [00:00, 14229.88it/s, accuracy=0.05, loss=3.00]rain_acc=0.0658, train_loss=2.99]
Train:  50%|█████     | 50/100 [00:15<00:15,  3.19epoch/s, train_acc=0.0658, train_loss=2.99]

[Info]: Valid acc: 0.0466


Valid: 608it [00:00, 13869.39it/s, accuracy=0.04, loss=3.01]rain_acc=0.0592, train_loss=2.99]
Train:  60%|██████    | 60/100 [00:18<00:12,  3.28epoch/s, train_acc=0.0592, train_loss=2.99]

[Info]: Valid acc: 0.0422


Valid: 608it [00:00, 13983.47it/s, accuracy=0.04, loss=3.01]rain_acc=0.0629, train_loss=2.99]
Train:  70%|███████   | 70/100 [00:21<00:09,  3.16epoch/s, train_acc=0.0629, train_loss=2.99]

[Info]: Valid acc: 0.0362


Valid: 608it [00:00, 14060.64it/s, accuracy=0.04, loss=3.01]rain_acc=0.0671, train_loss=2.99]
Train:  80%|████████  | 80/100 [00:24<00:06,  3.05epoch/s, train_acc=0.0671, train_loss=2.99]

[Info]: Valid acc: 0.0400


Valid: 608it [00:00, 13803.33it/s, accuracy=0.03, loss=3.01]rain_acc=0.0638, train_loss=2.99]
Train:  90%|█████████ | 90/100 [00:27<00:03,  3.25epoch/s, train_acc=0.0638, train_loss=2.99]

[Info]: Valid acc: 0.0345


Valid: 608it [00:00, 14056.92it/s, accuracy=0.04, loss=3.01]rain_acc=0.0617, train_loss=2.98]
Train: 100%|██████████| 100/100 [00:30<00:00,  3.27epoch/s, train_acc=0.0617, train_loss=2.98]

[Info]: Valid acc: 0.0378
[Info]: Best model saved to checkpoints/eeg2label\best_model.pth





---

---

---

---

In [60]:
# 初始化模型、损失函数、优化器等
from src.eeg_encoders.models import glfnet
import torch.nn as nn
import torch.optim as optim

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[Info] Using device: {device}")

# 模型参数
input_dim = 62 * 5  # 62个电极 * 5个频段 = 310
emb_dim = 128       # 嵌入维度
out_dim = 50        # 50个视频类别

# use glfnet to analyze both global and local features
model = glfnet(input_dim, emb_dim, out_dim)
model = model.to(device)

# use to handle classification task
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# TODO: find the best scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)

print(f"[Info] Model initialized: {model.__class__.__name__}")
print(f"[Info] Total parameters: {sum(p.numel() for p in model.parameters())}")


ModuleNotFoundError: No module named 'src.eeg_encoders'

## Model Function

In [None]:
def main():
    set_seed(config['seed'])
    
    best_valid_acc = -1
    train_loader, valid_loader = get_dataloader(config['data_path'], config['label_path'], config['batch_size'])
    
    print(f"[Info] Training dataset size: {len(train_loader.dataset)}")
    print(f"[Info] Validation dataset size: {len(valid_loader.dataset)}")
    print(f"[Info] Starting training for {config['epochs']} epochs...")

    pbar = tqdm(range(config['epochs']), desc='Train', unit='epoch', dynamic_ncols=True)
    for epoch in pbar:
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
        pbar.set_postfix(train_loss=train_loss, train_acc=train_acc)
        if (epoch + 1) % config['valid_steps'] == 0:
            valid_acc = valid_epoch(model, valid_loader, criterion, device)
            
            # 保存最佳模型
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                torch.save(model.state_dict(), config['save_path'])
                print("="*50, f"\n[Info] Best model saved! Valid accuracy: {valid_acc:.4f}\n", "="*50)
    
    pbar.close()

    print("="*50, f"\n[Info] Training completed! Best validation accuracy: {best_valid_acc:.4f}\n", "="*50)
    return best_valid_acc

## Strat training!

In [None]:
# 开始训练
if __name__ == "__main__":
    best_acc = main()
    print(f"\n[Final] Training finished with best validation accuracy: {best_acc:.4f}")


[Info] Training dataset size: 2400
[Info] Validation dataset size: 600
[Info] Starting training for 100 epochs...


Valid: 608it [00:00, 5677.92it/s, accuracy=0.06, loss=3.01]rain_acc=0.0475, train_loss=3]   
Train:  10%|█         | 10/100 [00:06<00:58,  1.55epoch/s, train_acc=0.0475, train_loss=3]

[Info] Best model saved! Valid accuracy: 0.0565


Train:  13%|█▎        | 13/100 [00:08<00:58,  1.48epoch/s, train_acc=0.0488, train_loss=3]


KeyboardInterrupt: 