In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import random
from tqdm import tqdm
import pickle

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class SpectralNorm(nn.Module):
    """Spectral Normalization"""
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = F.normalize(torch.mv(torch.t(w.view(height, -1).data), u.data), dim=0)
            u.data = F.normalize(torch.mv(w.view(height, -1).data, v.data), dim=0)

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)
        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = F.normalize(u.data, dim=0)
        v.data = F.normalize(v.data, dim=0)
        w_bar = nn.Parameter(w.data)

        del self.module._parameters[self.name]
        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

def spectral_norm(module, name='weight', power_iterations=1):
    return SpectralNorm(module, name, power_iterations)

In [None]:
class ImageEncoder(nn.Module):
    """Encode MNIST images (28x28 grayscale) to latent features """
    def __init__(self, latent_dim=512):
        super(ImageEncoder, self).__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 14x14
            
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 7x7
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4))  # 4x4
        )
        
        self.projection_head = nn.Sequential(
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, latent_dim)
        )
    
    def forward(self, x):
        # x shape: (batch_size, 1, 28, 28)
        features = self.conv_layers(x)
        features = features.view(features.size(0), -1)  
        projected = self.projection_head(features)
        return F.normalize(projected, dim=1)  # L2 normalize for contrastive learning

In [None]:
# Thoughtviz ImageNet and eggimagedecode

# class ImageEncoder(nn.Module):
#     """Encode images to latent features for contrastive learning"""
#     def __init__(self, latent_dim=512):
#         super(ImageEncoder, self).__init__()
        
#         resnet = resnet50(pretrained=True)
#         self.backbone = nn.Sequential(*list(resnet.children())[:-1])  
        
#         self.projection_head = nn.Sequential(
#             nn.Linear(2048, 1024),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 512),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(512, latent_dim)
#         )
        
#         for param in list(self.backbone.parameters())[:-20]:
#             param.requires_grad = False
    
#     def forward(self, x):
#         # x shape: (batch_size, 3, 128, 128)
#         features = self.backbone(x)
#         features = features.view(features.size(0), -1)  
#         projected = self.projection_head(features)
#         return F.normalize(projected, dim=1)  # L2 normalize for contrastive learning


In [None]:
class EEGEncoder(nn.Module):
    """EEG Encoder for shape (batch, 14, 32, 1)"""
    def __init__(self, input_channels=14, sequence_length=32, latent_dim=512):
        super(EEGEncoder, self).__init__()
        
        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        
        self.lstm1 = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(256, 64, batch_first=True, bidirectional=True)
        
        self.fc1 = nn.Linear(128 * sequence_length, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, latent_dim)
        
        self.contrastive_head = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, latent_dim)
        )
        
        self.dropout = nn.Dropout(0.3)
        self.leaky_relu = nn.LeakyReLU(0.2)
        
    def forward(self, x, return_contrastive=False):
        # x shape: (batch_size, 14, 32, 1) -> (batch_size, 14, 32)
        x = x.squeeze(-1)  
        
        x = self.leaky_relu(self.conv1(x))  # (batch, 64, 32)
        x = self.dropout(x)
        x = self.leaky_relu(self.conv2(x))  # (batch, 128, 32)
        x = self.dropout(x)
        x = self.leaky_relu(self.conv3(x))  # (batch, 256, 32)
        
        # Transpose for LSTM: (batch, seq_len, features)
        x = x.transpose(1, 2)  # (batch, 32, 256)
        
        x, _ = self.lstm1(x)
        x = self.dropout(x)
        x, _ = self.lstm2(x)
        
        x = x.flatten(1)
        x = self.leaky_relu(self.fc1(x))
        x = self.dropout(x)
        x = self.leaky_relu(self.fc2(x))
        x = self.dropout(x)
        features = self.fc3(x)
        
        if return_contrastive:
            contrastive_features = self.contrastive_head(features)
            contrastive_features = F.normalize(contrastive_features, dim=1)
            return features, contrastive_features
        
        return features

In [None]:
# eggimagedecode

# class EEGEncoder(nn.Module):
#     """Encode EEG signals to latent features"""
#     def __init__(self, input_channels=17, sequence_length=100, latent_dim=512):
#         super(EEGEncoder, self).__init__()
        
#         self.lstm1 = nn.LSTM(input_channels, 128, batch_first=True, bidirectional=True)
#         self.lstm2 = nn.LSTM(256, 64, batch_first=True, bidirectional=True)
        
#         self.fc1 = nn.Linear(128 * sequence_length, 1024)
#         self.fc2 = nn.Linear(1024, 512)
#         self.fc3 = nn.Linear(512, latent_dim)
        
#         self.contrastive_head = nn.Sequential(
#             nn.Linear(latent_dim, 512),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(512, latent_dim)
#         )
        
#         self.dropout = nn.Dropout(0.3)
#         self.leaky_relu = nn.LeakyReLU(0.2)
        
#     def forward(self, x, return_contrastive=False):
#         # x shape: (batch_size, 4, 17, 100) -> (batch_size, 100, 17)
#         x = x.mean(dim=1)  # Average across 4 trials
#         x = x.permute(0, 2, 1)  # (batch_size, 100, 17)
        
#         x, _ = self.lstm1(x)
#         x = self.dropout(x)
#         x, _ = self.lstm2(x)
        
#         x = x.flatten(1)
#         x = self.leaky_relu(self.fc1(x))
#         x = self.dropout(x)
#         x = self.leaky_relu(self.fc2(x))
#         x = self.dropout(x)
#         features = self.fc3(x)
        
#         if return_contrastive:
#             contrastive_features = self.contrastive_head(features)
#             contrastive_features = F.normalize(contrastive_features, dim=1)
#             return features, contrastive_features
        
#         return features

