In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
# ============================================================
# EEG Foundation Challenge 2025 - Challenge 1
# SuS pretraining ‚Üí CCD RT regression
# (uses only cached 100 Hz .npy files)
# ============================================================

import os, random, numpy as np, pandas as pd, warnings, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from glob import glob
import torch.nn.functional as F
import mne
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
mne.set_log_level("ERROR")

# ---------------- CONFIG ----------------
CACHE_DIR         = "/data5/open_data/HBN/cache_eeg_100hz_noref"
BIDS_ROOT         = "/data5/open_data/HBN/EEG_BIDS/"
TARGET_SFREQ      = 100
WIN_S_SUS, WIN_S_CCD = 2.0, 2.0
STRIDE_S_SUS      = 1.0
BATCH_SIZE, NUM_WORKERS = 64, 4
EPOCHS_SUS, EPOCHS_CCD, LR_SUS, LR_CCD = 5, 10, 1e-3, 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(42); np.random.seed(42); torch.manual_seed(42); torch.cuda.manual_seed_all(42)

# ============================================================
# 1) File collectors (cache only)
# ============================================================
def collect_cached_files(cache_dir, task_name=None):
    """Î™®Îì† *_cached.npy ÌååÏùº ÏàòÏßë (task Ïù¥Î¶Ñ ÌïÑÌÑ∞ Í∞ÄÎä•)."""
    results = []
    for fn in os.listdir(cache_dir):
        if not fn.endswith("_cached.npy"):
            continue
        low = fn.lower()
        if task_name and (task_name.lower() not in low):
            continue
        results.append(os.path.join(cache_dir, fn))
    results = sorted(results)
    print(f"[INFO] Found {len(results)} cached EEG files ({'task: '+task_name if task_name else 'all'})")
    return [(f, "") for f in results]

# ============================================================
# 2) Cached loader & helper
# ============================================================
def cached_load_eeg(eeg_path):
    return np.load(eeg_path)

def make_window(x_ct, center_s, sfreq=TARGET_SFREQ, win_sec=2.0):
    t1 = int(center_s * sfreq)
    Tw = int(win_sec * sfreq)
    t0 = max(0, t1 - Tw)
    seg = x_ct[:, t0:t1]
    need = Tw - seg.shape[1]
    if need > 0:
        seg = np.pad(seg, ((0, 0), (need, 0)), mode="constant")
    return seg.astype(np.float32)

# ============================================================
# 3) CCD events + datasets
# ============================================================
def collect_ccd_event_files(bids_root):
    ev_files = glob(os.path.join(
        bids_root, "ds*/sub-*", "eeg", "sub-*_task-contrastChangeDetection*_events.tsv"
    ))
    print(f"‚úÖ Found {len(ev_files)} CCD event files.")
    return ev_files

def match_cached_to_event(cached_files, bids_root):
    ev_files = collect_ccd_event_files(bids_root)
    ev_dict = {os.path.basename(ef).replace("_events.tsv", ""): ef for ef in ev_files}
    pairs = []
    for eeg_path, _ in cached_files:
        base = os.path.basename(eeg_path)
        key = base.replace("_eeg_pp_cached.npy", "")
        if key in ev_dict:
            pairs.append((eeg_path, ev_dict[key]))
        else:
            key_no_run = key.replace("_run-1", "").replace("_run-2", "")
            matched = None
            for k, v in ev_dict.items():
                k_norm = k.replace("_run-1", "").replace("_run-2", "")
                if k_norm == key_no_run:
                    matched = v; break
            if matched:
                pairs.append((eeg_path, matched))
    print(f"üîó Matched {len(pairs)} cached EEG ‚Üî event pairs.")
    return pairs

def extract_ccd_trials(df):
    if df.empty or "onset" not in df.columns or "value" not in df.columns:
        return []
    trials = []
    on  = df["onset"].astype(float).values
    val = df["value"].astype(str).values
    fb  = df["feedback"].astype(str).values if "feedback" in df.columns else ["n/a"] * len(df)
    starts  = [i for i,v in enumerate(val) if "contrastTrial_start" in v]
    presses = [i for i,v in enumerate(val) if "buttonPress" in v]
    for ti in starts:
        t0 = on[ti]
        later = [pi for pi in presses if on[pi] > t0]
        if not later: 
            continue
        pi = later[0]
        rt = (on[pi]-t0) * 1000.0
        if 100 <= rt <= 3000 and "smiley" in fb[pi].lower():
            trials.append((t0, rt))
    return trials

