In [None]:
!pip install pytorch_metric_learning

Collecting pytorch_metric_learning
  Downloading pytorch_metric_learning-2.9.0-py3-none-any.whl.metadata (18 kB)
Downloading pytorch_metric_learning-2.9.0-py3-none-any.whl (127 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.8/127.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_metric_learning
Successfully installed pytorch_metric_learning-2.9.0


In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import json, random, math, matplotlib.pyplot as plt

from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.optim import AdamW
from multiprocessing import freeze_support
from pytorch_metric_learning.losses import SupConLoss
from sklearn.metrics import confusion_matrix

class EEGEncoderCNN(nn.Module):
    def __init__(self, n_classes, time_window=100, n_channels=62, hidden_dim=128):
        super().__init__()
        self.conv1 = nn.Conv2d(1,  64,  (3,15), padding=(1,7), bias=False)
        self.conv2 = nn.Conv2d(64, 128, (3,15), padding=(1,7), bias=False)
        self.conv3 = nn.Conv2d(128,128, (3,15), padding=(1,7), bias=False)
        self.conv4 = nn.Conv2d(128,256, (n_channels,3), padding=(0,1), bias=False)
        self.pool1, self.pool2 = nn.MaxPool2d((1,2)), nn.MaxPool2d((1,2))
        self.res_conv = nn.Conv2d(256,256,(1,3),padding=(0,1),bias=False)
        self.norm_res = nn.GroupNorm(32,256)

        T_after = time_window//4
        self.time_attn = nn.MultiheadAttention(embed_dim=256, num_heads=4, batch_first=True)

        self.adapt = nn.AdaptiveAvgPool2d((1,1))
        self.fc_pre = nn.Linear(256,512)
        self.fc_mu, self.fc_logvar = nn.Linear(512,hidden_dim), nn.Linear(512,hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim,512), nn.GELU(), nn.LayerNorm(512), nn.Dropout(0.5),
            nn.Linear(512,n_classes)
        )

    def forward(self,x):
        x = F.gelu(self.conv1(x)); x = F.gelu(self.conv2(x)); x = self.pool1(x)
        x = F.gelu(self.conv3(x)); x = self.pool2(x); x = F.gelu(self.conv4(x))
        r = x
        x = F.gelu(self.norm_res(self.res_conv(x))) + r      # residual
        B,C,_,T = x.shape
        x = x.squeeze(2).permute(0,2,1)                      # (B,T,C)
        x,_ = self.time_attn(x,x,x); x = x.permute(0,2,1).unsqueeze(2)
        h = self.adapt(x).flatten(1)
        h = F.gelu(self.fc_pre(h))
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.randn_like(mu)*torch.exp(0.5*logvar)
        return z, mu, logvar, self.classifier(z)


In [None]:

torch.manual_seed(0); random.seed(0); np.random.seed(0)
torch.backends.cudnn.benchmark = True

def plot_confusion_matrix(cm, classes,percent,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    percent = percent * 100
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)


class RandomTimeMask:
    """Zero a contiguous time block of length  L  (in samples)."""
    def __init__(self, p=0.5, L=15):
        self.p, self.L = p, L
    def __call__(self, x):                         # x: (1, C, T) or (B,1,C,T)
        if random.random() > self.p: return x
        _, C, T = x.shape
        l = min(self.L, T-1)
        t0 = random.randint(0, T-l-1)
        x = x.clone()
        x[:,:,t0:t0+l] = 0.
        return x

class RandomChannelDropout:
    def __init__(self, p_channel=0.2, max_drop=4):
        self.p, self.max_drop = p_channel, max_drop
    def __call__(self, x):                         # x: (1,C,T)
        if random.random() > self.p: return x
        _, C, T = x.shape
        k = random.randint(1, self.max_drop)
        ch = torch.randperm(C, device=x.device)[:k]
        x = x.clone()
        x[:,ch,:] = 0.
        return x