In [None]:
class Generator(nn.Module):
    """Generator for MNIST - outputs 28x28 grayscale images"""
    def __init__(self, eeg_dim=512, noise_dim=100, n_classes=10):
        super(Generator, self).__init__()
        
        self.class_embedding = nn.Embedding(n_classes, 50)
        
        input_dim = eeg_dim + noise_dim + 50
        
        self.fc = nn.Linear(input_dim, 7 * 7 * 256)
        
        self.convt1 = spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False))
        self.bn1 = nn.BatchNorm2d(128)
        
        self.convt2 = spectral_norm(nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False))
        self.bn2 = nn.BatchNorm2d(64)
        
        self.convt3 = spectral_norm(nn.ConvTranspose2d(64, 1, 3, 1, 1, bias=False))
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.tanh = nn.Tanh()
        
    def forward(self, eeg_features, noise, class_labels):
        class_emb = self.class_embedding(class_labels)
        
        x = torch.cat([eeg_features, noise, class_emb], dim=1)
        
        x = self.fc(x)
        x = x.view(-1, 256, 7, 7)
        
        x = self.leaky_relu(self.bn1(self.convt1(x)))  # 14x14
        x = self.leaky_relu(self.bn2(self.convt2(x)))  # 28x28
        x = self.tanh(self.convt3(x))                   # 28x28x1
        
        return x

In [None]:
# Thoughtviz ImageNet

# class Generator(nn.Module):
#     """Generator for 10 classes"""
#     def __init__(self, eeg_dim=512, noise_dim=100, n_classes=10): # n_classes=1654 for eggimagedecode
#         super(Generator, self).__init__()
        
#         self.class_embedding = nn.Embedding(n_classes, 50)
        
#         input_dim = eeg_dim + noise_dim + 50
        
#         self.fc = nn.Linear(input_dim, 8 * 8 * 1024)
        
#         self.convt1 = spectral_norm(nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False))
#         self.bn1 = nn.BatchNorm2d(512)
        
#         self.convt2 = spectral_norm(nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False))
#         self.bn2 = nn.BatchNorm2d(256)
        
#         self.convt3 = spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False))
#         self.bn3 = nn.BatchNorm2d(128)
        
#         self.convt4 = spectral_norm(nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False))
#         self.bn4 = nn.BatchNorm2d(64)
        
#         self.convt5 = spectral_norm(nn.ConvTranspose2d(64, 3, 3, 1, 1, bias=False))
        
#         self.leaky_relu = nn.LeakyReLU(0.2)
#         self.tanh = nn.Tanh()
        
#     def forward(self, eeg_features, noise, class_labels):
#         class_emb = self.class_embedding(class_labels)
        
#         x = torch.cat([eeg_features, noise, class_emb], dim=1)
        
#         x = self.fc(x)
#         x = x.view(-1, 1024, 8, 8)
        
#         x = self.leaky_relu(self.bn1(self.convt1(x)))  # 16x16
#         x = self.leaky_relu(self.bn2(self.convt2(x)))  # 32x32
#         x = self.leaky_relu(self.bn3(self.convt3(x)))  # 64x64
#         x = self.leaky_relu(self.bn4(self.convt4(x)))  # 128x128
#         x = self.tanh(self.convt5(x))                   # 128x128x3
        
#         return x

In [None]:
class Discriminator(nn.Module):
    """ Discriminator for MNIST - processes 28x28 grayscale images"""
    def __init__(self, n_classes=10):
        super(Discriminator, self).__init__()
        
        self.class_embedding = nn.Embedding(n_classes, 50)
        self.class_projection = nn.Linear(50, 28 * 28)
        
        self.conv1 = nn.Conv2d(2, 64, 3, 2, 1, bias=False)  # 14x14
        self.conv2 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)  # 7x7
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, 3, 2, 1, bias=False)  # 3x3
        self.bn3 = nn.BatchNorm2d(256)
        
        self.conv4 = nn.Conv2d(256, 1, 3, 1, 1, bias=False)  # 3x3
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        
    def forward(self, x, class_labels):
        batch_size = x.size(0)
        
        class_emb = self.class_embedding(class_labels)
        class_map = self.class_projection(class_emb)
        class_map = class_map.view(batch_size, 1, 28, 28)
        
        x = torch.cat([x, class_map], dim=1)
        
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.bn2(self.conv2(x)))
        x = self.leaky_relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        
        return x.view(batch_size, -1).mean(dim=1)

In [None]:
# Thoughtviz ImageNet

# class Discriminator(nn.Module):
#     """Discriminator for 10 classes"""
#     def __init__(self, n_classes=10): # n_classes=1654 for eggimagedecode
#         super(Discriminator, self).__init__()
        
#         self.class_embedding = nn.Embedding(n_classes, 50)
#         self.class_projection = nn.Linear(50, 128 * 128)
        
#         self.conv1 = nn.Conv2d(4, 64, 3, 2, 1, bias=False)  # 4 channels (3 RGB + 1 class)
#         self.conv2 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
#         self.bn2 = nn.BatchNorm2d(128)
        
#         self.conv3 = nn.Conv2d(128, 256, 3, 2, 1, bias=False)
#         self.bn3 = nn.BatchNorm2d(256)
        
#         self.conv4 = nn.Conv2d(256, 512, 3, 2, 1, bias=False)
#         self.bn4 = nn.BatchNorm2d(512)
        
#         self.conv5 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False)
#         self.bn5 = nn.BatchNorm2d(1024)
        
#         self.conv6 = nn.Conv2d(1024, 1, 3, 1, 1, bias=False)
        
#         self.leaky_relu = nn.LeakyReLU(0.2)
        
#     def forward(self, x, class_labels):
#         batch_size = x.size(0)
        
#         class_emb = self.class_embedding(class_labels)
#         class_map = self.class_projection(class_emb)
#         class_map = class_map.view(batch_size, 1, 128, 128)
        
#         x = torch.cat([x, class_map], dim=1)
        
#         x = self.leaky_relu(self.conv1(x))
#         x = self.leaky_relu(self.bn2(self.conv2(x)))
#         x = self.leaky_relu(self.bn3(self.conv3(x)))
#         x = self.leaky_relu(self.bn4(self.conv4(x)))
#         x = self.leaky_relu(self.bn5(self.conv5(x)))
#         x = self.conv6(x)
        
#         return x.view(batch_size, -1).mean(dim=1)

In [None]:
def info_nce_loss(eeg_features, img_features, temperature=0.07):
    """
    InfoNCE loss for contrastive learning
    """
    batch_size = eeg_features.size(0)
    
    similarity_matrix = torch.matmul(eeg_features, img_features.T) / temperature
    labels = torch.arange(batch_size).to(eeg_features.device)
    loss = F.cross_entropy(similarity_matrix, labels)
    
    return loss

