## Imports

In [None]:
!pip install torch torchvision torchaudio transformers tqdm




In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
import math
from pathlib import Path

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
BASE_DIR = Path("/content/drive/MyDrive/ProjectLabTMIT")
df_index = pd.read_csv(BASE_DIR / "df_index_with_clip.csv")

In [None]:
print(df_index.head())


          base_id      class  \
0  n02510455_4616  n02510455   
1  n02510455_4616  n02510455   
2  n02510455_4616  n02510455   
3  n02510455_4616  n02510455   
4  n02510455_4616  n02510455   

                                            eeg_path  \
0  /content/drive/MyDrive/capstone/images/n025104...   
1  /content/drive/MyDrive/capstone/images/n025104...   
2  /content/drive/MyDrive/capstone/images/n025104...   
3  /content/drive/MyDrive/capstone/images/n025104...   
4  /content/drive/MyDrive/capstone/images/n025104...   

                                          image_path  \
0  /content/drive/MyDrive/capstone/images/n025104...   
1  /content/drive/MyDrive/capstone/images/n025104...   
2  /content/drive/MyDrive/capstone/images/n025104...   
3  /content/drive/MyDrive/capstone/images/n025104...   
4  /content/drive/MyDrive/capstone/images/n025104...   

                                        caption_path  \
0  /content/drive/MyDrive/capstone/images/n025104...   
1  /content/drive/MyD

## Dataset Module



In [None]:

class EEGDatasetV1(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        self.labels = sorted(df["class"].unique())
        self.label2id = {c: i for i, c in enumerate(self.labels)}

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        eeg_img = Image.open(row.eeg_path).convert("RGB")
        eeg = self.transform(eeg_img)

        clip_emb = torch.tensor(np.load(row.clip_emb_path), dtype=torch.float32)
        label = torch.tensor(self.label2id[row["class"]], dtype=torch.long)

        return eeg, clip_emb, label


In [None]:

class EEGDatasetV2(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        self.labels = sorted(df["class"].unique())
        self.label2id = {c: i for i, c in enumerate(self.labels)}

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        eeg_img = Image.open(row.eeg_path).convert("RGB")
        eeg = self.transform(eeg_img)

        clip_emb = torch.tensor(np.load(row.clip_emb_path), dtype=torch.float32)
        label = torch.tensor(self.label2id[row["class"]], dtype=torch.long)

        return eeg, clip_emb, label


In [None]:
class EEGDatasetV3(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        self.labels = sorted(df["class"].unique())
        self.label2id = {c: i for i, c in enumerate(self.labels)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        eeg = Image.open(row.eeg_path).convert("L")

        eeg = eeg.resize((440, 128))   # (width, height)

        eeg = np.array(eeg).astype(np.float32) / 255.0

        eeg = eeg[np.newaxis, :, :]  # (1, H, W)

        eeg_tensor = torch.tensor(eeg, dtype=torch.float32)

        clip_emb = torch.tensor(np.load(row.clip_emb_path), dtype=torch.float32)
        label = torch.tensor(self.label2id[row["class"]], dtype=torch.long)

        return eeg_tensor, clip_emb, label


## Encoder Module

### Baseline V1

In [None]:

class EEGEncoderV1(nn.Module):
    def __init__(self, emb_dim=512, num_classes=40):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 224 -> 112

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),   # 112 -> 56
        )

        self.fc_emb = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 56 * 56, emb_dim)
        )

        self.classifier = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        h = self.features(x)
        h = self.fc_emb(h)
        h = F.normalize(h, dim=-1)
        y = self.classifier(h)
        return h, y


### Improved Version - V2

In [None]:
class ConvLayer2D(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel, stride, padding, dilation):
        super().__init__()
        self.add_module("bn", nn.BatchNorm2d(in_channels))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(
            in_channels, out_channels,
            kernel_size=kernel,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=True
        ))


In [None]:
class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_list, kernel, stride):
        super().__init__()
        layers = []

        paddings = []
        for dil in dilation_list:
            k = kernel[1] * dil - 1
            pad = math.floor(k / 2)
            paddings.append((0, pad))

        for pad, dil in zip(paddings, dilation_list):
            layers.append(
                ConvLayer2D(
                    in_channels, out_channels,
                    kernel, stride,
                    padding=pad,
                    dilation=(1, dil)
                )
            )

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        feats = [layer(x) for layer in self.layers]
        min_w = min(f.shape[-1] for f in feats)
        feats = [f[..., :min_w] for f in feats]
        return torch.cat(feats, dim=1)


In [None]:
class SpatialBlock(nn.Module):
    def __init__(self, in_channels, out_channels, height, n_layers=4):
        super().__init__()

        kernel_sizes = [32, 16, 8, 4]
        layers = []

        for k_h in kernel_sizes:
            pad_h = k_h // 2
            layers.append(
                ConvLayer2D(
                    in_channels, out_channels,
                    kernel=(k_h, 1),
                    stride=(1,1),
                    padding=(pad_h, 0),
                    dilation=1
                )
            )

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        feats = [layer(x) for layer in self.layers]
        min_h = min(f.shape[-2] for f in feats)
        min_w = min(f.shape[-1] for f in feats)
        feats = [f[..., :min_h, :min_w] for f in feats]
        return torch.cat(feats, dim=1)


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(ch)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(ch)

    def forward(self, x):
        res = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + res)


 ChannelNet



