In [2]:
from torcheeg.datasets import SEEDDataset
from torcheeg import transforms
from torcheeg.datasets.constants import SEED_CHANNEL_LOCATION_DICT

dataset = SEEDDataset(root_path='./SEED/SEED_EEG/Preprocessed_EEG',
                      offline_transform=transforms.Compose([
                          transforms.BandDifferentialEntropy(),
                          transforms.ToGrid(SEED_CHANNEL_LOCATION_DICT)
                      ]),
                      online_transform=transforms.ToTensor(),
                      label_transform=transforms.Compose([
                          transforms.Select('emotion'),
                          transforms.Lambda(lambda x: x + 1)
                      ]),
                      num_worker=4)
print(dataset[0])

[2024-11-26 22:37:45] INFO (torcheeg/MainThread) 🔍 | Processing EEG data. Processed EEG data has been cached to [92m.torcheeg\datasets_1732642665062_seDY1[0m.
[2024-11-26 22:37:45] INFO (torcheeg/MainThread) ⏳ | Monitoring the detailed processing of a record for debugging. The processing of other records will only be reported in percentage to keep it clean.
[PROCESS]: 100%|██████████| 45/45 [3:19:21<00:00, 265.80s/it]  
[2024-11-27 02:29:34] INFO (torcheeg/MainThread) ✅ | All processed EEG data has been cached to .torcheeg\datasets_1732642665062_seDY1.
[2024-11-27 02:29:34] INFO (torcheeg/MainThread) 😊 | Please set [92mio_path[0m to [92m.torcheeg\datasets_1732642665062_seDY1[0m for the next run, to directly read from the cache if you wish to skip the data processing step.


(tensor([[[ 0.0000,  0.0000,  0.0000,  5.1832,  5.0887,  4.9807,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  4.5553,  0.0000,  4.2643,  0.0000,
           0.0000,  0.0000],
         [ 4.2014,  4.0099,  3.8065,  3.8590,  3.7828,  3.7195,  3.3290,
           3.8660,  3.5138],
         [ 3.5574,  3.8370,  3.1473,  3.1314,  3.4455,  2.8615,  2.8605,
           3.1997,  4.4979],
         [ 3.8301,  3.7472,  3.4873,  2.9300, -0.5460,  2.4503,  3.2057,
           3.7392,  4.5785],
         [ 3.8620,  3.5352,  3.4424,  3.1226, -0.6533,  2.3607,  3.8760,
           4.0473,  4.4404],
         [ 3.9283,  3.7993,  3.6490,  4.5985,  3.1423,  3.5070,  3.8388,
           4.3632,  4.5748],
         [ 0.0000,  4.0628,  4.0962,  3.8834,  3.8629,  4.1004,  4.5013,
           4.5144,  0.0000],
         [ 0.0000,  0.0000,  4.2134,  4.0241,  4.0828,  4.3369,  4.3554,
           0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  4.1442,  3.9311,  3.9702,  0.0000,
      

In [7]:
# Directory to save all samples
import os
save_dir = "./processed_eeg_data"
os.makedirs(save_dir, exist_ok=True)

# Iterate through the dataset (assuming `dataset` contains all preprocessed samples)
for idx, (eeg, label) in enumerate(dataset):
    sample = {
        'eeg': eeg,      # Preprocessed EEG Tensor
        'label': label   # Label
    }
    
    # Save each sample
    file_path = os.path.join(save_dir, f"sample_{idx}.pt")
    torch.save(sample, file_path)
    if idx % 100 == 0:
        print(f"Saved {idx + 1}/{len(dataset)} samples")

print(f"All samples saved in {save_dir}")


Saved 1/152730 samples
Saved 101/152730 samples
Saved 201/152730 samples
Saved 301/152730 samples
Saved 401/152730 samples
Saved 501/152730 samples
Saved 601/152730 samples
Saved 701/152730 samples
Saved 801/152730 samples
Saved 901/152730 samples
Saved 1001/152730 samples
Saved 1101/152730 samples
Saved 1201/152730 samples
Saved 1301/152730 samples
Saved 1401/152730 samples
Saved 1501/152730 samples
Saved 1601/152730 samples
Saved 1701/152730 samples
Saved 1801/152730 samples
Saved 1901/152730 samples
Saved 2001/152730 samples
Saved 2101/152730 samples
Saved 2201/152730 samples
Saved 2301/152730 samples
Saved 2401/152730 samples
Saved 2501/152730 samples
Saved 2601/152730 samples
Saved 2701/152730 samples
Saved 2801/152730 samples
Saved 2901/152730 samples
Saved 3001/152730 samples
Saved 3101/152730 samples
Saved 3201/152730 samples
Saved 3301/152730 samples
Saved 3401/152730 samples
Saved 3501/152730 samples
Saved 3601/152730 samples
Saved 3701/152730 samples
Saved 3801/152730 sample

In [46]:

from torch.utils.data import Dataset, DataLoader, random_split
import torch
import os

class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.file_list = sorted(os.listdir(data_dir))
        self.data_dir = data_dir

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        sample = torch.load(file_path)
        return sample['eeg'], sample['label']



In [47]:
data_dir = "./processed_eeg_data"
dataset = EEGDataset(data_dir)

# split dataset
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [48]:
## debugging 
for i in range(5):
    eeg, label = dataset[i]
    print(f"samples {i}: {eeg.shape}, {label}")

samples 0: torch.Size([4, 9, 9]), 2
samples 1: torch.Size([4, 9, 9]), 2
samples 2: torch.Size([4, 9, 9]), 2
samples 3: torch.Size([4, 9, 9]), 2
samples 4: torch.Size([4, 9, 9]), 1


  sample = torch.load(file_path)


In [49]:
## debugging
print(f"Training size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")


Training size: 106911, Validation size: 30546, Test size: 15273


In [50]:
for eeg_batch, label_batch in train_loader:
    print(f"Batch shape: {eeg_batch.shape}, Labels: {label_batch}")
    break


  sample = torch.load(file_path)


Batch shape: torch.Size([64, 4, 9, 9]), Labels: tensor([2, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 2, 2, 1, 1, 2,
        1, 0, 1, 1, 0, 2, 0, 1, 0, 2, 1, 2, 1, 2, 0, 2, 2, 1, 0, 0, 2, 2, 0, 1,
        0, 1, 2, 2, 1, 0, 1, 0, 2, 2, 0, 2, 2, 2, 2, 0])


In [51]:
import torch.nn as nn

class DEBaseEncoder(nn.Module):
    def __init__(self):
        super(DEBaseEncoder, self).__init__()
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32*4,128)
        
    def forward(self, x):
         print(f"Input to encoder: {x.shape}")
         x = self.pool(nn.ReLU()(self.conv1(x)))
         print(f"After conv1: {x.shape}")
         x = self.pool(nn.ReLU()(self.conv2(x)))
         print(f"After conv2: {x.shape}")
         x = x.view(x.size(0), -1)  # Flatten
         print(f"Flattened Shape: {x.shape}")
         x = nn.ReLU()(self.fc(x))
         return x

In [52]:
## Debugging base encoder
encoder = DEBaseEncoder().cuda()
eeg_batch, _ = next(iter(train_loader))
eeg_batch = eeg_batch.cuda()

print(f"Input shape to encoder: {eeg_batch.shape}")
encoder_output = encoder(eeg_batch)
print(f"Encoder output shape: {encoder_output.shape}")

  sample = torch.load(file_path)


Input shape to encoder: torch.Size([64, 4, 9, 9])
Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Encoder output shape: torch.Size([64, 128])


In [60]:
class projector (nn.Module):
    def __init__(self, input_dim=128, output_dim=64):
        super(projector, self).__init__()
        self.project = nn.Sequential(
            nn.Linear(input_dim,output_dim),
            nn.ReLU(),
            nn.Linear(output_dim,output_dim)
        )
    def forward(self, x):
        print(f"Input to projector: {x.shape}")
        x = self.project(x)
        print(f"Output of projector: {x.shape}")
        return x

In [54]:
## Debugging Projector 
projector = projector().cuda()
encoder_output = encoder(eeg_batch)  # Output from debugged encoder
print(f"Encoder output shape: {encoder_output.shape}")

projector_output = projector(encoder_output)
print(f"Projector output shape: {projector_output.shape}")  # Expected: [Batch, 64]


Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Encoder output shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
Projector output shape: torch.Size([64, 64])


In [55]:
class EmotionClassifier(nn.Module):
    def __init__(self, input_dim=128, num_classes=3):
        super(EmotionClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)


In [56]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        z_i = nn.functional.normalize(z_i, dim=1)
        z_j = nn.functional.normalize(z_j, dim=1)
        similarity_matrix = torch.matmul(z_i, z_j.T) / self.temperature
        print(f"Similarity Matrix: {similarity_matrix.shape}")
        labels = torch.arange(z_i.size(0)).to(z_i.device)
        print(f"Labels: {labels.shape}")
        loss = nn.CrossEntropyLoss()(similarity_matrix, labels)
        print(f"Contrastive Loss: {loss}")
        return loss


In [57]:
## debugging contrastiveLoss
contrastive_loss = ContrastiveLoss().cuda()
z_i = projector(encoder(eeg_batch))  # Representation 1
z_j = projector(encoder(eeg_batch))  # Representation 2 (positive pair in this case)

# Print shapes
print(f"z_i shape: {z_i.shape}, z_j shape: {z_j.shape}")  # Both should be [Batch, 64]

# Compute similarity and loss
loss = contrastive_loss(z_i, z_j)
print(f"Contrastive loss: {loss.item()}")


Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
z_i shape: torch.Size([64, 64]), z_j shape: torch.Size([64, 64])
Similarity Matrix: torch.Size([64, 64])
Labels: torch.Size([64])
Contrastive Loss: 3.9343695640563965
Contrastive loss: 3.9343695640563965


In [61]:
from torch.utils.tensorboard import SummaryWriter

encoder = DEBaseEncoder().cuda()
projector = projector().cuda()
contrastive_loss = ContrastiveLoss().cuda()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(projector.parameters()), lr=0.001)

writer = SummaryWriter(log_dir="./runs/contrastive_pretraining")

for epoch in range(5):
    encoder.train()
    projector.train()
    total_loss = 0

    for batch_idx, (eeg, _) in enumerate(train_loader):
        eeg = eeg.cuda()

        # Positive pair (augmentations or duplicates in preprocessed dataset)
        z_i = projector(encoder(eeg))
        z_j = projector(encoder(eeg))

        # Compute contrastive loss
        loss = contrastive_loss(z_i, z_j)
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        for name, param in encoder.named_parameters():
            if param.grad is not None:
                writer.add_scalar(f"Gradients/{name}", param.grad.norm(), epoch * len(train_loader) + batch_idx)
       
        optimizer.step()

        # Log loss to TensorBoard
        writer.add_scalar("Pretraining Loss", loss.item(), epoch * len(train_loader) + batch_idx)

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}")
    
writer.close()


  sample = torch.load(file_path)


Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
Similarity Matrix: torch.Size([64, 64])
Labels: torch.Size([64])
Contrastive Loss: 4.027973651885986
Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 128])
Input to projector: torch.Size([64, 128])
Output of projector: torch.Size([64, 64])
Input to encoder: torch.Size([64, 4, 9, 9])
After conv1: torch.Size([64, 16, 4, 4])
After conv2: torch.Size([64, 32, 2, 2])
Flattened Shape: torch.Size([64, 1

In [62]:
# Save the pretrained encoder and projector
torch.save(encoder.state_dict(), "pretrained_encoder.pth")
torch.save(projector.state_dict(), "pretrained_projector.pth")

print("Pretraining completed. Models saved as 'pretrained_encoder.pth' and 'pretrained_projector.pth'.")


Pretraining completed. Models saved as 'pretrained_encoder.pth' and 'pretrained_projector.pth'.


In [63]:
torch.save({
    "epoch": epoch + 1,
    "encoder_state_dict": encoder.state_dict(),
    "projector_state_dict": projector.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": total_loss / len(train_loader),
}, "pretraining_checkpoint.pth")


In [64]:
# Load the checkpoint
checkpoint = torch.load("pretraining_checkpoint.pth")

# Restore states
encoder.load_state_dict(checkpoint["encoder_state_dict"])
projector.load_state_dict(checkpoint["projector_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

print(f"Checkpoint loaded. Resuming from epoch {start_epoch} with loss {loss:.4f}.")


  checkpoint = torch.load("pretraining_checkpoint.pth")


Checkpoint loaded. Resuming from epoch 5 with loss 0.0129.


In [None]:
classifier = EmotionClassifier().cuda()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss()

writer = SummaryWriter(log_dir="./runs/fine_tuning")

for epoch in range(10):
    encoder.train()
    classifier.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (eeg, label) in enumerate(train_loader):
        eeg, label = eeg.cuda(), label.cuda()

        # Forward pass
        features = encoder(eeg)
        outputs = classifier(features)
        loss = criterion(outputs, label)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

        # Log loss and accuracy to TensorBoard
        writer.add_scalar("Training Loss", loss.item(), epoch * len(train_loader) + batch_idx)
        writer.add_scalar("Training Accuracy", 100 * correct / total, epoch * len(train_loader) + batch_idx)

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}, Accuracy: {correct / total * 100:.2f}%")


In [None]:
encoder.eval()
classifier.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for eeg, label in test_loader:
        eeg, label = eeg.cuda(), label.cuda()
        features = encoder(eeg)
        outputs = classifier(features)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(label.cpu().numpy())

from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Generate metrics
print(classification_report(all_labels, all_preds))

# Confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