def triplet_loss(anchor, positive, negative, margin=1.0):
    """
    Triplet loss for contrastive learning
    """
    pos_dist = F.pairwise_distance(anchor, positive)
    neg_dist = F.pairwise_distance(anchor, negative)
    
    loss = F.relu(pos_dist - neg_dist + margin)
    return loss.mean()

In [None]:
class EEGMNISTDataset(Dataset):
    def __init__(self, eeg_pickle_path, train=True, transform=None):
        with open(eeg_pickle_path, 'rb') as f:
            data_dict = pickle.load(f, encoding='latin1')
        
        if train:
            self.eeg_data = data_dict['x_train']  # (45390, 14, 32, 1)
            self.eeg_labels = data_dict['y_train']  # (45390, 10) - one-hot encoded
        else:
            self.eeg_data = data_dict['x_test']
            self.eeg_labels = data_dict['y_test']
        
        self.class_labels = np.argmax(self.eeg_labels, axis=1)
        
        from torchvision.datasets import MNIST
        
        mnist_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
        ])
        
        self.mnist_dataset = MNIST(
            root='./mnist_data', 
            train=train, 
            download=True, 
            transform=mnist_transform
        )
        
        self.class_to_mnist_images = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(self.mnist_dataset):
            self.class_to_mnist_images[label].append(idx)
        
        self.transform = transform
        
        print(f"Dataset: {len(self.eeg_data)} EEG samples, {len(self.mnist_dataset)} MNIST images")
        print(f"EEG shape: {self.eeg_data.shape}")
        print(f"EEG labels shape: {self.eeg_labels.shape}")
        
    def __len__(self):
        return len(self.eeg_data)
    
    def __getitem__(self, idx):
        eeg = torch.FloatTensor(self.eeg_data[idx])  # (14, 32, 1)
        eeg_class = self.class_labels[idx]
        
        if len(self.class_to_mnist_images[eeg_class]) > 0:
            mnist_idx = random.choice(self.class_to_mnist_images[eeg_class])
        else:
            mnist_idx = random.randint(0, len(self.mnist_dataset) - 1)
            
        mnist_image, mnist_label = self.mnist_dataset[mnist_idx]
        
        available_classes = [c for c in range(10) if c != mnist_label and len(self.class_to_mnist_images[c]) > 0]
        if available_classes:
            negative_class = random.choice(available_classes)
            negative_mnist_idx = random.choice(self.class_to_mnist_images[negative_class])
        else:
            negative_mnist_idx = random.choice([i for i in range(len(self.mnist_dataset)) 
                                              if self.mnist_dataset[i][1] != mnist_label])
        
        negative_mnist_image, _ = self.mnist_dataset[negative_mnist_idx]
        
        return eeg, mnist_image, negative_mnist_image, torch.LongTensor([mnist_label])


In [None]:
# Thoughtviz ImageNet

# class EEGImageDataset(Dataset):
#     def __init__(self, eeg_pickle_path, images_dir, transform=None):
#         with open(eeg_pickle_path, 'rb') as f:
#             data_dict = pickle.load(f, encoding='latin1')
        
#         self.eeg_data = data_dict['x_train']  # (45390, 14, 32, 1)
#         self.eeg_labels = data_dict['y_train']  # (45390, 10) - one-hot encoded
        
#         self.class_labels = np.argmax(self.eeg_labels, axis=1)
        
#         self.images_dir = Path(images_dir)
#         self.class_dirs = sorted([d for d in self.images_dir.iterdir() if d.is_dir()])
#         self.class_to_idx = {d.name: i for i, d in enumerate(self.class_dirs)}
        
#         self.image_paths = []
#         self.image_labels = []
        
#         for class_idx, class_dir in enumerate(self.class_dirs):
#             image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.JPEG')) + list(class_dir.glob('*.jpeg'))
#             for img_path in image_files:
#                 self.image_paths.append(img_path)
#                 self.image_labels.append(class_idx)
        
#         self.class_to_images = {}
#         for i, label in enumerate(self.image_labels):
#             if label not in self.class_to_images:
#                 self.class_to_images[label] = []
#             self.class_to_images[label].append(i)
        
#         self.transform = transform or transforms.Compose([
#             transforms.Resize((128, 128)),
#             transforms.ToTensor(),
#             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
#         ])
        
#         print(f"Dataset: {len(self.eeg_data)} EEG samples, {len(self.image_paths)} images")
#         print(f"Classes: {len(self.class_dirs)}")
#         print(f"EEG shape: {self.eeg_data.shape}")
#         print(f"EEG labels shape: {self.eeg_labels.shape}")
        
#     def __len__(self):
#         return len(self.eeg_data)
    
#     def __getitem__(self, idx):
#         eeg = torch.FloatTensor(self.eeg_data[idx])  # (14, 32, 1)
#         eeg_class = self.class_labels[idx]
        
#         if eeg_class in self.class_to_images and len(self.class_to_images[eeg_class]) > 0:
#             img_idx = random.choice(self.class_to_images[eeg_class])
#         else:
#             img_idx = random.randint(0, len(self.image_paths) - 1)
            
#         img_path = self.image_paths[img_idx]
#         img_class = self.image_labels[img_idx]
        
#         image = Image.open(img_path).convert('RGB')
#         image = self.transform(image)
        
#         available_classes = [c for c in self.class_to_images.keys() if c != img_class and len(self.class_to_images[c]) > 0]
#         if available_classes:
#             negative_class = random.choice(available_classes)
#             negative_img_idx = random.choice(self.class_to_images[negative_class])
#             negative_img_path = self.image_paths[negative_img_idx]
#             negative_image = Image.open(negative_img_path).convert('RGB')
#             negative_image = self.transform(negative_image)
#         else:
#             negative_img_idx = random.choice([i for i in range(len(self.image_paths)) if self.image_labels[i] != img_class])
#             negative_img_path = self.image_paths[negative_img_idx]
#             negative_image = Image.open(negative_img_path).convert('RGB')
#             negative_image = self.transform(negative_image)
        
#         return eeg, image, negative_image, torch.LongTensor([img_class])

