## Setup

In [None]:
!pip install einops ema_pytorch mat73 numpy scikit_learn torch tqdm wandb --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/IDL_Project

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1UTmW6iNsuZDOK5pC8i70YZC-1owajRkF/IDL_Project


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/IDL_Project')

## Load and Preprocess Data

In [None]:
import numpy as np
import torch
import pandas as pd
import os
import random

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from DiffE.utils import zscore_norm, EEGDataset


In [None]:
def load_eeg_data(root_dir, subject_id, data_type='thinking', eeg_channels=None, window_size=256, step_size=128):
    subject_id = str(subject_id).zfill(2)
    subject_dir = os.path.join(root_dir, subject_id)
    file_path = os.path.join(subject_dir, f"{data_type}.csv")

    df = pd.read_csv(file_path)
    if eeg_channels is None:
        non_eeg_cols = ['Time:256Hz', 'Epoch', 'Label', 'Stage', 'Flag']
        eeg_channels = [col for col in df.columns if col not in non_eeg_cols]

    eeg_data = df[eeg_channels].values  # shape: (n_samples, n_channels)
    labels = df['Label'].values
    unique_labels = np.unique(labels)
    label_map = {label: i for i, label in enumerate(unique_labels)}
    labels = np.array([label_map[label] for label in labels])

    # Reshape data into epochs using sliding windows
    X = []
    Y = []

    for i in range(0, len(eeg_data) - window_size + 1, step_size):
        window_data = eeg_data[i:i+window_size, :].T  # Transpose to get (n_channels, n_timepoints)
        window_label = labels[i + window_size // 2]  # Use the label from the middle of the window
        X.append(window_data)
        Y.append(window_label)

    # Convert to torch tensors
    X = torch.tensor(np.array(X), dtype=torch.float32)
    Y = torch.tensor(np.array(Y), dtype=torch.long)

    # Apply z-score normalization
    X = zscore_norm(X)

    return X, Y

In [None]:
def load_multiple_subjects(root_dir, subject_ids, data_type='thinking'):
    all_X = []
    all_Y = []

    for subject_id in subject_ids:
        try:
            X, Y = load_eeg_data(root_dir, subject_id, data_type)
            all_X.append(X)
            all_Y.append(Y)
            print(f"Loaded data from subject {subject_id}")
        except Exception as e:
            print(f"Error loading data from subject {subject_id}: {e}")

    # Combine data from all subjects
    X_combined = torch.cat(all_X, dim=0)
    Y_combined = torch.cat(all_Y, dim=0)

    return X_combined, Y_combined

In [None]:
def split_subjects_and_data(root_dir, data_type='thinking', seen_ratio=0.9, train_ratio=0.9, val_ratio=0.15, seed=42):
    # Set random seed for reproducibility
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # Get all subject IDs
    all_subject_ids = [str(i).zfill(2) for i in range(1, 22)]

    # Randomly shuffle subjects
    random.shuffle(all_subject_ids)

    # Split into seen and unseen
    num_seen = int(len(all_subject_ids) * seen_ratio)
    seen_subjects = all_subject_ids[:num_seen]
    unseen_subjects = all_subject_ids[num_seen:]

    print(f"Seen subjects: {seen_subjects}")
    print(f"Unseen subjects: {unseen_subjects}")

    # Load data for seen subjects
    seen_X, seen_Y = load_multiple_subjects(root_dir, seen_subjects, data_type)

    # Split seen subjects' data into train, val, test
    train_size = train_ratio
    val_size = val_ratio / (1 - train_ratio)  # Adjusted to be relative to remaining data after train split

    X_train, X_temp, Y_train, Y_temp = train_test_split(
        seen_X, seen_Y, test_size=(1-train_size), random_state=seed, stratify=seen_Y
    )

    X_val, X_seen_test, Y_val, Y_seen_test = train_test_split(
        X_temp, Y_temp, test_size=(1-val_size), random_state=seed, stratify=Y_temp
    )

    # Load data for unseen subjects (all for testing)
    X_unseen_test, Y_unseen_test = load_multiple_subjects(root_dir, unseen_subjects, data_type)

    # Create dataloaders
    batch_size_train = 32
    batch_size_test = 64

    train_dataset = EEGDataset(X_train, Y_train)
    val_dataset = EEGDataset(X_val, Y_val)
    seen_test_dataset = EEGDataset(X_seen_test, Y_seen_test)
    unseen_test_dataset = EEGDataset(X_unseen_test, Y_unseen_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False)
    seen_test_loader = DataLoader(seen_test_dataset, batch_size=batch_size_test, shuffle=False)
    unseen_test_loader = DataLoader(unseen_test_dataset, batch_size=batch_size_test, shuffle=False)

    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Seen test samples: {len(seen_test_dataset)}")
    print(f"Unseen test samples: {len(unseen_test_dataset)}")

    return train_loader, val_loader, seen_test_loader, unseen_test_loader

In [None]:
batch_size_train = 32
batch_size_test = 64
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader, validation_loader, seen_test_loader, unseen_test_loader = split_subjects_and_data(
    root_dir="data_eeg",
    data_type='stimuli',
    seen_ratio=0.9,  # 18 seen, 3 unseen
    train_ratio=0.7,
    val_ratio=0.15,
    seed=42
)

Seen subjects: ['20', '06', '15', '05', '10', '14', '16', '19', '07', '13', '18', '11', '02', '12', '03', '17', '08', '09']
Unseen subjects: ['01', '04', '21']
Loaded data from subject 20
Loaded data from subject 06
Loaded data from subject 15
Loaded data from subject 05
Loaded data from subject 10
Loaded data from subject 14
Loaded data from subject 16
Loaded data from subject 19
Loaded data from subject 07
Loaded data from subject 13
Loaded data from subject 18
Loaded data from subject 11
Loaded data from subject 02
Loaded data from subject 12
Loaded data from subject 03
Loaded data from subject 17
Loaded data from subject 08
Loaded data from subject 09
Loaded data from subject 01
Loaded data from subject 04
Loaded data from subject 21
Train samples: 19811
Validation samples: 4245
Seen test samples: 4246
Unseen test samples: 4797


## Model

In [None]:
import math
import numpy as np
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import reduce

In [None]:
def get_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)