class RandomTimeJitter:
    def __init__(self, max_shift=8): self.max_shift=max_shift
    def __call__(self, x):
        if self.max_shift==0: return x
        shift = random.randint(-self.max_shift, self.max_shift)
        return torch.roll(x, shifts=shift, dims=-1)

class RandomAmplitudeScale:
    def __init__(self, min_s=0.9, max_s=1.1): self.a,self.b=min_s,max_s
    def __call__(self, x):
        s = torch.empty(1, device=x.device).uniform_(self.a, self.b)
        return x * s

class RandomGaussianNoise:
    def __init__(self, std=0.03): self.std = std
    def __call__(self, x):
        return x + torch.randn_like(x) * self.std


class EEGAugment:
    """Compose all small transforms"""
    def __init__(self):
        self.ops = [
            RandomTimeJitter(12),
            RandomChannelDropout(0.6, max_drop=8),
            RandomTimeMask(0.8, L=25),
            RandomAmplitudeScale(0.8,1.2),
            RandomGaussianNoise(0.03),
        ]
    def __call__(self, x):
        for op in self.ops:
            x = op(x)
        return x

class EEGWindowDataset(Dataset):
    def __init__(self, eeg_tensor, wnid_list,img,sub, class_to_idx, transform=None, fname=None):
        self.X = eeg_tensor                     # (N,1,,n_freq_bins,time_window)
        self.y = torch.tensor([class_to_idx[w] for w in wnid_list],
                              dtype=torch.long)
        self.transform = transform
        self.image = img
        self.subject = sub
        self.fname = fname
    def __len__(self): return len(self.X)

    def __getitem__(self, idx):
        x, y = self.X[idx], self.y[idx]

        x = self.X[idx].float()

        #     centre & scale — here we just scale to roughly [-1,1]
        x = (x - x.mean()) / (x.std() + 1e-6)

        if self.transform is not None:
            x = self.transform(x)

        return x, y,self.image[idx], self.subject[idx],self.fname[idx]


def latent_mixup(z, y, alpha=0.2):
    """z: (B, d)  – latent vectors
       y: (B,)    – int labels"""
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(z.size(0), device=z.device)
    z_mix = lam * z + (1 - lam) * z[idx]        # mixed latents
    y_a, y_b = y, y[idx]
    return z_mix, y_a, y_b, lam