In [None]:
# eggimagedecode
# class EEGImageDataset(Dataset):
#     def __init__(self, eeg_file_path, images_dir, transform=None):
#         data_dict = np.load(eeg_file_path, allow_pickle=True).item()
#         self.eeg_data = data_dict['preprocessed_eeg_data']
        
#         self.images_dir = Path(images_dir)
#         self.class_dirs = sorted([d for d in self.images_dir.iterdir() if d.is_dir()])
#         self.class_to_idx = {d.name: i for i, d in enumerate(self.class_dirs)}
        
#         self.image_paths = []
#         self.image_labels = []
        
#         for class_idx, class_dir in enumerate(self.class_dirs):
#             image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))
#             for img_path in image_files:
#                 self.image_paths.append(img_path)
#                 self.image_labels.append(class_idx)
        
#         self.class_to_images = {}
#         for i, label in enumerate(self.image_labels):
#             if label not in self.class_to_images:
#                 self.class_to_images[label] = []
#             self.class_to_images[label].append(i)
        
#         self.transform = transform or transforms.Compose([
#             transforms.Resize((128, 128)),
#             transforms.ToTensor(),
#             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
#         ])
        
#         print(f"Dataset: {len(self.eeg_data)} EEG samples, {len(self.image_paths)} images")
#         print(f"Classes: {len(self.class_dirs)}")
        
#     def __len__(self):
#         return min(len(self.eeg_data), len(self.image_paths))
    
#     def __getitem__(self, idx):
#         eeg = torch.FloatTensor(self.eeg_data[idx])
        
#         if idx < len(self.image_paths):
#             img_idx = idx
#         else:
#             img_idx = random.randint(0, len(self.image_paths) - 1)
            
#         img_path = self.image_paths[img_idx]
#         img_class = self.image_labels[img_idx]
        
#         image = Image.open(img_path).convert('RGB')
#         image = self.transform(image)
        
#         negative_class = random.choice([c for c in self.class_to_images.keys() if c != img_class])
#         negative_img_idx = random.choice(self.class_to_images[negative_class])
#         negative_img_path = self.image_paths[negative_img_idx]
#         negative_image = Image.open(negative_img_path).convert('RGB')
#         negative_image = self.transform(negative_image)
        
#         return eeg, image, negative_image, torch.LongTensor([img_class])

In [None]:
def load_test_data(eeg_pickle_path):
    with open(eeg_pickle_path, 'rb') as f:
        data_dict = pickle.load(f, encoding='latin1')
    
    x_test = data_dict['x_test']  # (5706, 14, 32, 1)
    y_test = data_dict['y_test']  # (5706, 10)
    test_class_labels = np.argmax(y_test, axis=1)
    
    return x_test, test_class_labels

In [None]:
def discriminator_hinge_loss(real_output, fake_output):
    real_loss = torch.mean(F.relu(1.0 - real_output))
    fake_loss = torch.mean(F.relu(1.0 + fake_output))
    return (real_loss + fake_loss) / 2.0

def generator_hinge_loss(fake_output):
    return -torch.mean(fake_output)

def diff_augment(x, policy="color,translation"):
    if "color" in policy:
        x = x + torch.randn_like(x) * 0.1
        x = torch.clamp(x, -1, 1)
    
    if "translation" in policy:
        if random.random() > 0.5:
            shift = random.randint(-4, 4)
            x = torch.roll(x, shift, dims=2)
            x = torch.roll(x, shift, dims=3)
    return x

In [None]:
def diff_augment_mnist(x, policy="color,translation"):
    if "color" in policy:
        x = x + torch.randn_like(x) * 0.05
        x = torch.clamp(x, -1, 1)
    
    if "translation" in policy:
        if random.random() > 0.5:
            shift = random.randint(-2, 2)  
            x = torch.roll(x, shift, dims=2)
            x = torch.roll(x, shift, dims=3)
    
    return x