# Swish activation function
class Swish(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return x * self.sigmoid(x)


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class WeightStandardizedConv1d(nn.Conv1d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv1d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class ResidualConvBlock(nn.Module):
    def __init__(self, inc: int, outc: int, kernel_size: int, stride=1, gn=8):
        super().__init__()
        """
        standard ResNet style convolutional block
        """
        self.same_channels = inc == outc
        self.ks = kernel_size
        self.conv = nn.Sequential(
            WeightStandardizedConv1d(inc, outc, self.ks, stride, get_padding(self.ks)),
            nn.GroupNorm(gn, outc),
            nn.PReLU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.conv(x)
        if self.same_channels:
            out = (x + x1) / 2
        else:
            out = x1
        return out


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, gn=8, factor=2):
        super(UnetDown, self).__init__()
        self.pool = nn.MaxPool1d(factor)
        self.layer = ResidualConvBlock(in_channels, out_channels, kernel_size, gn=gn)

    def forward(self, x):
        x = self.layer(x)
        x = self.pool(x)
        return x


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, gn=8, factor=2):
        super(UnetUp, self).__init__()
        self.pool = nn.Upsample(scale_factor=factor, mode="nearest")
        self.layer = ResidualConvBlock(in_channels, out_channels, kernel_size, gn=gn)

    def forward(self, x):
        x = self.pool(x)
        x = self.layer(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        """
        generic one layer FC NN for embedding things
        """
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.PReLU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ConditionalUNet(nn.Module):
    def __init__(self, in_channels, n_feat=256):
        super(ConditionalUNet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat

        self.d1_out = n_feat * 1
        self.d2_out = n_feat * 2
        self.d3_out = n_feat * 3
        self.d4_out = n_feat * 4

        self.u1_out = n_feat
        self.u2_out = n_feat
        self.u3_out = n_feat
        self.u4_out = in_channels

        self.sin_emb = SinusoidalPosEmb(n_feat)
        # self.timeembed1 = EmbedFC(n_feat, self.u1_out)
        # self.timeembed2 = EmbedFC(n_feat, self.u2_out)
        # self.timeembed3 = EmbedFC(n_feat, self.u3_out)

        self.down1 = UnetDown(in_channels, self.d1_out, 1, gn=2, factor=2)
        self.down2 = UnetDown(self.d1_out, self.d2_out, 1, gn=2, factor=2)
        self.down3 = UnetDown(self.d2_out, self.d3_out, 1, gn=2, factor=2)

        self.up2 = UnetUp(self.d3_out, self.u2_out, 1, gn=2, factor=2)
        self.up3 = UnetUp(self.u2_out + self.d2_out, self.u3_out, 1, gn=2, factor=2)
        self.up4 = UnetUp(self.u3_out + self.d1_out, self.u4_out, 1, gn=2, factor=2)
        self.out = nn.Conv1d(self.u4_out + in_channels, in_channels, 1)

    def forward(self, x, t):
        down1 = self.down1(x)  # 2000 -> 1000
        down2 = self.down2(down1)  # 1000 -> 500
        down3 = self.down3(down2)  # 500 -> 250

        temb = self.sin_emb(t).view(-1, self.n_feat, 1)  # [b, n_feat, 1]

        up1 = self.up2(down3)  # 250 -> 500
        up2 = self.up3(torch.cat([up1 + temb, down2], 1))  # 500 -> 1000
        up3 = self.up4(torch.cat([up2 + temb, down1], 1))  # 1000 -> 2000
        out = self.out(torch.cat([up3, x], 1))  # 2000 -> 2000

        down = (down1, down2, down3)
        up = (up1, up2, up3)
        return out, down, up

class AttentionPool1d(nn.Module):
    def __init__(self, in_channels):
        super(AttentionPool1d, self).__init__()
        # Learnable query vector (shape: [in_channels])
        self.query = nn.Parameter(torch.randn(in_channels))

    def forward(self, x):
        """
        x: Tensor of shape (B, C, L) where B=batch size, C=channels, L=sequence length
        Returns a tensor of shape (B, C) by computing a weighted sum over L.
        """
        B, C, L = x.shape
        # Permute x to shape (B, L, C)
        x_perm = x.permute(0, 2, 1)  # (B, L, C)
        # Compute attention scores as dot-product between each time step and the query vector.
        # Resulting scores shape: (B, L)
        scores = torch.einsum('blc,c->bl', x_perm, self.query)
        # Softmax over the time dimension to get attention weights.
        weights = F.softmax(scores, dim=-1)  # (B, L)
        # Compute weighted sum over the time dimension.
        pooled = torch.sum(x * weights.unsqueeze(1), dim=2)  # (B, C)
        return pooled

class Encoder(nn.Module):
    def __init__(self, in_channels, dim=512):
        super(Encoder, self).__init__()

        self.in_channels = in_channels
        self.e1_out = dim
        self.e2_out = dim
        self.e3_out = dim

        self.down1 = UnetDown(in_channels, self.e1_out, 1, gn=2, factor=2)
        self.down2 = UnetDown(self.e1_out, self.e2_out, 1, gn=2, factor=2)
        self.down3 = UnetDown(self.e2_out, self.e3_out, 1, gn=2, factor=2)

        # self.avg_pooling = nn.AdaptiveAvgPool1d(output_size=1)
        self.att_pooling = AttentionPool1d(self.e3_out)
        # self.max_pooling = nn.AdaptiveMaxPool1d(output_size=1)
        self.act = nn.Tanh()

    def forward(self, x0):
        # Down sampling
        dn1 = self.down1(x0)  # 2048 -> 1024
        dn2 = self.down2(dn1)  # 1024 -> 512
        dn3 = self.down3(dn2)  # 512 -> 256
        # z = self.avg_pooling(dn3).view(-1, self.e3_out)  # [b, features]
        z = self.att_pooling(dn3)
        down = (dn1, dn2, dn3)
        out = (down, z)
        return out


class Decoder(nn.Module):
    def __init__(self, in_channels, n_feat=256, encoder_dim=512, n_classes=13):
        super(Decoder, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes
        self.e1_out = encoder_dim
        self.e2_out = encoder_dim
        self.e3_out = encoder_dim
        self.d1_out = n_feat
        self.d2_out = n_feat * 2
        self.d3_out = n_feat * 3
        self.u1_out = n_feat
        self.u2_out = n_feat
        self.u3_out = n_feat
        self.u4_out = in_channels

        # self.sin_emb = SinusoidalPosEmb(n_feat)
        # self.timeembed1 = EmbedFC(n_feat, self.e3_out)
        # self.timeembed2 = EmbedFC(n_feat, self.u2_out)
        # self.timeembed3 = EmbedFC(n_feat, self.u3_out)
        # self.contextembed1 = EmbedFC(self.e3_out, self.e3_out)
        # self.contextembed2 = EmbedFC(self.e3_out, self.u2_out)
        # self.contextembed3 = EmbedFC(self.e3_out, self.u3_out)

        # Unet up sampling
        self.up1 = UnetUp(self.d3_out + self.e3_out, self.u2_out, 1, gn=2, factor=2)
        self.up2 = UnetUp(self.d2_out + self.u2_out, self.u3_out, 1, gn=2, factor=2)
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv1d(
                self.d1_out + self.u3_out + in_channels * 2, in_channels, 1, 1, 0
            ),
        )

        # self.out = nn.Conv1d(self.u4_out+in_channels, in_channels, 1)
        self.pool = nn.AvgPool1d(2)

    def forward(self, x0, encoder_out, diffusion_out):
        # Encoder output
        down, z = encoder_out
        dn1, dn2, dn3 = down

        # DDPM output
        x_hat, down_ddpm, up, t = diffusion_out
        dn11, dn22, dn33 = down_ddpm

        # embed context, time step
        # temb = self.sin_emb(t).view(-1, self.n_feat, 1) # [b, n_feat, 1]
        # temb1 = self.timeembed1(temb).view(-1, self.e3_out, 1) # [b, features]
        # temb2 = self.timeembed2(temb).view(-1, self.u2_out, 1) # [b, features]
        # temb3 = self.timeembed3(temb).view(-1, self.u3_out, 1) # [b, features]
        # ct2 = self.contextembed2(z).view(-1, self.u2_out, 1) # [b, n_feat, 1]
        # ct3 = self.contextembed3(z).view(-1, self.u3_out, 1) # [b, n_feat, 1]

        # Up sampling
        up1 = self.up1(torch.cat([dn3, dn33.detach()], 1))
        up2 = self.up2(torch.cat([up1, dn22.detach()], 1))
        out = self.up3(
            torch.cat([self.pool(x0), self.pool(x_hat.detach()), up2, dn11.detach()], 1)
        )
        return out


class DiffE(nn.Module):
    def __init__(self, encoder, decoder, fc):
        super(DiffE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.fc = fc

    def forward(self, x0, ddpm_out):
        encoder_out = self.encoder(x0)
        z = encoder_out[1]
        decoder_out = self.decoder(x0, encoder_out, ddpm_out)
        fc_out = self.fc(encoder_out[1])
        return decoder_out, fc_out, z


class DecoderNoDiff(nn.Module):
    def __init__(self, in_channels, n_feat=256, encoder_dim=512, n_classes=13):
        super(DecoderNoDiff, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes
        self.e1_out = encoder_dim
        self.e2_out = encoder_dim
        self.e3_out = encoder_dim
        self.u1_out = n_feat
        self.u2_out = n_feat
        self.u3_out = n_feat
        self.u4_out = n_feat

        self.sin_emb = SinusoidalPosEmb(n_feat)
        self.timeembed1 = EmbedFC(n_feat, self.e3_out)
        self.timeembed2 = EmbedFC(n_feat, self.u2_out)
        self.timeembed3 = EmbedFC(n_feat, self.u3_out)
        self.contextembed1 = EmbedFC(self.e3_out, self.e3_out)
        self.contextembed2 = EmbedFC(self.e3_out, self.u2_out)
        self.contextembed3 = EmbedFC(self.e3_out, self.u3_out)

        # Unet up sampling
        self.up2 = UnetUp(self.e3_out, self.u2_out, 1, gn=2, factor=2)
        self.up3 = UnetUp(self.e2_out + self.u2_out, self.u3_out, 1, gn=2, factor=2)
        # self.up4 = UnetUp(self.e1_out+self.u3_out, self.u4_out, 1, 1, gn=in_channels, factor=2, is_res=True)
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv1d(self.u3_out + self.e1_out + in_channels, in_channels, 1, 1, 0),
        )

        self.out = nn.Conv1d(self.u4_out, in_channels, 1)
        self.pool = nn.AvgPool1d(2)

    def forward(self, x0, x_hat, encoder_out, t):
        down, z = encoder_out
        dn1, dn2, dn3 = down
        tembd = self.sin_emb(t).view(-1, self.n_feat, 1)  # [b, n_feat, 1]
        tembd1 = self.timeembed1(self.sin_emb(t)).view(
            -1, self.e3_out, 1
        )  # [b, n_feat, 1]
        tembd2 = self.timeembed2(self.sin_emb(t)).view(
            -1, self.u2_out, 1
        )  # [b, n_feat, 1]
        tembd3 = self.timeembed3(self.sin_emb(t)).view(
            -1, self.u3_out, 1
        )  # [b, n_feat, 1]

        # Up sampling
        ddpm_loss = F.l1_loss(x0, x_hat, reduction="none")

        up2 = self.up2(dn3)  # 256 -> 512
        up3 = self.up3(torch.cat([up2, dn2], 1))  # 512 -> 1024
        out = self.up4(
            torch.cat([self.pool(x0), self.pool(x_hat), up3, dn1], 1)
        )  # 1024 -> 2048
        # out = self.out(torch.cat([out, x_hat], 1)) # 2048 -> 2048
        # out = self.out(out)
        return out


class LinearClassifier(nn.Module):
    def __init__(self, in_dim, latent_dim, emb_dim):
        super().__init__()
        self.linear_out = nn.Sequential(
            nn.Linear(in_features=in_dim, out_features=latent_dim),
            nn.GroupNorm(4, latent_dim),
            nn.PReLU(),
            nn.Linear(in_features=latent_dim, out_features=latent_dim),
            nn.GroupNorm(4, latent_dim),
            nn.PReLU(),
            nn.Linear(in_features=latent_dim, out_features=emb_dim),
        )

    def forward(self, x):
        x = self.linear_out(x)
        return x


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (
        v_end - v_start
    )
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


def ddpm_schedules(beta1, beta2, T):
    # assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"
    # beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    beta_t = cosine_beta_schedule(T, s=0.008).float()
    # beta_t = sigmoid_beta_schedule(T).float()

    alpha_t = 1 - beta_t

    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)

    return {
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
    }


class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device

    def forward(self, x):
        _ts = torch.randint(1, self.n_T, (x.shape[0],)).to(
            self.device
        )  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)
        x_t = self.sqrtab[_ts, None, None] * x + self.sqrtmab[_ts, None, None] * noise
        times = _ts / self.n_T
        output, down, up = self.nn_model(x_t, times)
        return output, down, up, noise, times

In [None]:
num_classes = 16
channels = 14
n_T = 1000
ddpm_dim = 128
encoder_dim = 256
fc_dim = 512

In [None]:
ddpm_model = ConditionalUNet(in_channels=channels, n_feat=ddpm_dim).to(device)
ddpm = DDPM(nn_model=ddpm_model, betas=(1e-6, 1e-2), n_T=n_T, device=device).to(device)
encoder = Encoder(in_channels=channels, dim=encoder_dim).to(device)
decoder = Decoder(in_channels=channels, n_feat=ddpm_dim, encoder_dim=encoder_dim).to(device)
fc = LinearClassifier(encoder_dim, fc_dim, emb_dim=num_classes).to(device)
diffe = DiffE(encoder, decoder, fc).to(device)

In [None]:
print("ddpm size: ", sum(p.numel() for p in ddpm.parameters()))
print("encoder size: ", sum(p.numel() for p in encoder.parameters()))
print("decoder size: ", sum(p.numel() for p in decoder.parameters()))
print("fc size: ", sum(p.numel() for p in fc.parameters()))

ddpm size:  238278
encoder size:  137219
decoder size:  135832
fc size:  404498


## Train

In [None]:
import wandb
wandb.login(key="fc35b6207578f4e85e34481be02780068223a3f6")

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
from tqdm.auto import tqdm
from ema_pytorch import EMA
from tqdm import tqdm
from sklearn.metrics import (
    f1_score,
    roc_auc_score,
    precision_score,
    recall_score,
    top_k_accuracy_score,
)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
# Criterion
criterion = nn.L1Loss(reduction='none')
criterion_class = nn.CrossEntropyLoss()

In [None]:
lr_ddpm = 1e-4
lr_diffe = 1e-4
weight_decay = 0.01

optimizer1 = optim.AdamW(ddpm.parameters(), lr=lr_ddpm, weight_decay=weight_decay)
optimizer2 = optim.AdamW(diffe.parameters(), lr=lr_diffe, weight_decay=weight_decay)

In [None]:
num_epochs = 100
scheduler1 = optim.lr_scheduler.CosineAnnealingLR(optimizer1, T_max=num_epochs, eta_min=1e-7)
scheduler2 = optim.lr_scheduler.CosineAnnealingLR(optimizer2, T_max=num_epochs, eta_min=1e-7)

In [None]:
fc_ema = EMA(diffe.fc, beta=0.95, update_after_step=100, update_every=10)

In [None]:
# Create directory for saving models if it doesn't exist
models_path = "saved_models"
os.makedirs(models_path, exist_ok=True)

In [None]:
best_val_accuracy = 0.0

In [None]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        features: [B, D] — z output from the encoder
        labels:   [B]    — integer type labels
        """
        device = features.device
        features = F.normalize(features, dim=1)              # Feature normalization
        batch_size = features.shape[0]

        # Construct positive sample mask
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)  # [B, B], 1 if same class

        # Similarity matrix
        sim = torch.matmul(features, features.T) / self.temperature  # [B, B]

        # Exclude diagonal (self with self)
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        mask = mask * logits_mask
        sim = sim - 1e9 * (1 - logits_mask)  # Mask the diagonal with large negative value

        # Compute log-softmax
        exp_sim = torch.exp(sim)
        log_prob = sim - torch.log(exp_sim.sum(1, keepdim=True) + 1e-6)

        # Compute mean log-probability of positive samples for each instance
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)

        # Mean negative log-likelihood
        loss = -mean_log_prob_pos.mean()
        return loss

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, proj_dim),
            nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, z):
        return F.normalize(self.net(z), dim=1)

In [None]:
def train_epoch(ddpm_model, diffe_model, ema_classifier, train_loader,
                optimizer1, optimizer2, scheduler1, scheduler2,
                criterion_recon, criterion_class, epoch, device, num_classes):
    """Runs a single training epoch."""
    ddpm_model.train()
    diffe_model.train()

    total_loss_supcon = 0.0
    total_loss_decoder = 0.0
    total_loss_c = 0.0
    total_loss = 0.0

    alpha = 1
    beta = min(1.0, epoch / 50) * 0.2
    gamma = min(1.0, epoch / 100) * 0.05
    supcon_loss = SupConLoss(temperature=0.07)
    proj_head = ProjectionHead(input_dim=256, proj_dim=128).to(device)

    progress_bar = tqdm(train_loader, desc=f"Training", leave=True)

    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.type(torch.LongTensor).to(device)

        # --- Train DDPM ---
        optimizer1.zero_grad()
        x_hat, down, up, noise, t = ddpm_model(x)
        # Use the criterion_recon which has reduction='none'
        loss_ddpm_per_sample = criterion_recon(x_hat, x)  # No reduction param needed
        loss_ddpm = loss_ddpm_per_sample.mean()  # Average for DDPM backward pass
        loss_ddpm.backward()
        optimizer1.step()
        ddpm_out = x_hat, down, up, t

        # --- Train Diff-E ---
        optimizer2.zero_grad()
        decoder_out, fc_out, z = diffe_model(x, ddpm_out)
        #loss_gap = criterion(decoder_out, loss_ddpm.detach())
        loss_decoder = F.l1_loss(decoder_out, x_hat.detach())
        loss_c = criterion_class(fc_out, y)
        z_proj = proj_head(z)
        loss_supcon = supcon_loss(z_proj, y)
        loss = alpha * loss_c + beta * loss_supcon + gamma * loss_decoder
        loss.backward()
        optimizer2.step()

        # --- Update EMA ---
        ema_classifier.update()

        # --- Logging (Wandb) ---
        wandb.log({
            "train/batch_loss_decoder": loss_decoder.item(),
            "train/batch_loss_supcon": loss_supcon.item(),
            "train/batch_loss_c": loss_c.item(),
            "train/batch_loss_total": loss.item(),
            "train/learning_rate1": scheduler1.get_last_lr()[0],
            "train/learning_rate2": scheduler2.get_last_lr()[0]
        })

        # --- Accumulate Epoch Losses ---
        total_loss_decoder += loss_decoder.item()
        total_loss_supcon += loss_supcon.item()
        total_loss_c += loss_c.item()
        total_loss += loss.item()

    # --- Return Average Epoch Losses ---
    num_batches = len(train_loader)
    avg_losses = {
        "loss_decoder": total_loss_decoder / num_batches,
        "loss_supcon": total_loss_supcon / num_batches,
        "loss_c": total_loss_c / num_batches,
        "loss_total": total_loss / num_batches,
    }
    return avg_losses

In [None]:
def evaluate(diffe_model, ddpm_model, ema_classifier, dataloader,
             criterion_recon, criterion_class, epoch, device, num_classes):
    """Evaluates the model on the provided dataloader, returning metrics and losses."""
    diffe_model.eval()
    ddpm_model.eval()
    ema_classifier.eval()

    all_labels = []
    all_preds = []
    total_loss_supcon = 0.0
    total_loss_decoder = 0.0
    total_loss_c = 0.0
    total_loss = 0.0

    alpha = 1
    beta = min(1.0, epoch / 50) * 0.2
    gamma = min(1.0, epoch / 100) * 0.05
    supcon_loss = SupConLoss(temperature=0.07)
    proj_head = ProjectionHead(input_dim=256, proj_dim=128).to(device)

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.type(torch.LongTensor).to(device)

            # DDPM forward pass
            x_hat, down, up, noise, t = ddpm_model(x)
            loss_ddpm_per_sample = criterion_recon(x_hat, x)
            loss_ddpm = loss_ddpm_per_sample.mean()
            ddpm_out = x_hat, down, up, t

            # DiffE forward pass
            decoder_out, fc_out, z = diffe_model(x, ddpm_out)
            #loss_gap = criterion_recon(decoder_out, loss_ddpm_per_sample).mean()

            # Classification loss
            y_hat = ema_classifier(diffe_model.encoder(x)[1])

            loss_decoder = F.l1_loss(decoder_out, x_hat.detach())
            loss_c = criterion_class(fc_out, y)
            z_proj = proj_head(z)
            loss_supcon = supcon_loss(z_proj, y)
            loss = alpha * loss_c + beta * loss_supcon + gamma * loss_decoder

            # For metrics calculation
            y_hat_softmax = F.softmax(y_hat, dim=1)
            all_labels.append(y.detach().cpu())
            all_preds.append(y_hat_softmax.detach().cpu())

            # Accumulate losses
            total_loss_decoder += loss_decoder.item()
            total_loss_supcon += loss_supcon.item()
            total_loss_c += loss_c.item()
            total_loss += loss.item()

    # Calculate average losses
    num_batches = len(dataloader)
    avg_losses = {
        "loss_decoder": total_loss_decoder / num_batches,
        "loss_supcon": total_loss_supcon / num_batches,
        "loss_c": total_loss_c / num_batches,
        "loss_total": total_loss / num_batches,
    }

    # Convert predictions and labels for metric calculation
    all_labels = torch.cat(all_labels, dim=0).numpy()
    all_preds = torch.cat(all_preds, dim=0).numpy()
    pred_classes = all_preds.argmax(axis=1)

    # Calculate metrics
    accuracy = (pred_classes == all_labels).mean()

    # For multi-class metrics, handle potential warnings
    try:
        f1 = f1_score(all_labels, pred_classes, average="macro", zero_division=0)
        recall = recall_score(all_labels, pred_classes, average="macro", zero_division=0)
        precision = precision_score(all_labels, pred_classes, average="macro", zero_division=0)
        auc = roc_auc_score(all_labels, all_preds, average="macro", multi_class="ovo")
    except ValueError as e:
        print(f"Warning in metric calculation: {e}")
        f1, recall, precision, auc = 0.0, 0.0, 0.0, 0.0

    metrics = {
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "precision": precision,
        "auc": auc,
    }

    return metrics, avg_losses

In [None]:
# Wandb configuration
config = {
    # Model parameters
    "model_type": "DiffE-EEG",
    "num_classes": num_classes,
    "channels": channels,
    "n_T": n_T,
    "ddpm_dim": ddpm_dim,
    "encoder_dim": encoder_dim,
    "fc_dim": fc_dim,

    # Training parameters
    "epochs": num_epochs,
    "train_batch_size": batch_size_train,
    "test_batch_size": batch_size_test,
    "seed": seed,

    # Optimizer parameters
    "optimizer": "AdamW",
    "learning_rate_ddpm": lr_ddpm,
    "learning_rate_diffe": lr_diffe,
    "weight_decay": weight_decay,

    # Scheduler parameters
    "scheduler": "CosineAnnealingLR",
    "scheduler_T_max": num_epochs,
    "scheduler_eta_min": 1e-7,

    # EMA parameters
    "ema_beta": 0.95,
    "ema_update_after_step": 100,
    "ema_update_every": 10
}

# Initialize wandb run
run = wandb.init(
    project="DiffE-EEG",
    config=config,
    name=f"diffE-FEIS-thinking",
    reinit=True  # For notebooks
)

In [None]:
print("Starting training...")
epoch_progress = tqdm(range(num_epochs), desc="Epochs", position=0)

best_val_accuracy = 0.0

for epoch in epoch_progress:
    # --- Training ---
    train_losses = train_epoch(
        ddpm, diffe, fc_ema, train_loader,
        optimizer1, optimizer2, scheduler1, scheduler2,
        criterion, criterion_class, epoch, device, num_classes
    )

    # Step schedulers after full epoch
    scheduler1.step()
    scheduler2.step()

    # --- Evaluation ---
    epoch_progress.set_description(f"Evaluating Epoch {epoch+1}/{num_epochs}")
    val_metrics, val_losses = evaluate(diffe, ddpm, fc_ema, validation_loader, criterion, criterion_class, epoch, device, num_classes)

    # Update progress bar with metrics
    epoch_progress.set_postfix({
        'Train Loss': f"{train_losses['loss_total']:.4f}",
        'Val Loss': f"{val_losses['loss_total']:.4f}",
        'Val Acc': f"{val_metrics['accuracy']:.4f}",
        'Val F1': f"{val_metrics['f1']:.4f}"
    })

    # --- Wandb Logging (Epoch Level) ---
    wandb.log({
        "epoch": epoch + 1,
        # Training losses
        "train/epoch_loss_decoder": train_losses['loss_decoder'],
        "train/epoch_loss_supcon": train_losses['loss_supcon'],
        "train/epoch_loss_c": train_losses['loss_c'],
        "train/epoch_loss_total": train_losses['loss_total'],
        # Validation losses
        "val/loss_decoder": val_losses['loss_decoder'],
        "val/loss_supcon": val_losses['loss_supcon'],
        "val/loss_c": val_losses['loss_c'],
        "val/loss_total": val_losses['loss_total'],
        # Validation metrics
        "val/accuracy": val_metrics['accuracy'],
        "val/f1_score": val_metrics['f1'],
        "val/recall": val_metrics['recall'],
        "val/precision": val_metrics['precision'],
        "val/auc": val_metrics['auc']
    })

    # --- Checkpoint Saving ---
    current_accuracy = val_metrics['accuracy']
    if current_accuracy > best_val_accuracy:
        best_val_accuracy = current_accuracy
        best_model_epoch = epoch + 1
        epoch_progress.write(f"*** New best validation accuracy: {best_val_accuracy:.4f} at epoch {epoch+1} ***")
        wandb.log({"val/best_accuracy": best_val_accuracy})

        # Save the models
        torch.save(diffe.state_dict(), os.path.join(models_path, f"best_diffe_model_epoch{epoch+1}.pt"))
        torch.save(ddpm.state_dict(), os.path.join(models_path, f"best_ddpm_model_epoch{epoch+1}.pt"))
        epoch_progress.write(f"Models saved to {models_path}")

# --- End of Training ---
print("Training finished.")
print(f"Best validation accuracy achieved: {best_val_accuracy:.4f}")

Starting training...


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 1/100:   1%|          | 1/100 [00:16<26:28, 16.04s/it, Train Loss=2.7896, Val Loss=2.7753, Val Acc=0.0775, Val F1=0.0726]

*** New best validation accuracy: 0.0775 at epoch 1 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:22<?, ?it/s]
Evaluating Epoch 2/100:   2%|▏         | 2/100 [00:39<32:59, 20.20s/it, Train Loss=2.7787, Val Loss=2.7874, Val Acc=0.0735, Val F1=0.0728]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 3/100:   3%|▎         | 3/100 [00:54<28:56, 17.91s/it, Train Loss=2.7808, Val Loss=2.8149, Val Acc=0.0784, Val F1=0.0743]

*** New best validation accuracy: 0.0784 at epoch 3 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 4/100:   4%|▍         | 4/100 [01:09<26:58, 16.86s/it, Train Loss=2.7784, Val Loss=2.8226, Val Acc=0.0810, Val F1=0.0801]

*** New best validation accuracy: 0.0810 at epoch 4 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 5/100:   5%|▌         | 5/100 [01:24<25:43, 16.24s/it, Train Loss=2.7728, Val Loss=2.8513, Val Acc=0.0843, Val F1=0.0847]

*** New best validation accuracy: 0.0843 at epoch 5 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 6/100:   6%|▌         | 6/100 [01:39<24:55, 15.91s/it, Train Loss=2.7636, Val Loss=2.8577, Val Acc=0.0947, Val F1=0.0937]

*** New best validation accuracy: 0.0947 at epoch 6 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 7/100:   7%|▋         | 7/100 [01:55<24:23, 15.74s/it, Train Loss=2.7437, Val Loss=2.8666, Val Acc=0.1062, Val F1=0.1062]

*** New best validation accuracy: 0.1062 at epoch 7 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:15<?, ?it/s]
Evaluating Epoch 8/100:   8%|▊         | 8/100 [02:11<24:30, 15.98s/it, Train Loss=2.7059, Val Loss=2.8793, Val Acc=0.1135, Val F1=0.1128]

*** New best validation accuracy: 0.1135 at epoch 8 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 9/100:   9%|▉         | 9/100 [02:27<23:55, 15.78s/it, Train Loss=2.6664, Val Loss=2.9068, Val Acc=0.1246, Val F1=0.1251]

*** New best validation accuracy: 0.1246 at epoch 9 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:15<?, ?it/s]
Evaluating Epoch 10/100:  10%|█         | 10/100 [02:43<23:55, 15.95s/it, Train Loss=2.6055, Val Loss=2.9081, Val Acc=0.1336, Val F1=0.1328]

*** New best validation accuracy: 0.1336 at epoch 10 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 11/100:  11%|█         | 11/100 [02:58<23:22, 15.76s/it, Train Loss=2.5333, Val Loss=2.9238, Val Acc=0.1390, Val F1=0.1387]

*** New best validation accuracy: 0.1390 at epoch 11 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 12/100:  12%|█▏        | 12/100 [03:14<22:54, 15.62s/it, Train Loss=2.4699, Val Loss=2.9169, Val Acc=0.1552, Val F1=0.1548]

*** New best validation accuracy: 0.1552 at epoch 12 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 13/100:  13%|█▎        | 13/100 [03:29<22:28, 15.50s/it, Train Loss=2.3830, Val Loss=2.9392, Val Acc=0.1625, Val F1=0.1623]

*** New best validation accuracy: 0.1625 at epoch 13 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 14/100:  14%|█▍        | 14/100 [03:44<22:05, 15.41s/it, Train Loss=2.2906, Val Loss=2.9770, Val Acc=0.1845, Val F1=0.1840]

*** New best validation accuracy: 0.1845 at epoch 14 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 15/100:  15%|█▌        | 15/100 [04:01<22:32, 15.91s/it, Train Loss=2.2119, Val Loss=2.9717, Val Acc=0.1976, Val F1=0.1970]

*** New best validation accuracy: 0.1976 at epoch 15 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 16/100:  16%|█▌        | 16/100 [04:17<22:09, 15.82s/it, Train Loss=2.1028, Val Loss=2.9733, Val Acc=0.2130, Val F1=0.2126]

*** New best validation accuracy: 0.2130 at epoch 16 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 17/100:  17%|█▋        | 17/100 [04:32<21:38, 15.64s/it, Train Loss=2.0089, Val Loss=2.9906, Val Acc=0.2233, Val F1=0.2237]

*** New best validation accuracy: 0.2233 at epoch 17 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 18/100:  18%|█▊        | 18/100 [04:47<21:13, 15.53s/it, Train Loss=1.9052, Val Loss=3.0251, Val Acc=0.2410, Val F1=0.2408]

*** New best validation accuracy: 0.2410 at epoch 18 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 19/100:  19%|█▉        | 19/100 [05:03<20:50, 15.43s/it, Train Loss=1.8218, Val Loss=3.0501, Val Acc=0.2563, Val F1=0.2563]

*** New best validation accuracy: 0.2563 at epoch 19 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 20/100:  20%|██        | 20/100 [05:18<20:29, 15.37s/it, Train Loss=1.7146, Val Loss=3.0552, Val Acc=0.2598, Val F1=0.2601]

*** New best validation accuracy: 0.2598 at epoch 20 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 21/100:  21%|██        | 21/100 [05:33<20:10, 15.32s/it, Train Loss=1.6142, Val Loss=3.1106, Val Acc=0.2730, Val F1=0.2735]

*** New best validation accuracy: 0.2730 at epoch 21 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 22/100:  22%|██▏       | 22/100 [05:48<19:52, 15.29s/it, Train Loss=1.5315, Val Loss=3.1057, Val Acc=0.2956, Val F1=0.2955]

*** New best validation accuracy: 0.2956 at epoch 22 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 23/100:  23%|██▎       | 23/100 [06:04<19:47, 15.42s/it, Train Loss=1.4537, Val Loss=3.1165, Val Acc=0.3093, Val F1=0.3089]

*** New best validation accuracy: 0.3093 at epoch 23 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 24/100:  24%|██▍       | 24/100 [06:19<19:35, 15.47s/it, Train Loss=1.3678, Val Loss=3.1939, Val Acc=0.3150, Val F1=0.3148]

*** New best validation accuracy: 0.3150 at epoch 24 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 25/100:  25%|██▌       | 25/100 [06:35<19:16, 15.42s/it, Train Loss=1.3013, Val Loss=3.2896, Val Acc=0.3296, Val F1=0.3294]

*** New best validation accuracy: 0.3296 at epoch 25 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 26/100:  26%|██▌       | 26/100 [06:50<18:56, 15.36s/it, Train Loss=1.2197, Val Loss=3.3003, Val Acc=0.3383, Val F1=0.3378]

*** New best validation accuracy: 0.3383 at epoch 26 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 27/100:  27%|██▋       | 27/100 [07:05<18:36, 15.30s/it, Train Loss=1.1651, Val Loss=3.2999, Val Acc=0.3489, Val F1=0.3486]

*** New best validation accuracy: 0.3489 at epoch 27 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 28/100:  28%|██▊       | 28/100 [07:20<18:18, 15.26s/it, Train Loss=1.1021, Val Loss=3.2950, Val Acc=0.3656, Val F1=0.3649]

*** New best validation accuracy: 0.3656 at epoch 28 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 29/100:  29%|██▉       | 29/100 [07:36<18:02, 15.25s/it, Train Loss=1.0413, Val Loss=3.3389, Val Acc=0.3708, Val F1=0.3704]

*** New best validation accuracy: 0.3708 at epoch 29 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 30/100:  30%|███       | 30/100 [07:51<17:45, 15.22s/it, Train Loss=0.9867, Val Loss=3.4086, Val Acc=0.3804, Val F1=0.3803]

*** New best validation accuracy: 0.3804 at epoch 30 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 31/100:  31%|███       | 31/100 [08:06<17:36, 15.31s/it, Train Loss=0.9598, Val Loss=3.4663, Val Acc=0.3795, Val F1=0.3795]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 32/100:  32%|███▏      | 32/100 [08:22<17:25, 15.38s/it, Train Loss=0.9343, Val Loss=3.4506, Val Acc=0.3896, Val F1=0.3894]

*** New best validation accuracy: 0.3896 at epoch 32 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 33/100:  33%|███▎      | 33/100 [08:37<17:10, 15.39s/it, Train Loss=0.8987, Val Loss=3.5214, Val Acc=0.3967, Val F1=0.3964]

*** New best validation accuracy: 0.3967 at epoch 33 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 34/100:  34%|███▍      | 34/100 [08:52<16:53, 15.36s/it, Train Loss=0.8569, Val Loss=3.5180, Val Acc=0.4052, Val F1=0.4052]

*** New best validation accuracy: 0.4052 at epoch 34 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 35/100:  35%|███▌      | 35/100 [09:08<16:33, 15.28s/it, Train Loss=0.8332, Val Loss=3.5557, Val Acc=0.4148, Val F1=0.4147]

*** New best validation accuracy: 0.4148 at epoch 35 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 36/100:  36%|███▌      | 36/100 [09:23<16:14, 15.22s/it, Train Loss=0.8297, Val Loss=3.6261, Val Acc=0.4115, Val F1=0.4113]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 37/100:  37%|███▋      | 37/100 [09:38<15:56, 15.18s/it, Train Loss=0.7971, Val Loss=3.6199, Val Acc=0.4212, Val F1=0.4211]

*** New best validation accuracy: 0.4212 at epoch 37 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 38/100:  38%|███▊      | 38/100 [09:53<15:38, 15.14s/it, Train Loss=0.7754, Val Loss=3.7102, Val Acc=0.4210, Val F1=0.4210]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 39/100:  39%|███▉      | 39/100 [10:08<15:24, 15.15s/it, Train Loss=0.7742, Val Loss=3.7755, Val Acc=0.4238, Val F1=0.4236]

*** New best validation accuracy: 0.4238 at epoch 39 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 40/100:  40%|████      | 40/100 [10:23<15:13, 15.23s/it, Train Loss=0.7597, Val Loss=3.7862, Val Acc=0.4231, Val F1=0.4225]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 41/100:  41%|████      | 41/100 [10:39<15:04, 15.34s/it, Train Loss=0.7716, Val Loss=3.8323, Val Acc=0.4290, Val F1=0.4290]

*** New best validation accuracy: 0.4290 at epoch 41 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 42/100:  41%|████      | 41/100 [10:54<15:04, 15.34s/it, Train Loss=0.7371, Val Loss=3.8443, Val Acc=0.4342, Val F1=0.4339]

*** New best validation accuracy: 0.4342 at epoch 42 ***


Evaluating Epoch 42/100:  42%|████▏     | 42/100 [10:56<15:12, 15.73s/it, Train Loss=0.7371, Val Loss=3.8443, Val Acc=0.4342, Val F1=0.4339]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 43/100:  43%|████▎     | 43/100 [11:11<14:46, 15.54s/it, Train Loss=0.7600, Val Loss=3.8862, Val Acc=0.4342, Val F1=0.4338]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 44/100:  44%|████▍     | 44/100 [11:26<14:23, 15.41s/it, Train Loss=0.7536, Val Loss=3.8780, Val Acc=0.4424, Val F1=0.4425]

*** New best validation accuracy: 0.4424 at epoch 44 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 45/100:  45%|████▌     | 45/100 [11:41<14:02, 15.31s/it, Train Loss=0.7180, Val Loss=3.9635, Val Acc=0.4403, Val F1=0.4403]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 46/100:  46%|████▌     | 46/100 [11:56<13:40, 15.20s/it, Train Loss=0.7420, Val Loss=4.0152, Val Acc=0.4419, Val F1=0.4419]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 47/100:  47%|████▋     | 47/100 [12:11<13:22, 15.14s/it, Train Loss=0.7553, Val Loss=3.9938, Val Acc=0.4405, Val F1=0.4402]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 48/100:  48%|████▊     | 48/100 [12:26<13:08, 15.17s/it, Train Loss=0.7317, Val Loss=4.0593, Val Acc=0.4396, Val F1=0.4395]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 49/100:  49%|████▉     | 49/100 [12:41<12:57, 15.25s/it, Train Loss=0.7320, Val Loss=4.1089, Val Acc=0.4415, Val F1=0.4413]
Training:   0%|          | 0/620 [00:14<?, ?

*** New best validation accuracy: 0.4457 at epoch 50 ***


Evaluating Epoch 50/100:  50%|█████     | 50/100 [12:58<13:05, 15.72s/it, Train Loss=0.7580, Val Loss=4.1342, Val Acc=0.4457, Val F1=0.4458]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 51/100:  50%|█████     | 50/100 [13:13<13:05, 15.72s/it, Train Loss=0.7425, Val Loss=4.1606, Val Acc=0.4464, Val F1=0.4467]

*** New best validation accuracy: 0.4464 at epoch 51 ***


Evaluating Epoch 51/100:  51%|█████     | 51/100 [13:15<13:05, 16.04s/it, Train Loss=0.7425, Val Loss=4.1606, Val Acc=0.4464, Val F1=0.4467]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 52/100:  52%|█████▏    | 52/100 [13:30<12:35, 15.74s/it, Train Loss=0.7537, Val Loss=4.2074, Val Acc=0.4365, Val F1=0.4363]
Training:   0%|          | 0/620 [00:13<?, ?it/s]
Evaluating Epoch 53/100:  53%|█████▎    | 53/100 [13:45<12:08, 15.50s/it, Train Loss=0.7387, Val Loss=4.2177, Val Acc=0.4412, Val F1=0.4413]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 54/100:  54%|█████▍    | 54/100 [14:00<11:45, 15.35s/it, Train Loss=0.7267, Val Loss=4.2022, Val Acc=0.4464, Val F1=0.4463]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 55/100:  55%|█████▌    | 55/100 [14:15<11:26, 15.25s/it, Train Loss=0.7265, Val Loss=4.2364, Val Acc=0.4452, Val F1=0.4455]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 56/100:  56%|█████▌    | 56/100 [14:31<11:14, 15.34s/it, Train Loss=0.7155, Val Loss=4.2696, Val Acc=0.4424, Val F1=0.4424]
Training:   0%|          | 0/620 [00:14<?, ?

*** New best validation accuracy: 0.4478 at epoch 57 ***


Evaluating Epoch 57/100:  57%|█████▋    | 57/100 [14:48<11:19, 15.80s/it, Train Loss=0.7145, Val Loss=4.2446, Val Acc=0.4478, Val F1=0.4477]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 58/100:  57%|█████▋    | 57/100 [15:02<11:19, 15.80s/it, Train Loss=0.6979, Val Loss=4.2864, Val Acc=0.4485, Val F1=0.4483]

*** New best validation accuracy: 0.4485 at epoch 58 ***


Evaluating Epoch 58/100:  58%|█████▊    | 58/100 [15:04<11:09, 15.94s/it, Train Loss=0.6979, Val Loss=4.2864, Val Acc=0.4485, Val F1=0.4483]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 59/100:  59%|█████▉    | 59/100 [15:19<10:42, 15.67s/it, Train Loss=0.6968, Val Loss=4.2945, Val Acc=0.4514, Val F1=0.4514]

*** New best validation accuracy: 0.4514 at epoch 59 ***
Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 60/100:  59%|█████▉    | 59/100 [15:34<10:42, 15.67s/it, Train Loss=0.6898, Val Loss=4.3016, Val Acc=0.4535, Val F1=0.4534]

*** New best validation accuracy: 0.4535 at epoch 60 ***


Evaluating Epoch 60/100:  60%|██████    | 60/100 [15:36<10:40, 16.02s/it, Train Loss=0.6898, Val Loss=4.3016, Val Acc=0.4535, Val F1=0.4534]

Models saved to saved_models



Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 61/100:  61%|██████    | 61/100 [15:51<10:12, 15.71s/it, Train Loss=0.6968, Val Loss=4.3278, Val Acc=0.4481, Val F1=0.4484]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 62/100:  62%|██████▏   | 62/100 [16:06<09:56, 15.68s/it, Train Loss=0.6838, Val Loss=4.3575, Val Acc=0.4488, Val F1=0.4485]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 63/100:  63%|██████▎   | 63/100 [16:22<09:40, 15.69s/it, Train Loss=0.6920, Val Loss=4.3547, Val Acc=0.4483, Val F1=0.4484]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 64/100:  64%|██████▍   | 64/100 [16:37<09:20, 15.56s/it, Train Loss=0.6955, Val Loss=4.3562, Val Acc=0.4504, Val F1=0.4504]
Training:   0%|          | 0/620 [00:14<?, ?it/s]
Evaluating Epoch 65/100:  65%|██████▌   | 65/100 [16:53<09:02, 15.51s/it, Train Loss=0.6792, Val Loss=4.3891, Val Acc=0.4518, Val F1=0.4517]
Training:   0%|          | 0/620 [00:14<?, ?

Training finished.
Best validation accuracy achieved: 0.4535





### Visualize EEG Signals

In [None]:
import matplotlib.pyplot as plt

def visualize_signals_grid(original, masked, reconstructed, sample_idxs, channel=0):
    """
    Creates a grid figure showing for each sample:
      Column 1: Signal Before (original)
      Column 2: Signal Masked (noisy/masked)
      Column 3: Signal After (reconstructed)
    """
    num_samples = len(sample_idxs)
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 4 * num_samples), squeeze=False)
    for i, sample_idx in enumerate(sample_idxs):
        orig_np = original[sample_idx, channel].detach().cpu().numpy()
        masked_np = masked[sample_idx, channel].detach().cpu().numpy()
        recon_np = reconstructed[sample_idx, channel].detach().cpu().numpy()
        axes[i, 0].plot(orig_np, color='blue')
        axes[i, 0].set_title(f"Sample {sample_idx} - Signal Before")
        axes[i, 1].plot(masked_np, color='orange')
        axes[i, 1].set_title(f"Sample {sample_idx} - Signal Masked")
        axes[i, 2].plot(recon_np, color='green')
        axes[i, 2].set_title(f"Sample {sample_idx} - Signal After")
        for j in range(3):
            axes[i, j].set_xlabel("Time")
            axes[i, j].set_ylabel("Amplitude")
    fig.tight_layout()
    return fig

In [None]:
best_diffe_path = os.path.join(models_path, f"best_diffe_model_epoch{best_model_epoch}.pt")
best_ddpm_path = os.path.join(models_path, f"best_ddpm_model_epoch{best_model_epoch}.pt")

if os.path.exists(best_diffe_path) and os.path.exists(best_ddpm_path):
    diffe.load_state_dict(torch.load(best_diffe_path))
    ddpm.load_state_dict(torch.load(best_ddpm_path))
    print(f"Loaded best models from epoch {best_model_epoch}")

Loaded best models from epoch 60


In [None]:
with torch.no_grad():
    for x, _ in seen_test_loader:
        x = x.to(device)
        # Generate reconstruction
        x_hat, down, up, noise, t = ddpm(x)

        # Visualize multiple samples and channels
        samples_to_visualize = min(3, x.size(0))
        channels_to_visualize = min(3, x.size(1))

        # Visualize each channel
        for channel_idx in range(channels_to_visualize):
            sample_idxs = list(range(samples_to_visualize))
            fig = visualize_signals_grid(
                x[:samples_to_visualize],
                noise[:samples_to_visualize],  # Using the noise from the diffusion process
                x_hat[:samples_to_visualize],
                sample_idxs,
                channel=channel_idx
            )

            wandb.log({f"Final_EEG_Signals/Channel_{channel_idx}": wandb.Image(fig)})
            plt.close(fig)  # Close to free memory

        # We only need one batch for visualization
        break

# --- Finish Wandb Run ---
wandb.finish()

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/batch_loss_c,████▇▇▇▅▅▅▃▃▄▄▃▃▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/batch_loss_decoder,█▅▄▇▅▇▃▄▄▃▄▄▃▄▂▃▃▅▃▃▄▂▃▄▂▄▂▂▄▄▃▂▁▃▂▃▂▂▂▄
train/batch_loss_supcon,▇█▆▆▄▂▁▅▂▄▃▄▃▅▅▃▄▅▄▅▅▃▅▄▅▃▆▄▄▄▅▄▆▄▅▃▅▃▁▃
train/batch_loss_total,█▇████▇▇█▅▃▂▃▁▁▂▁▁▂▁▁▂▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁
train/epoch_loss_c,█████▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch_loss_decoder,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch_loss_supcon,█▇▄▄▄▃▃▅▃▃▄▃▃▃▂▂▂▃▂▂▂▁▂▂▂▄▂▂▃▂▂▂▁▂▂▁▂▂▃▂
train/epoch_loss_total,█████▇▇▇▆▆▅▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate1,██████▇▇▇▇▇▇▇▇▆▆▆▆▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▁▁▁▁▁

0,1
epoch,100.0
train/batch_loss_c,0.00417
train/batch_loss_decoder,0.09588
train/batch_loss_supcon,0.0
train/batch_loss_total,0.00892
train/epoch_loss_c,0.01044
train/epoch_loss_decoder,0.10224
train/epoch_loss_supcon,3.08118
train/epoch_loss_total,0.63174
train/learning_rate1,0.0


In [None]:
# Evaluate on seen subjects
seen_metrics, seen_losses = evaluate(diffe, ddpm, fc_ema, seen_test_loader,
                                    criterion, criterion_class, epoch, device, num_classes)
print(f"Seen subjects test results: Accuracy={seen_metrics['accuracy']:.4f}, F1={seen_metrics['f1']:.4f}")

# Evaluate on unseen subjects
unseen_metrics, unseen_losses = evaluate(diffe, ddpm, fc_ema, unseen_test_loader,
                                        criterion, criterion_class, epoch, device, num_classes)
print(f"Unseen subjects test results: Accuracy={unseen_metrics['accuracy']:.4f}, F1={unseen_metrics['f1']:.4f}")

Seen subjects test results: Accuracy=0.4385, F1=0.4385
Unseen subjects test results: Accuracy=0.0638, F1=0.0628