# ============================================================
# 4) Datasets
# ============================================================
class SusPretrainDataset(Dataset):
    def __init__(self, eeg_files, win_s=WIN_S_SUS, stride_s=STRIDE_S_SUS):
        self.items = []
        self.win_s = win_s; self.stride_s = stride_s
        for p,_ in eeg_files:
            X = cached_load_eeg(p)
            # (Ï±ÑÎÑê, ÏÉòÌîåÏàò)
            T = X.shape[1] # ÏÉòÌîå Ïàò = Ï†ÑÏ≤¥ Í∏∏Ïù¥(Ï¥à)
            Tw = int(win_s * TARGET_SFREQ) # ÏúàÎèÑÏö∞ Í∏∏Ïù¥Î•º Ï£ºÌååÏàò Îã®ÏúÑÎ°ú
            stride = int(stride_s * TARGET_SFREQ)
            centers = [t1 / TARGET_SFREQ for t1 in range(Tw, T, stride)] # centerÎ•º Íµ¨Ìï¥ÏÑú Ï¥à Îã®ÏúÑÎ°ú Ï†ÄÏû•
            if len(centers) == 0 and T >= Tw:
                centers = [Tw / TARGET_SFREQ]
            for c in centers:
                self.items.append((p, c))
        random.shuffle(self.items)
        # [(ÌååÏùº Í≤ΩÎ°ú, center Ï¥à), ...]

    def __len__(self): return len(self.items)

    def _augment(self, x):
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)
        # ÎûúÎç§ÏúºÎ°ú ÏãúÍ∞ÑÏ∂ï ÏùºÎ∂ÄÎ•º 0ÏúºÎ°ú ÎßàÏä§ÌÇπ
        if np.random.rand() < 0.5:
            L = max(1, int(x.shape[1]*0.1)) # ÎßàÏä§ÌÅ¨ Íµ¨Í∞Ñ Í∏∏Ïù¥
            s = np.random.randint(0, x.shape[1]-L+1) # ÎûúÎç§ ÏãúÏûë Íµ¨Í∞Ñ
            x[:, s:s+L] = 0.0
        # ÎûúÎç§ÏúºÎ°ú Ï±ÑÎÑê ÏùºÎ∂ÄÎ•º 0ÏúºÎ°ú ÎßàÏä§ÌÇπ
        if np.random.rand() < 0.5:
            drop = max(1, int(x.shape[0]*0.05))
            idx = np.random.choice(x.shape[0], drop, replace=False)
            x[idx] = 0.0
        return x

    def __getitem__(self, idx):
        p, c = self.items[idx] # ÌååÏùº Í≤ΩÎ°ú, center Ï¥à
        X = cached_load_eeg(p)
        seg = make_window(X, c, win_sec=self.win_s) # (Ï±ÑÎÑê, ÏúàÎèÑÏö∞ Í∏∏Ïù¥)
        v1 = self._augment(seg.copy())
        v2 = self._augment(seg.copy())
        return torch.from_numpy(v1), torch.from_numpy(v2), torch.zeros(1) # ÎÑòÌååÏù¥ ‚Üí ÌÜ†Ïπò ÌÖêÏÑú

class CcdRtDataset(Dataset):
    def __init__(self, eeg_event_pairs, win_s=WIN_S_CCD):
        self.samples = []
        for eeg_path, ev_path in eeg_event_pairs:
            if not os.path.exists(ev_path): 
                continue
            df = pd.read_csv(ev_path, sep="\t")
            for o, rt in extract_ccd_trials(df):
                self.samples.append((eeg_path, o, rt))
        all_rts = np.array([s[2] for s in self.samples]).astype(np.float32)
        self.rt_mean = all_rts.mean()
        self.rt_std = all_rts.std() + 1e-6
        print(f"‚úÖ CCD Dataset: {len(self.samples)} trials. RT(ms) Mean={self.rt_mean:.2f}, Std={self.rt_std:.2f}")
        self.win_s = win_s

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        p, o, rt = self.samples[idx]
        X = cached_load_eeg(p)
        seg = make_window(X, o, win_sec=self.win_s)
        rt_normalized = (rt - self.rt_mean) / self.rt_std
        return torch.from_numpy(seg), torch.tensor([rt_normalized], dtype=torch.float32)