In [None]:
def train(eeg_pickle_path, num_epochs=100, batch_size=32, lr=0.0002):
    dataset = EEGMNISTDataset(eeg_pickle_path, train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    eeg_encoder = EEGEncoder(input_channels=14, sequence_length=32).to(device)
    image_encoder = ImageEncoder().to(device) 
    generator = Generator(n_classes=10).to(device) 
    discriminator = Discriminator(n_classes=10).to(device) 
    
    contrastive_optimizer = optim.Adam(
        list(eeg_encoder.parameters()) + list(image_encoder.parameters()), 
        lr=lr, betas=(0.9, 0.999)
    )
    g_optimizer = optim.Adam(
        list(eeg_encoder.parameters()) + list(generator.parameters()), 
        lr=lr, betas=(0.0, 0.9)
    )
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.0, 0.9))
    
    contrastive_losses = []
    g_losses = []
    d_losses = []
    
    print("Starting training...")
    
    for epoch in range(num_epochs):
        contrastive_loss_total = 0
        g_loss_total = 0
        d_loss_total = 0
        
        pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch_idx, (eeg_batch, real_images, negative_images, class_labels) in enumerate(pbar):
            batch_size_actual = eeg_batch.size(0)
            
            eeg_batch = eeg_batch.to(device)
            real_images = real_images.to(device)
            negative_images = negative_images.to(device)
            class_labels = class_labels.squeeze().to(device)
            
            # =============== Contrastive Learning Phase ===============
            contrastive_optimizer.zero_grad()
            
            _, eeg_contrastive_features = eeg_encoder(eeg_batch, return_contrastive=True)
            img_contrastive_features = image_encoder(real_images)
            negative_img_features = image_encoder(negative_images)
            
            info_nce = info_nce_loss(eeg_contrastive_features, img_contrastive_features)
            triplet = triplet_loss(eeg_contrastive_features, img_contrastive_features, negative_img_features)
            
            contrastive_loss = info_nce + 0.5 * triplet
            contrastive_loss.backward()
            contrastive_optimizer.step()
            
            # =============== GAN Training Phase ===============
            noise = torch.randn(batch_size_actual, 100).to(device)
            
            d_optimizer.zero_grad()
            
            eeg_features = eeg_encoder(eeg_batch).detach()
            
            fake_images = generator(eeg_features, noise, class_labels)
            
            real_images_aug = diff_augment_mnist(real_images)
            fake_images_aug = diff_augment_mnist(fake_images.detach())
            
            real_output = discriminator(real_images_aug, class_labels)
            fake_output = discriminator(fake_images_aug, class_labels)
            
            d_loss = discriminator_hinge_loss(real_output, fake_output)
            d_loss.backward()
            d_optimizer.step()
            
            g_optimizer.zero_grad()
            
            noise2 = torch.randn(batch_size_actual, 100).to(device)
            eeg_features = eeg_encoder(eeg_batch)
            fake_images = generator(eeg_features, noise, class_labels)
            fake_images2 = generator(eeg_features, noise2, class_labels)
            
            fake_images_aug = diff_augment_mnist(fake_images)
            fake_images2_aug = diff_augment_mnist(fake_images2)
            
            fake_output = discriminator(fake_images_aug, class_labels)
            fake_output2 = discriminator(fake_images2_aug, class_labels)
            
            g_loss = generator_hinge_loss(fake_output) + generator_hinge_loss(fake_output2)
            
            mode_loss = torch.mean(torch.abs(fake_images2 - fake_images)) / (
                torch.mean(torch.abs(noise2 - noise)) + 1e-5)
            mode_loss = 1.0 / (mode_loss + 1e-5)
            
            with torch.no_grad():
                real_img_features = image_encoder(real_images)
            fake_img_features = image_encoder(fake_images)
            alignment_loss = F.mse_loss(fake_img_features, real_img_features)
            
            total_g_loss = g_loss + 1.0 * mode_loss + 0.1 * alignment_loss
            total_g_loss.backward()
            g_optimizer.step()
            
            contrastive_loss_total += contrastive_loss.item()
            g_loss_total += total_g_loss.item()
            d_loss_total += d_loss.item()
            
            pbar.set_postfix({
                'Cont_Loss': f'{contrastive_loss_total/(batch_idx+1):.4f}',
                'G_Loss': f'{g_loss_total/(batch_idx+1):.4f}',
                'D_Loss': f'{d_loss_total/(batch_idx+1):.4f}'
            })
        
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                sample_eeg = eeg_batch[:8]
                sample_noise = torch.randn(8, 100).to(device)
                sample_classes = class_labels[:8]
                sample_features = eeg_encoder(sample_eeg)
                sample_images = generator(sample_features, sample_noise, sample_classes)
                
                plt.figure(figsize=(12, 8))
                for i in range(8):
                    plt.subplot(2, 4, i+1)
                    img = sample_images[i].cpu().squeeze() 
                    img = (img + 1) / 2  
                    img = torch.clamp(img, 0, 1)
                    plt.imshow(img, cmap='gray')
                    plt.axis('off')
                    plt.title(f'Class: {sample_classes[i].item()}')
                
                plt.tight_layout()
                plt.savefig(f'mnist_generated_samples_epoch_{epoch+1}.png')
                plt.close()
        
        contrastive_losses.append(contrastive_loss_total/len(dataloader))
        g_losses.append(g_loss_total/len(dataloader))
        d_losses.append(d_loss_total/len(dataloader))
        
        print(f'Epoch {epoch+1}: Contrastive_Loss = {contrastive_loss_total/len(dataloader):.4f}, '
              f'G_Loss = {g_loss_total/len(dataloader):.4f}, '
              f'D_Loss = {d_loss_total/len(dataloader):.4f}')
    
    plot_training_history_with_contrastive(contrastive_losses, g_losses, d_losses)
    
    return eeg_encoder, image_encoder, generator, discriminator

In [None]:
# Thoughtviz ImageNet

# def train2(eeg_pickle_path, images_dir, num_epochs=100, batch_size=32, lr=0.0002):
#     dataset = EEGImageDataset(eeg_pickle_path, images_dir)
#     dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
#     eeg_encoder = EEGEncoder(input_channels=14, sequence_length=32).to(device)
#     image_encoder = ImageEncoder().to(device)
#     generator = Generator(n_classes=10).to(device)  
#     discriminator = Discriminator(n_classes=10).to(device) 
    
#     contrastive_optimizer = optim.Adam(
#         list(eeg_encoder.parameters()) + list(image_encoder.parameters()), 
#         lr=lr, betas=(0.9, 0.999)
#     )
#     g_optimizer = optim.Adam(
#         list(eeg_encoder.parameters()) + list(generator.parameters()), 
#         lr=lr, betas=(0.0, 0.9)
#     )
#     d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.0, 0.9))
    
#     contrastive_losses = []
#     g_losses = []
#     d_losses = []
    
#     print("Starting training...")
    
#     for epoch in range(num_epochs):
#         contrastive_loss_total = 0
#         g_loss_total = 0
#         d_loss_total = 0
        
#         pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
#         for batch_idx, (eeg_batch, real_images, negative_images, class_labels) in enumerate(pbar):
#             batch_size_actual = eeg_batch.size(0)
            
#             eeg_batch = eeg_batch.to(device)
#             real_images = real_images.to(device)
#             negative_images = negative_images.to(device)
#             class_labels = class_labels.squeeze().to(device)
            
#             # =============== Contrastive Learning Phase ===============
#             contrastive_optimizer.zero_grad()
            
#             _, eeg_contrastive_features = eeg_encoder(eeg_batch, return_contrastive=True)
#             img_contrastive_features = image_encoder(real_images)
#             negative_img_features = image_encoder(negative_images)
            
#             info_nce = info_nce_loss(eeg_contrastive_features, img_contrastive_features)
#             triplet = triplet_loss(eeg_contrastive_features, img_contrastive_features, negative_img_features)
            
#             contrastive_loss = info_nce + 0.5 * triplet
#             contrastive_loss.backward()
#             contrastive_optimizer.step()
            
#             # =============== GAN Training Phase ===============
#             noise = torch.randn(batch_size_actual, 100).to(device)
            
#             d_optimizer.zero_grad()
            
#             eeg_features = eeg_encoder(eeg_batch).detach()
            
#             fake_images = generator(eeg_features, noise, class_labels)
            
#             real_images_aug = diff_augment(real_images)
#             fake_images_aug = diff_augment(fake_images.detach())
            