In [None]:
class EEGEncoderV3(nn.Module):
    def __init__(self, embedding_dim=512, num_classes=40,
                 in_channels=1, height=128, width=440):
        super().__init__()

        self.temp = TemporalBlock(
            in_channels=in_channels,
            out_channels=10,
            dilation_list=[1, 2, 4, 8, 16],
            kernel=(1, 33),
            stride=(1, 2)
        )

        self.spatial = SpatialBlock(
            in_channels=10 * 5,
            out_channels=50,
            height=height
        )

        res_in = 50 * 4
        self.res_blocks = nn.ModuleList([ResidualBlock(res_in) for _ in range(4)])

        self.down = ConvLayer2D(
            res_in, 50,
            kernel=3,
            stride=2,
            padding=1,
            dilation=1
        )

        # Compute flattened size
        x = torch.zeros(1, in_channels, height, width)
        with torch.no_grad():
            h = self.temp(x)
            h = self.spatial(h)
            for rb in self.res_blocks:
                h = rb(h)
            h = self.down(h)
            flat_dim = h.view(1, -1).size(1)

        self.embedding_proj = nn.Sequential(
            nn.Linear(flat_dim, 1024),
            nn.GELU(),
            nn.Linear(1024, embedding_dim)
        )

        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        h = self.temp(x)
        h = self.spatial(h)
        for rb in self.res_blocks:
            h = rb(h)
        h = self.down(h)

        h = h.view(h.size(0), -1)
        emb = self.embedding_proj(h)
        emb = F.normalize(emb, dim=-1)
        cls = self.classifier(emb)
        return emb, cls


## Training module

### Training modules of V1 and V2

In [None]:
def train_with_mse(model, dataloader, lr=1e-4, epochs=10, mse_w=0.5):
    model = model.cuda()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    ce = nn.CrossEntropyLoss()

    for ep in range(epochs):
        model.train()
        total = 0
        for eeg, clip_emb, labels in dataloader:
            eeg, clip_emb, labels = eeg.cuda(), clip_emb.cuda(), labels.cuda()
            opt.zero_grad()

            emb_pred, cls_pred = model(eeg)
            loss = (1 - mse_w) * mse(emb_pred, clip_emb) + mse_w * ce(cls_pred, labels)
            loss.backward()
            opt.step()

            total += loss.item()

        print(f"[MSE] Epoch {ep+1} Loss = {total/len(dataloader):.4f}")


In [None]:
def cosine_loss(pred, target):
    pred = F.normalize(pred, dim=-1)
    target = F.normalize(target, dim=-1)
    return 1 - (pred * target).sum(dim=-1).mean()

In [None]:
def train_eeg_encoder(model, dataloader, lr=1e-4, epochs=20, ce_weight=0.05):
    model = model.cuda()
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()

    for ep in range(epochs):
        total = 0
        model.train()

        for eeg, clip_emb, labels in dataloader:
            eeg = eeg.cuda()
            clip_emb = clip_emb.cuda()
            labels = labels.cuda()

            opt.zero_grad()

            emb_pred, cls_pred = model(eeg)
            loss_emb = cosine_loss(emb_pred, clip_emb)
            loss_cls = ce(cls_pred, labels)
            loss = loss_emb + ce_weight * loss_cls

            loss.backward()
            opt.step()

            total += loss.item()

        print(f"[COSINE] Epoch {ep+1}: Loss = {total/len(dataloader):.4f}")

In [None]:
def train_eeg_v3_optimized(
    model,
    dataloader,
    lr=1e-4,
    epochs=20,
    ce_weight=0.05,
    save_dir="/content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3"
):

    os.makedirs(save_dir, exist_ok=True)

    device = "cuda"
    model = model.to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    ce = nn.CrossEntropyLoss()

    scaler = GradScaler()   # for mixed precision

    for ep in range(1, epochs+1):
        model.train()
        running_loss = 0.0

        loop = tqdm(dataloader, desc=f"Epoch {ep}/{epochs}", leave=True)

        for eeg, clip_emb, labels in loop:
            eeg = eeg.to(device)
            clip_emb = clip_emb.to(device)
            labels = labels.to(device)

            opt.zero_grad()

            # ---- FP16 AUTOCOMPUTE ----
            with autocast():
                emb_pred, cls_pred = model(eeg)

                loss_emb = cosine_loss(emb_pred, clip_emb)
                loss_cls = ce(cls_pred, labels)
                loss = loss_emb + ce_weight * loss_cls

            # ---- GRADIENT SCALING ----
            scaler.scale(loss).backward()

            # ---- GRADIENT CLIPPING ----
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            scaler.step(opt)
            scaler.update()

            running_loss += loss.item()
            loop.set_postfix(loss=f"{running_loss/len(loop):.4f}")

        scheduler.step()

        # ---- SAVE CHECKPOINT EACH EPOCH ----
        ckpt_path = os.path.join(save_dir, f"eeg_encoder_v3_epoch_{ep:02d}.pt")
        torch.save(model.state_dict(), ckpt_path)

        print(f"✔ Saved checkpoint: {ckpt_path}\n")

    return model