# ============================================================
# 5) Models
# ============================================================
from braindecode.models import EEGConformer

class SafeEEGConformerEncoder(nn.Module):
    def __init__(self, n_chans, sfreq, input_window_samples):
        super().__init__()
        last_err = None
        trials = [
            dict(n_chans=n_chans, n_outputs=1, n_times=input_window_samples,              sfreq=sfreq, return_features=True),
            dict(n_chans=n_chans, n_outputs=1, input_window_samples=input_window_samples, sfreq=sfreq, return_features=True),
            dict(n_chans=n_chans, n_outputs=1, n_times=input_window_samples,              sfreq=sfreq),
        ]
        for kw in trials:
            try:
                self.backbone = EEGConformer(**kw)
                break
            except TypeError as e:
                last_err = e
        if not hasattr(self, "backbone"):
            raise TypeError(f"EEGConformer init failed. Last error: {last_err}")

    def forward(self, x):
        z = self.backbone(x)
        if isinstance(z, tuple): z = z[0]
        return torch.flatten(z, 1)

class RtHead(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, max(64, feat_dim//2)),
            nn.ReLU(),
            nn.Linear(max(64, feat_dim//2), 1)
        )
    def forward(self, z):
        if z.ndim > 2: z = torch.flatten(z, 1)
        return self.mlp(z)

def nt_xent_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    logits12 = torch.matmul(z1, z2.T) / temperature
    logits21 = torch.matmul(z2, z1.T) / temperature
    logits12 -= logits12.max(dim=1, keepdim=True).values
    logits21 -= logits21.max(dim=1, keepdim=True).values
    labels = torch.arange(z1.size(0), device=z1.device)
    loss = (F.cross_entropy(logits12, labels) + F.cross_entropy(logits21, labels)) / 2
    return loss

# ============================================================
# 6) Train loops
# ============================================================
def train_pretrain_sus(dl, encoder, epochs=EPOCHS_SUS, lr=LR_SUS):
    with torch.no_grad():
        dummy, _, _ = next(iter(dl))
        feat_dim = encoder(dummy[:1].float().to(DEVICE)).shape[1]
    proj_head = nn.Sequential(
        nn.Linear(feat_dim, feat_dim // 2),
        nn.ReLU(),
        nn.Linear(feat_dim // 2, 128)
    ).to(DEVICE)
    opt = torch.optim.Adam(list(encoder.parameters()) + list(proj_head.parameters()), lr=lr)
    for ep in range(epochs):
        encoder.train(); proj_head.train(); losses=[]
        for x1, x2, _ in tqdm(dl, desc=f"[SuS pretrain] {ep+1}/{epochs}"):
            x1, x2 = x1.float().to(DEVICE), x2.float().to(DEVICE)
            z1, z2 = encoder(x1), encoder(x2)
            p1, p2 = proj_head(z1), proj_head(z2)
            loss = nt_xent_loss(p1, p2, temperature=2.0)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(loss.item())
        print(f"Epoch {ep+1}: contrastive loss={np.mean(losses):.4f}")
    return encoder

def train_ccd_rt(dl_tr, encoder, rt_head, epochs=EPOCHS_CCD, lr=LR_CCD):
    opt = torch.optim.Adam(list(encoder.parameters()) + list(rt_head.parameters()), lr=lr)
    for ep in range(epochs):
        encoder.train(); rt_head.train(); losses=[]
        for x, y in tqdm(dl_tr, desc=f"[CCD train] {ep+1}/{epochs}"):
            x, y = x.float().to(DEVICE), y.to(DEVICE)
            yhat = rt_head(encoder(x))
            loss = nn.functional.l1_loss(yhat, y)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(loss.item())
        print(f"Epoch {ep+1}: MAE={np.mean(losses):.3f} (z-score scale)")

# ============================================================
# 7) Main
# ============================================================
def main():
    # 1Ô∏è‚É£ SuS pretrain
    all_cached = collect_cached_files(CACHE_DIR)
    sus_like = [(p,_) for (p,_) in all_cached if "contrastchangedetection" not in p.lower()]
    print(f"[INFO] Pretrain files (excluding CCD): {len(sus_like)}")

    ds_sus = SusPretrainDataset(sus_like)
    dl_sus = DataLoader(ds_sus, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, pin_memory=True)

    x_demo, _, _ = next(iter(dl_sus))
    _, C, T = x_demo.shape
    encoder = SafeEEGConformerEncoder(C, TARGET_SFREQ, T).to(DEVICE)

    # -------- SSL pretraining --------
    encoder = train_pretrain_sus(dl_sus, encoder)

    # ‚úÖ Ïù∏ÏΩîÎçîÎßå Îî∞Î°ú Ï†ÄÏû•
    encoder_path = "encoder_sus_pretrained.pth"
    torch.save(encoder.state_dict(), encoder_path)
    print(f"‚úÖ Saved pretrained encoder to {encoder_path}")
    # --------------------------------

    # 2Ô∏è‚É£ CCD fine-tuning (RT regression)
    ccd_files = collect_cached_files(CACHE_DIR, task_name="contrastChangeDetection")
    matched_pairs = match_cached_to_event(ccd_files, BIDS_ROOT)
    ds_ccd = CcdRtDataset(matched_pairs)
    dl_ccd = DataLoader(ds_ccd, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, pin_memory=True)

    with torch.no_grad():
        feat_dim = encoder(x_demo[:1].float().to(DEVICE)).shape[1]

    rt_head = RtHead(feat_dim).to(DEVICE)
    train_ccd_rt(dl_ccd, encoder, rt_head)

    # ‚úÖ Ï†ÑÏ≤¥ Î™®Îç∏ Ï†ÄÏû• (encoder + head)
    ckpt_path = "weights_ch1_cached.pth"
    torch.save({"encoder": encoder.state_dict(),
                "rt_head": rt_head.state_dict()}, ckpt_path)
    print(f"‚úÖ Saved Challenge-1 weights to {ckpt_path}")


if __name__ == "__main__":
    main()


[INFO] Found 25944 cached EEG files (all)
[INFO] Pretrain files (excluding CCD): 20558


[SuS pretrain] 1/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 86084/86084 [11:22:52<00:00,  2.10it/s]   


Epoch 1: contrastive loss=3.6997


[SuS pretrain] 2/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 86084/86084 [13:43:17<00:00,  1.74it/s]   


Epoch 2: contrastive loss=3.6948


[SuS pretrain] 3/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 86084/86084 [9:00:56<00:00,  2.65it/s]   


Epoch 3: contrastive loss=3.6936


[SuS pretrain] 4/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 86084/86084 [9:48:18<00:00,  2.44it/s]   


Epoch 4: contrastive loss=3.6928


[SuS pretrain] 5/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 86084/86084 [9:35:59<00:00,  2.49it/s]   


Epoch 5: contrastive loss=3.6924
‚úÖ Saved pretrained encoder to encoder_sus_pretrained.pth
[INFO] Found 5386 cached EEG files (task: contrastChangeDetection)
‚úÖ Found 5390 CCD event files.
üîó Matched 5386 cached EEG ‚Üî event pairs.
‚úÖ CCD Dataset: 341 trials. RT(ms) Mean=2863.26, Std=168.62


[CCD train] 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:53<00:00,  8.98s/it]


Epoch 1: MAE=4.301 (z-score scale)


[CCD train] 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.78it/s]


Epoch 2: MAE=3.469 (z-score scale)


[CCD train] 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:03<00:00,  1.79it/s]


Epoch 3: MAE=3.069 (z-score scale)


[CCD train] 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.56it/s]


Epoch 4: MAE=2.372 (z-score scale)


[CCD train] 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.16it/s]


Epoch 5: MAE=2.176 (z-score scale)


[CCD train] 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.03it/s]


Epoch 6: MAE=2.080 (z-score scale)


[CCD train] 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:04<00:00,  1.33it/s]


Epoch 7: MAE=1.583 (z-score scale)


[CCD train] 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:04<00:00,  1.47it/s]


Epoch 8: MAE=1.455 (z-score scale)


[CCD train] 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.09it/s]


Epoch 9: MAE=1.677 (z-score scale)


[CCD train] 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:02<00:00,  2.30it/s]


Epoch 10: MAE=1.414 (z-score scale)
‚úÖ Saved Challenge-1 weights to weights_ch1_cached.pth