#             real_output = discriminator(real_images_aug, class_labels)
#             fake_output = discriminator(fake_images_aug, class_labels)
            
#             d_loss = discriminator_hinge_loss(real_output, fake_output)
#             d_loss.backward()
#             d_optimizer.step()
            
#             g_optimizer.zero_grad()
            
#             noise2 = torch.randn(batch_size_actual, 100).to(device)
#             eeg_features = eeg_encoder(eeg_batch)
#             fake_images = generator(eeg_features, noise, class_labels)
#             fake_images2 = generator(eeg_features, noise2, class_labels)
            
#             fake_images_aug = diff_augment(fake_images)
#             fake_images2_aug = diff_augment(fake_images2)
            
#             fake_output = discriminator(fake_images_aug, class_labels)
#             fake_output2 = discriminator(fake_images2_aug, class_labels)
            
#             g_loss = generator_hinge_loss(fake_output) + generator_hinge_loss(fake_output2)
            
#             mode_loss = torch.mean(torch.abs(fake_images2 - fake_images)) / (
#                 torch.mean(torch.abs(noise2 - noise)) + 1e-5)
#             mode_loss = 1.0 / (mode_loss + 1e-5)
            
#             with torch.no_grad():
#                 real_img_features = image_encoder(real_images)
#             fake_img_features = image_encoder(fake_images)
#             alignment_loss = F.mse_loss(fake_img_features, real_img_features)
            
#             total_g_loss = g_loss + 1.0 * mode_loss + 0.1 * alignment_loss
#             total_g_loss.backward()
#             g_optimizer.step()
            
#             contrastive_loss_total += contrastive_loss.item()
#             g_loss_total += total_g_loss.item()
#             d_loss_total += d_loss.item()
            
#             pbar.set_postfix({
#                 'Cont_Loss': f'{contrastive_loss_total/(batch_idx+1):.4f}',
#                 'G_Loss': f'{g_loss_total/(batch_idx+1):.4f}',
#                 'D_Loss': f'{d_loss_total/(batch_idx+1):.4f}'
#             })
        
#         if (epoch + 1) % 10 == 0:
#             with torch.no_grad():
#                 sample_eeg = eeg_batch[:8]
#                 sample_noise = torch.randn(8, 100).to(device)
#                 sample_classes = class_labels[:8]
#                 sample_features = eeg_encoder(sample_eeg)
#                 sample_images = generator(sample_features, sample_noise, sample_classes)
                
#                 plt.figure(figsize=(12, 8))
#                 for i in range(8):
#                     plt.subplot(2, 4, i+1)
#                     img = sample_images[i].cpu()
#                     img = (img + 1) / 2  
#                     img = torch.clamp(img, 0, 1)
#                     plt.imshow(img.permute(1, 2, 0))
#                     plt.axis('off')
#                     plt.title(f'Class: {sample_classes[i].item()}')
                
#                 plt.tight_layout()
#                 plt.savefig(f'generated_samples_epoch_{epoch+1}.png')
#                 plt.close()
        
#         contrastive_losses.append(contrastive_loss_total/len(dataloader))
#         g_losses.append(g_loss_total/len(dataloader))
#         d_losses.append(d_loss_total/len(dataloader))
        
#         print(f'Epoch {epoch+1}: Contrastive_Loss = {contrastive_loss_total/len(dataloader):.4f}, '
#               f'G_Loss = {g_loss_total/len(dataloader):.4f}, '
#               f'D_Loss = {d_loss_total/len(dataloader):.4f}')
    
#     plot_training_history_with_contrastive(contrastive_losses, g_losses, d_losses)
    
#     return eeg_encoder, image_encoder, generator, discriminator

In [None]:
# eggimagedecode
# def train3(eeg_file_path, images_dir, num_epochs=100, batch_size=32, lr=0.0002):
#     dataset = EEGImageDataset(eeg_file_path, images_dir)
#     dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
#     eeg_encoder = EEGEncoder().to(device)
#     image_encoder = ImageEncoder().to(device)
#     generator = Generator(n_classes=len(dataset.class_dirs)).to(device)
#     discriminator = Discriminator(n_classes=len(dataset.class_dirs)).to(device)
    
#     contrastive_optimizer = optim.Adam(
#         list(eeg_encoder.parameters()) + list(image_encoder.parameters()), 
#         lr=lr, betas=(0.9, 0.999)
#     )
#     g_optimizer = optim.Adam(
#         list(eeg_encoder.parameters()) + list(generator.parameters()), 
#         lr=lr, betas=(0.0, 0.9)
#     )
#     d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.0, 0.9))
    
#     contrastive_losses = []
#     g_losses = []
#     d_losses = []
    
#     print("Starting training...")
    
#     for epoch in range(num_epochs):
#         contrastive_loss_total = 0
#         g_loss_total = 0
#         d_loss_total = 0
        
#         pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
#         for batch_idx, (eeg_batch, real_images, negative_images, class_labels) in enumerate(pbar):
#             batch_size_actual = eeg_batch.size(0)
            
#             eeg_batch = eeg_batch.to(device)
#             real_images = real_images.to(device)
#             negative_images = negative_images.to(device)
#             class_labels = class_labels.squeeze().to(device)
            
#             # =============== Contrastive Learning Phase ===============
#             contrastive_optimizer.zero_grad()
            
#             _, eeg_contrastive_features = eeg_encoder(eeg_batch, return_contrastive=True)
#             img_contrastive_features = image_encoder(real_images)
#             negative_img_features = image_encoder(negative_images)
            
#             info_nce = info_nce_loss(eeg_contrastive_features, img_contrastive_features)
#             triplet = triplet_loss(eeg_contrastive_features, img_contrastive_features, negative_img_features)
            
#             contrastive_loss = info_nce + 0.5 * triplet
#             contrastive_loss.backward()
#             contrastive_optimizer.step()
            
#             # =============== GAN Training Phase ===============
#             noise = torch.randn(batch_size_actual, 100).to(device)
            
#             d_optimizer.zero_grad()
#             eeg_features = eeg_encoder(eeg_batch).detach()
            
#             fake_images = generator(eeg_features, noise, class_labels)
            
#             real_images_aug = diff_augment(real_images)
#             fake_images_aug = diff_augment(fake_images.detach())
            