def train_and_dump():
    # ──────────────────────────────────────────────────────────
    # 1) LOAD PREPROCESSED SPLITS FROM DISK (no more splitData)
    # ──────────────────────────────────────────────────────────
    channel = "all_channels"
    granularity = "fine0"

    split_dir = f"/content/drive/MyDrive/ImageNet_Images/preprocessed_splits/granularity/Time/{channel}/{granularity}"

    # --- load *train* once --------------------------------------------------------
    train_tensor, train_wnids, train_imgs, train_subj, train_fname = torch.load(f"{split_dir}/train_timeNewHop32.pt",
                                              weights_only=False)

    # build / read the mapping -----------------------------------------------------
    json_path = f"/content/drive/MyDrive/ImageNet_Images/preprocessed_splits/granularity/Time/{channel}/{granularity}/wnid_to_idx.json"
    if os.path.exists(json_path):
        with open(json_path) as f:
            class_to_idx = json.load(f)
    else:
        uniq = sorted(set(train_wnids))
        class_to_idx = {w:i for i,w in enumerate(uniq)}
        with open(json_path, "w") as f:
            json.dump(class_to_idx, f, indent=2)

    n_classes = len(class_to_idx)            # 80

    # --- validation & test still come from disk (they're small) -------------------
    val_tensor,  val_wnids,  val_imgs, val_subj, val_fname = torch.load(f"{split_dir}/val_timeNewHop32.pt",
                                            weights_only=False)
    test_tensor, test_wnids, test_imgs, test_subj, test_fname = torch.load(f"{split_dir}/test_timeNewHop32.pt",
                                            weights_only=False)

    # --- construct datasets *without* an extra disk read --------------------------
    train_ds = EEGWindowDataset(train_tensor, train_wnids,train_imgs,train_subj,
                                class_to_idx, transform=EEGAugment(),fname=train_fname)
    val_ds   = EEGWindowDataset(val_tensor,  val_wnids, val_imgs, val_subj,
                                class_to_idx, transform=None,fname=val_fname)
    test_ds  = EEGWindowDataset(test_tensor, test_wnids, test_imgs, test_subj,
                                class_to_idx, transform=None,fname=test_fname)

    # ───────────────────────── 4. dataloaders ─────────────────────────────────────
    batch_size = 256
    loaders = dict(
        train = DataLoader(train_ds, batch_size, shuffle=True,  num_workers=4),
        val   = DataLoader(val_ds,   batch_size, shuffle=False, num_workers=4),
        test  = DataLoader(test_ds,  batch_size, shuffle=False, num_workers=4),
    )

    # 0.001
    # 3) MODEL, LOSS, optimiser, SCHEDULER
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EEGEncoderCNN(n_classes,n_channels=62,hidden_dim=128).to(device)

    def to_channels_last(t):
        return t.to(device, memory_format=torch.channels_last)

    model = model.to(memory_format=torch.channels_last)


    criterion = nn.CrossEntropyLoss(label_smoothing=0.15)
    optimiser = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=150, eta_min=1e-6)
    supcon = SupConLoss(temperature=0.07)

    free_bit = 0.1
    beta = 0.0
    gamma = 0.0

    epochs_no_improve = 0
    num_epochs   = 150
    best_val_acc = 0.0
    save_path    = f"/content/drive/MyDrive/ImageNet_Images/preprocessed_splits/granularity/Time/{channel}/{granularity}/best_eeg_cnn.pt"
    warmup = 10
    ramp = 30
    patience = 15
    early_stop_start = warmup + ramp + patience
    beta_max  = 0.0003
    gamma_max = 5e-5
    max_logvar = 1.5
    free_bit = 0.05
    # 4) TRAIN + VALIDATION LOOP
    for epoch in range(1, num_epochs + 1):
        for phase in ("train", "val"):
            if phase == "train":
                model.train()
            else:
                model.eval()

            epoch_loss = 0.0
            correct    = 0
            total      = 0
            sum_ce, sum_kl, sum_supcon, sum_total, total_samples = 0, 0, 0, 0, 0
            if epoch < warmup:
                beta = 0.0
            elif epoch < warmup+ramp:
                beta = beta_max * (epoch - warmup) / ramp
            else:
                beta = beta_max

            gamma = gamma_max * (beta / beta_max)
            # ---- on the very first training epoch ----
            if epoch == warmup:
                for m in [model.conv1, model.conv2, model.conv3,
                          model.conv4, model.res_conv, model.norm_res]:
                    for p in m.parameters():
                        p.requires_grad = False
                    m.eval()
            # ---- unfreeze when KL ramp is finished ----
            if epoch == warmup+ramp:
                for m in [model.conv1, model.conv2, model.conv3,
                          model.conv4, model.res_conv, model.norm_res]:
                    for p in m.parameters():
                        p.requires_grad = True
                    m.train()


            for X, y,_,_,_ in loaders[phase]:
                # X: shape (B, 1, 4, 100)
                # y: shape (B,)
                B, C, W, T = X.shape  # C = 1, W = 4, T = 100
                X = to_channels_last(X)       # (B,1,4,100)
                y = y.to(device)      # (B,)
                # y: (B,)

                # 1) collapse “channels” into batch (but here C=1, so this is effectively a no-op)
                #    The original code did X.view(B*C,1,W,T), but since C=1:
                Xf = X.view(B * C, 1, W, T)  # → actually still (B,1,4,100)

                # 2) forward through shared encoder
                lat_f, mu, logvar, _ = model(Xf.to(device,memory_format=torch.channels_last))

                # 3) reshape & (the original did a 5-window concat, but we have only 1 “chunk”)
                #    In the old code they assumed 5 windows per trial (so C=5). Here C=1, so:
                lat = lat_f.view(B, C, -1).reshape(B, 128 * C)  # → (B, 128)

                # clamp logvar
                logvar = torch.clamp(logvar, max=max_logvar)

                if phase == "train" and epoch>early_stop_start:                             # ← MIXUP ONLY DURING TRAIN
                    z, y_a, y_b, lam = latent_mixup(mu, y, alpha=0.1)   # ★
                    logits = model.classifier(z)                         # ★
                    loss_ce = (lam * criterion(logits, y_a)              # ★
                              + (1 - lam) * criterion(logits, y_b))
                else:
                    logits  = model.classifier(mu)
                    loss_ce = criterion(logits, y)
                # compute per‐sample, per‐dim KL
                # free‐bits: clamp each dimension’s KL to at least free_bit
                # sum over dims → (B,) then mean over batch → scalar
                kl_per_dim = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())  # (B, hidden_dim)
                kl = torch.clamp(kl_per_dim, min=free_bit).sum(dim=1).mean()

                contrastive_loss = supcon(mu, y)

                loss = loss_ce + beta * kl + gamma * contrastive_loss

                if phase == "train":
                    optimiser.zero_grad()
                    loss.backward()
                    optimiser.step()

                epoch_loss += loss.item() * B
                sum_ce      += loss_ce.item()      * B
                sum_kl      += kl.item()      * B
                sum_supcon  += contrastive_loss.item()  * B

                preds       = logits.argmax(dim=1)
                correct    += (preds == y).sum().item()
                total      += B

            epoch_loss /= total
            epoch_acc   = correct / total
            avg_ce      = sum_ce     / total
            avg_kl      = sum_kl     / total
            avg_supcon  = sum_supcon / total
            print(f"{phase.title()} Epoch {epoch:03d} | "
                  f"Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} "
                  f"CE: {avg_ce:.4f}  KL: {avg_kl:.4f}  SupCon: {avg_supcon:.4f}  ")

            if phase == "val" and epoch_acc > best_val_acc:
                best_val_acc = epoch_acc
                epochs_no_improve = 0
                ckpt = {
                    'model'     : model.state_dict(),
                    'optimiser' : optimiser.state_dict(),
                    'scheduler' : scheduler.state_dict(),
                    'epoch'     : epoch,
                    'best_acc'  : best_val_acc,
                }
                torch.save(ckpt, save_path)
            elif phase == "val" and epoch>=early_stop_start:
                epochs_no_improve += 1

        if epochs_no_improve >= patience and epoch>early_stop_start:
            print(f"No improvement in {patience} epochs, stopping early.")
            break
        scheduler.step()

    # ─── TEST PHASE ──────────────────────────────────────────────

    ckpt = torch.load(save_path)
    model.load_state_dict(ckpt["model"])
    epoch_start     = ckpt['epoch'] + 1
    optimiser.load_state_dict(ckpt['optimiser'])


    model.eval()

    test_correct = 0
    test_total   = 0

    with torch.no_grad():
        for X, y,_,_,_ in loaders["test"]:
            X = X.to(device)                         # (B,1,62,100)
            y = y.to(device)
            B, C, W, T = X.shape                    # here C == 1
            _, mu, _, _ = model(X.view(B*C,1,W,T))

            # quick sanity-check accuracy
            logits  = model.classifier(mu)

            preds  = logits.argmax(dim=1)
            test_correct += (preds == y).sum().item()
            test_total   += B

    print(f"Test Acc: {test_correct/test_total:.4f}")

    # ─── DUMP LATENTS FOR EACH SPLIT ─────────────────────────────
    for split in ("train", "val", "test"):
        all_lat, all_lbl,all_img, all_sub, all_preds, all_fnames = [], [], [],[],[],[]
        for X, y,image,subject, fname in loaders[split]:

            X = X.to(device)                         # (B,1,62,100)
            y = y.to(device)
            B, C, W, T = X.shape                    # here C == 1
            _, mu, _, _ = model(X.view(B*C,1,W,T))

            # quick sanity-check accuracy
            logits  = model.classifier(mu)

            lat = mu.cpu().detach().numpy()                      # (B, 128) — no extra reshape

            preds = logits.argmax(dim=1)

            all_pred_batch = preds.cpu().numpy()
            all_preds.append(all_pred_batch)

            all_lat.append(lat)
            all_lbl.append(y.cpu().numpy())
            all_img.append(image)
            all_sub.append(subject.cpu().numpy())
            all_fnames.extend(fname)

        L = np.concatenate(all_lat, axis=0)  # → (N, 1, 128)
        Y = np.concatenate(all_lbl, axis=0)  # → (N,)
        I = np.concatenate(all_img,axis=0)
        S = np.concatenate(all_sub,axis=0)
        P = np.concatenate(all_preds,axis=0)
        F = np.array(all_fnames,dtype=object)

        # Save as .npy
        latent_savePath = f"/content/drive/MyDrive/ImageNet_Images/latent_dumps/granularity/Time/{channel}/{granularity}/"

        if not os.path.exists(latent_savePath):
          os.makedirs(latent_savePath)

        np.save(f"{latent_savePath}{split}_latentsNew.npy", L)
        np.save(f"{latent_savePath}{split}_labelsNew.npy",  Y)
        np.save(f"{latent_savePath}{split}_imagesNew.npy", I)
        np.save(f"{latent_savePath}{split}_subjectNew.npy",  S)
        np.save(f"{latent_savePath}{split}_filenamesNew.npy",F)
               # 2) Compute confusion matrix
        cm = confusion_matrix(Y, P, labels=list(class_to_idx.values()))

        # 3) Plot & save
        plt.figure(figsize=(10,10))
        plot_confusion_matrix(cm,
                              classes=[k for k,v in sorted(class_to_idx.items(), key=lambda x: x[1])],
                              percent=False,
                              title=f"{split} confusion")
        fn = os.path.join(split_dir, f"{split}_confusion.png")
        plt.tight_layout()
        plt.savefig(fn)
        plt.close()

        print(f"Dumped {split}: latents {L.shape}, labels {Y.shape}, images {I.shape}, subjects {S.shape}, CM at {fn}")