## Training scripts

### Data Split

In [None]:
from sklearn.model_selection import train_test_split
import os
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import os

In [None]:
df = df_index.copy()


In [None]:
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["class"],
    random_state=42
)

In [None]:
print("Train size:", len(train_df))
print("Val size:", len(val_df))
print("Unique classes:", df["class"].nunique())

Train size: 9572
Val size: 2393
Unique classes: 40


### Option A — Train Baseline V1

In [None]:
train_dataset = EEGDatasetV1(train_df)
val_dataset   = EEGDatasetV1(val_df)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)

model = EEGEncoderV1(emb_dim=512, num_classes=len(train_dataset.labels))
train_with_mse(model, train_loader, epochs=10, mse_w=0.5)

evaluate(model, val_loader)


### Option B — ChannelNet V3

In [None]:
train_dataset = EEGDatasetV3(train_df)
val_dataset   = EEGDatasetV3(val_df)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
model = EEGEncoderV3(
    embedding_dim=512,
    num_classes=len(train_dataset.labels),
    in_channels=1,
    height=128,
    width=440
)

trained_model = train_eeg_v3_optimized(
    model,
    train_loader,
    lr=1e-4,
    epochs=10,
    ce_weight=0.05
)


  scaler = GradScaler()   # for mixed precision
  with autocast():
Epoch 1/10: 100%|██████████| 599/599 [2:34:37<00:00, 15.49s/it, loss=0.3765]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_01.pt



Epoch 2/10: 100%|██████████| 599/599 [03:24<00:00,  2.92it/s, loss=0.3711]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_02.pt



Epoch 3/10: 100%|██████████| 599/599 [03:36<00:00,  2.77it/s, loss=0.3662]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_03.pt



Epoch 4/10: 100%|██████████| 599/599 [03:49<00:00,  2.61it/s, loss=0.3608]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_04.pt



Epoch 5/10: 100%|██████████| 599/599 [03:55<00:00,  2.54it/s, loss=0.3510]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_05.pt



Epoch 6/10: 100%|██████████| 599/599 [04:02<00:00,  2.47it/s, loss=0.3324]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_06.pt



Epoch 7/10: 100%|██████████| 599/599 [03:19<00:00,  3.01it/s, loss=0.3088]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_07.pt



Epoch 8/10: 100%|██████████| 599/599 [03:02<00:00,  3.29it/s, loss=0.2889]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_08.pt



Epoch 9/10: 100%|██████████| 599/599 [03:02<00:00,  3.29it/s, loss=0.2747]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_09.pt



Epoch 10/10: 100%|██████████| 599/599 [03:01<00:00,  3.30it/s, loss=0.2676]


✔ Saved checkpoint: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_10.pt



In [None]:
dummy = torch.zeros(1, 1, 128, 440).cuda()
model = EEGEncoderV3().cuda()
out = model(dummy)

## Evaluation

In [None]:
def evaluate_v3(model, loader):
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for eeg, _, labels in loader:
            eeg = eeg.cuda()
            labels = labels.cuda()

            _, logits = model(eeg)
            preds = logits.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total * 100
    print(f"Validation Accuracy: {acc:.2f}%  ({correct}/{total})")
    return acc


In [None]:
best_ckpt = "/content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_10.pt"

model = EEGEncoderV3(
    embedding_dim=512,
    num_classes=len(train_dataset.labels),
    in_channels=1,
    height=128,
    width=440
)

model.load_state_dict(torch.load(best_ckpt, map_location="cuda"))
model = model.cuda().eval()

print("Loaded:", best_ckpt)


Loaded: /content/drive/MyDrive/ProjectLabTMIT/checkpoints_v3/eeg_encoder_v3_epoch_10.pt


In [None]:
val_acc = evaluate_v3(model, val_loader)
val_acc


Validation Accuracy: 11.95%  (286/2393)


11.951525282072712

In [None]:
def avg_cosine_similarity(model, loader):
    cos = nn.CosineSimilarity(dim=-1)
    sims = []

    model.eval()
    with torch.no_grad():
        for eeg, clip, _ in loader:
            eeg = eeg.cuda()
            clip = clip.cuda()

            emb_pred, _ = model(eeg)

            s = cos(emb_pred, clip).mean().item()
            sims.append(s)

    return sum(sims) / len(sims)

In [None]:

sim = avg_cosine_similarity(model, val_loader)
print("Average EEG→CLIP cosine similarity:", sim)

Average EEG→CLIP cosine similarity: 0.7726423684755961