#             real_output = discriminator(real_images_aug, class_labels)
#             fake_output = discriminator(fake_images_aug, class_labels)
            
#             d_loss = discriminator_hinge_loss(real_output, fake_output)
#             d_loss.backward()
#             d_optimizer.step()
            
#             g_optimizer.zero_grad()
            
#             noise2 = torch.randn(batch_size_actual, 100).to(device)
#             eeg_features = eeg_encoder(eeg_batch)
#             fake_images = generator(eeg_features, noise, class_labels)
#             fake_images2 = generator(eeg_features, noise2, class_labels)
            
#             fake_images_aug = diff_augment(fake_images)
#             fake_images2_aug = diff_augment(fake_images2)
            
#             fake_output = discriminator(fake_images_aug, class_labels)
#             fake_output2 = discriminator(fake_images2_aug, class_labels)
            
#             g_loss = generator_hinge_loss(fake_output) + generator_hinge_loss(fake_output2)
            
#             mode_loss = torch.mean(torch.abs(fake_images2 - fake_images)) / (
#                 torch.mean(torch.abs(noise2 - noise)) + 1e-5)
#             mode_loss = 1.0 / (mode_loss + 1e-5)
            
#             with torch.no_grad():
#                 real_img_features = image_encoder(real_images)
#             fake_img_features = image_encoder(fake_images)
#             alignment_loss = F.mse_loss(fake_img_features, real_img_features)
            
#             total_g_loss = g_loss + 1.0 * mode_loss + 0.1 * alignment_loss
#             total_g_loss.backward()
#             g_optimizer.step()
            
#             contrastive_loss_total += contrastive_loss.item()
#             g_loss_total += total_g_loss.item()
#             d_loss_total += d_loss.item()
            
#             pbar.set_postfix({
#                 'Cont_Loss': f'{contrastive_loss_total/(batch_idx+1):.4f}',
#                 'G_Loss': f'{g_loss_total/(batch_idx+1):.4f}',
#                 'D_Loss': f'{d_loss_total/(batch_idx+1):.4f}'
#             })
        
#         if (epoch + 1) % 10 == 0:
#             with torch.no_grad():
#                 sample_eeg = eeg_batch[:8]
#                 sample_noise = torch.randn(8, 100).to(device)
#                 sample_classes = class_labels[:8]
#                 sample_features = eeg_encoder(sample_eeg)
#                 sample_images = generator(sample_features, sample_noise, sample_classes)
                
#                 plt.figure(figsize=(12, 8))
#                 for i in range(8):
#                     plt.subplot(2, 4, i+1)
#                     img = sample_images[i].cpu()
#                     img = (img + 1) / 2 
#                     img = torch.clamp(img, 0, 1)
#                     plt.imshow(img.permute(1, 2, 0))
#                     plt.axis('off')
#                     plt.title(f'Class: {sample_classes[i].item()}')
                
#                 plt.tight_layout()
#                 plt.savefig(f'generated_samples_epoch_{epoch+1}.png')
#                 plt.close()
        
#         contrastive_losses.append(contrastive_loss_total/len(dataloader))
#         g_losses.append(g_loss_total/len(dataloader))
#         d_losses.append(d_loss_total/len(dataloader))
        
#         print(f'Epoch {epoch+1}: Contrastive_Loss = {contrastive_loss_total/len(dataloader):.4f}, '
#               f'G_Loss = {g_loss_total/len(dataloader):.4f}, '
#               f'D_Loss = {d_loss_total/len(dataloader):.4f}')
    
#     plot_training_history_with_contrastive(contrastive_losses, g_losses, d_losses)
    
#     return eeg_encoder, image_encoder, generator, discriminator

In [None]:
def visualize_samples(eeg_encoder, generator, dataset, num_samples=8, save_path='mnist_generated_samples.png'):
    eeg_encoder.eval()
    generator.eval()
    
    indices = random.sample(range(len(dataset)), num_samples)
    
    fig, axes = plt.subplots(3, num_samples, figsize=(20, 12))
    fig.suptitle('EEG-to-MNIST Generation Results', fontsize=16, fontweight='bold')
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            eeg_sample, real_image, _, class_label = dataset[idx]
            
            eeg_batch = eeg_sample.unsqueeze(0).to(device)
            class_batch = class_label.to(device)
            noise = torch.randn(1, 100).to(device)
            
            eeg_features = eeg_encoder(eeg_batch)
            generated_image = generator(eeg_features, noise, class_batch)
            
            real_img = (real_image.squeeze() + 1) / 2  
            real_img = torch.clamp(real_img, 0, 1)
            
            gen_img = (generated_image.squeeze().cpu() + 1) / 2
            gen_img = torch.clamp(gen_img, 0, 1)
            
            eeg_plot = eeg_sample.mean(dim=0).squeeze().numpy()
            
            axes[0, i].plot(eeg_plot)
            axes[0, i].set_title(f'EEG Signal\nClass: {class_label.item()}', fontsize=10)
            axes[0, i].set_xlabel('Time')
            axes[0, i].set_ylabel('Amplitude')
            axes[0, i].grid(True, alpha=0.3)
            
            axes[1, i].imshow(real_img, cmap='gray')
            axes[1, i].set_title('Real MNIST', fontsize=10)
            axes[1, i].axis('off')
            
            axes[2, i].imshow(gen_img, cmap='gray')
            axes[2, i].set_title('Generated', fontsize=10)
            axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"MNIST visualization saved as {save_path}")

In [None]:
# def visualize_generated_samples(eeg_encoder, generator, dataset, num_samples=8, save_path='generated_samples.png'):
#     eeg_encoder.eval()
#     generator.eval()
    
#     indices = random.sample(range(len(dataset)), num_samples)
    
#     fig, axes = plt.subplots(3, num_samples, figsize=(20, 12))
#     fig.suptitle('EEG-to-Image Generation Results with Contrastive Learning', fontsize=16, fontweight='bold')
    
#     with torch.no_grad():
#         for i, idx in enumerate(indices):
#             eeg_sample, real_image, _, class_label = dataset[idx]
            
#             eeg_batch = eeg_sample.unsqueeze(0).to(device)
#             class_batch = class_label.to(device)
#             noise = torch.randn(1, 100).to(device)
            