if __name__ == "__main__":
    freeze_support()
    train_and_dump()

Train Epoch 001 | Loss: 2.2183 Acc: 0.1316 CE: 2.2183  KL: 6.4349  SupCon: 5.5674  
Val Epoch 001 | Loss: 2.0682 Acc: 0.1800 CE: 2.0682  KL: 6.5906  SupCon: 5.5902  
Train Epoch 002 | Loss: 2.1374 Acc: 0.1583 CE: 2.1374  KL: 7.2903  SupCon: 5.7148  
Val Epoch 002 | Loss: 2.0216 Acc: 0.2040 CE: 2.0216  KL: 9.5137  SupCon: 6.1657  
Train Epoch 003 | Loss: 2.0139 Acc: 0.2258 CE: 2.0139  KL: 14.2442  SupCon: 6.4645  
Val Epoch 003 | Loss: 1.8562 Acc: 0.2880 CE: 1.8562  KL: 23.0837  SupCon: 7.1823  
Train Epoch 004 | Loss: 1.8473 Acc: 0.3241 CE: 1.8473  KL: 23.5887  SupCon: 7.0309  
Val Epoch 004 | Loss: 1.6073 Acc: 0.4520 CE: 1.6073  KL: 34.8137  SupCon: 7.4455  
Train Epoch 005 | Loss: 1.6187 Acc: 0.4701 CE: 1.6187  KL: 23.5490  SupCon: 7.1509  
Val Epoch 005 | Loss: 1.3092 Acc: 0.6440 CE: 1.3092  KL: 42.6493  SupCon: 7.2420  
Train Epoch 006 | Loss: 1.4038 Acc: 0.5947 CE: 1.4038  KL: 24.4456  SupCon: 6.7575  
Val Epoch 006 | Loss: 1.1211 Acc: 0.7400 CE: 1.1211  KL: 35.7764  SupCon: 6.318