#             eeg_features = eeg_encoder(eeg_batch)
#             generated_image = generator(eeg_features, noise, class_batch)
            
#             real_img = (real_image + 1) / 2 
#             real_img = torch.clamp(real_img, 0, 1)
            
#             gen_img = (generated_image.squeeze().cpu() + 1) / 2
#             gen_img = torch.clamp(gen_img, 0, 1)
            
#             eeg_plot = eeg_sample.mean(dim=0).mean(dim=0).numpy()  
            
#             axes[0, i].plot(eeg_plot)
#             axes[0, i].set_title(f'EEG Signal\nClass: {class_label.item()}', fontsize=10)
#             axes[0, i].set_xlabel('Time')
#             axes[0, i].set_ylabel('Amplitude')
#             axes[0, i].grid(True, alpha=0.3)
            
#             axes[1, i].imshow(real_img.permute(1, 2, 0))
#             axes[1, i].set_title('Real Image', fontsize=10)
#             axes[1, i].axis('off')
            
#             axes[2, i].imshow(gen_img.permute(1, 2, 0))
#             axes[2, i].set_title('Generated Image', fontsize=10)
#             axes[2, i].axis('off')
    
#     plt.tight_layout()
#     plt.savefig(save_path, dpi=300, bbox_inches='tight')
#     plt.show()
    
#     print(f"Visualization saved as {save_path}")

In [None]:
def plot_training_history_with_contrastive(contrastive_losses, g_losses, d_losses, save_path='training_history.png'):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(contrastive_losses, label='Contrastive Loss', color='green')
    plt.title('Contrastive Learning Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(g_losses, label='Generator Loss', color='blue')
    plt.plot(d_losses, label='Discriminator Loss', color='red')
    plt.title('GAN Training Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    window = max(1, len(g_losses) // 20)
    cont_smooth = np.convolve(contrastive_losses, np.ones(window)/window, mode='valid')
    g_smooth = np.convolve(g_losses, np.ones(window)/window, mode='valid')
    d_smooth = np.convolve(d_losses, np.ones(window)/window, mode='valid')
    
    plt.plot(cont_smooth, label='Contrastive Loss (Smoothed)', color='green')
    plt.plot(g_smooth, label='Generator Loss (Smoothed)', color='blue')
    plt.plot(d_smooth, label='Discriminator Loss (Smoothed)', color='red')
    plt.title('Smoothed Training Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Training history saved as {save_path}")

In [None]:
def evaluate_contrastive_alignment(eeg_encoder, image_encoder, dataset, num_samples=100):
    eeg_encoder.eval()
    image_encoder.eval()
    
    similarities = []
    
    with torch.no_grad():
        for i in range(min(num_samples, len(dataset))):
            eeg_sample, real_image, _, class_label = dataset[i]
            
            eeg_batch = eeg_sample.unsqueeze(0).to(device)
            img_batch = real_image.unsqueeze(0).to(device)
            
            _, eeg_features = eeg_encoder(eeg_batch, return_contrastive=True)
            img_features = image_encoder(img_batch)
            
            similarity = F.cosine_similarity(eeg_features, img_features, dim=1)
            similarities.append(similarity.item())
    
    avg_similarity = np.mean(similarities)
    std_similarity = np.std(similarities)
    
    print(f"Average EEG-Image similarity: {avg_similarity:.4f} ± {std_similarity:.4f}")
    
    plt.figure(figsize=(10, 6))
    plt.hist(similarities, bins=20, alpha=0.7, edgecolor='black')
    plt.axvline(avg_similarity, color='red', linestyle='--', 
                label=f'Mean: {avg_similarity:.4f}')
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Frequency')
    plt.title('Distribution of EEG-Image Feature Similarities')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('similarity_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return avg_similarity, similarities

In [None]:
eeg_pickle_path = "/kaggle/input/thoughtviz/data/data/eeg/digit/data.pkl"
# images_dir = "/kaggle/input/thoughtviz/images/images/ImageNet-Filtered"  

In [None]:
eeg_encoder, image_encoder, generator, discriminator = train(
    eeg_pickle_path=eeg_pickle_path,
    num_epochs=50,
    batch_size=64,
    lr=0.0002
)
# eeg_encoder, image_encoder, generator, discriminator = train2(
#         eeg_pickle_path=eeg_pickle_path,
#         images_dir=images_dir,
#         num_epochs=10,
#         batch_size=32,
#         lr=0.0002
#     )

In [None]:
dataset = EEGMNISTDataset(eeg_pickle_path, train=True)
# dataset = EEGImageDataset(eeg_pickle_path, images_dir)

In [None]:
visualize_samples(eeg_encoder, generator, dataset, num_samples=8)

In [None]:
# eeg_file_path = "/kaggle/input/dongyangli-deleeg-image-decode/sub-01/sub-01/preprocessed_eeg_training.npy"
# images_dir = "/kaggle/input/dongyangli-deleeg-image-decode/osfstorage-archive/training_images/training_images"

# print("Training EEG-to-Image GAN with Contrastive Learning...")
# print("="*60)

# eeg_encoder, image_encoder, generator, discriminator = train_gan_with_contrastive(
#     eeg_file_path=eeg_file_path,
#     images_dir=images_dir,
#     num_epochs=80,  
#     batch_size=64, 
#     lr=0.0002
# )

# torch.save({
#     'eeg_encoder': eeg_encoder.state_dict(),
#     'image_encoder': image_encoder.state_dict(),
#     'generator': generator.state_dict(),
#     'discriminator': discriminator.state_dict()
# }, 'eeg_to_image_gan_contrastive.pth')

# print("\nTraining completed and models saved!")

# dataset = EEGImageDataset(eeg_file_path, images_dir)

# print("\nEvaluating EEG-Image feature alignment...")
# avg_similarity, similarities = evaluate_contrastive_alignment(
#     eeg_encoder, image_encoder, dataset, num_samples=200
# )

# print("\nGenerating visualizations...")
# visualize_samples(eeg_encoder, generator, dataset, num_samples=8)

# print("All tasks completed successfully!")
# print(f"Average feature alignment score: {avg_similarity:.4f}")
