In [1]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
# ============================================================
# EEG Foundation Challenge 2025 - Challenge 1
# SuS pretraining ‚Üí CCD RT regression (100 Hz preprocessed version)
# ------------------------------------------------------------
# - Reads *_eeg_pp.set EEGs (100 Hz) from per-subject folders (run optional)
# - Matches CCD events from BIDS (ds*/sub-*/eeg/*_events.tsv)
# - Caches normalized EEG (.npy)
# - Suppresses MNE/User/Future/Runtime warnings
# - Safe EEGConformer wrapper (version differences)
# ============================================================

import os, random, numpy as np, pandas as pd, warnings, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from glob import glob
import torch.nn.functional as F

# ---- suppress warnings/logs ----
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import mne
mne.set_log_level("ERROR")

# ============================================================
# 0. Config (Í≤ΩÎ°úÎßå ÎßûÏ∂∞Ï£ºÏÑ∏Ïöî)
# ============================================================
BIDS_ROOT         = "/data5/open_data/HBN/EEG_BIDS/"
PREPROCESSED_ROOT = "/data5/open_data/HBN/Preprocessed_EEG/0922try_bySubject/"
CACHE_DIR         = "/data5/open_data/HBN/cache_eeg_100hz_noref2"
os.makedirs(CACHE_DIR, exist_ok=True)

TARGET_SFREQ = 100
WIN_S_SUS, WIN_S_CCD = 2.0, 2.0            # ÏúàÎèÑ Í∏∏Ïù¥(Ï¥à)
STRIDE_S_SUS = 1.0                         # SuS pretrainÏö© ÏúàÎèÑ stride(Ï¥à)
BATCH_SIZE, NUM_WORKERS = 64, 2
EPOCHS_SUS, EPOCHS_CCD, LR_SUS, LR_CCD = 5, 10, 1e-3, 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# ============================================================
# 1) File collectors (run Ïú†Î¨¥ Ìè¨Ìï®)
# ============================================================
def collect_preprocessed_files(root_path, task_name=None):
    """
    /.../bySubject/sub-XXXX/** ÏóêÏÑú *_eeg_pp.set ÏàòÏßë (run Ïú†Î¨¥ Î¨¥Í¥Ä)
    task_nameÏù¥ Ï£ºÏñ¥ÏßÄÎ©¥ Ìï¥Îãπ Î¨∏ÏûêÏó¥ Ìè¨Ìï® ÌååÏùºÎßå.
    """
    results = []
    for dirpath, _, filenames in os.walk(root_path):
        # subject Ìè¥ÎçîÎßå ÌÉêÏÉâ
        if "sub-" not in dirpath: 
            continue
        for fn in filenames:
            low = fn.lower()
            if not low.endswith("_eeg_pp.set"):
                continue
            if task_name and (task_name.lower() not in low):
                continue
            results.append(os.path.join(dirpath, fn))
    results = sorted(results)
    print(f"[INFO] Found {len(results)} preprocessed EEG files ({'task: '+task_name if task_name else 'all'})")
    return [(f, "") for f in results]

def collect_ccd_event_files(bids_root):
    """BIDS Ìè¥ÎçîÏóêÏÑú CCD Ïù¥Î≤§Ìä∏ ÌååÏùº Î™ΩÎïÖ ÏàòÏßë."""
    ev_files = glob(os.path.join(
        bids_root, "ds*/sub-*", "eeg", "sub-*_task-contrastChangeDetection*_events.tsv"
    ))
    print(f"‚úÖ Found {len(ev_files)} CCD event files.")
    return ev_files

def match_eeg_to_event(preproc_files, bids_root):
    """
    preprocessed CCD EEG ‚Üî BIDS Ïù¥Î≤§Ìä∏ Îß§Ïπ≠.
    Í∑úÏπô: *_eeg_pp.set ‚Üí *_events.tsv (run Ïú†Î¨¥ Î™®Îëê ÎåÄÏùë)
    """
    ev_files = collect_ccd_event_files(bids_root)
    ev_dict = {os.path.basename(ef).replace("_events.tsv", ""): ef for ef in ev_files}

    pairs = []
    for eeg_path, _ in preproc_files:
        base = os.path.basename(eeg_path)
        # pp Ï†ëÎØ∏Ïñ¥ Ï†úÍ±∞ÌïòÏó¨ ÌÇ§ ÏÉùÏÑ± (ÌôïÏû•Ïûê Ï†úÍ±∞)
        key = base.replace("_eeg_pp.set", "").replace(".set", "")
        if key in ev_dict:
            pairs.append((eeg_path, ev_dict[key]))
        else:
            # run ÏóÜÎäî Î≥ÄÌòïÎèÑ ÌÉêÏÉâ
            key_no_run = key.replace("_run-1", "").replace("_run-2", "")
            matched = None
            for k, v in ev_dict.items():
                k_norm = k.replace("_run-1", "").replace("_run-2", "")
                if k_norm == key_no_run:
                    matched = v; break
            if matched:
                pairs.append((eeg_path, matched))
    print(f"üîó Matched {len(pairs)} EEG ‚Üî event pairs.")
    return pairs

# ============================================================
# 2) Cached EEG loader (z-score only, no align/resample)
# ============================================================
def read_raw(eeg_path):
    return mne.io.read_raw_eeglab(eeg_path, preload=True, verbose=False)

def cached_load_eeg(eeg_path):
    """ÌååÏùº Îã®ÏúÑ Ï∫êÏãú: <basename>_cached.npy"""
    fname = os.path.basename(eeg_path).replace(".set", "_cached.npy")
    cache_path = os.path.join(CACHE_DIR, fname)
    if os.path.exists(cache_path):
        return np.load(cache_path)
    raw = read_raw(eeg_path)
    raw.load_data()
    raw.pick_types(eeg=True, meg=False, eog=False, ecg=False, stim=False)
    X = raw.get_data(picks="eeg").astype(np.float32)          # (C,T)
    mean, std = X.mean(1, keepdims=True), X.std(1, keepdims=True) + 1e-6
    X = np.nan_to_num((X - mean) / std)
    np.save(cache_path, X)
    return X

def make_window(x_ct, center_s, sfreq=TARGET_SFREQ, win_sec=2.0):
    """center Ïù¥Ï†Ñ win_sec Íµ¨Í∞ÑÏùÑ ÏûòÎùº [C,Tw] Î∞òÌôò(Î∂ÄÏ°±ÌïòÎ©¥ Ï¢åÏ∏° Ìå®Îî©)."""
    t1 = int(center_s * sfreq)
    Tw = int(win_sec * sfreq)
    t0 = max(0, t1 - Tw)
    seg = x_ct[:, t0:t1]
    need = Tw - seg.shape[1]
    if need > 0:
        seg = np.pad(seg, ((0, 0), (need, 0)), mode="constant")
    return seg.astype(np.float32)

# ============================================================
# 3) CCD trial parser (correct only: feedback==smiley)
# ============================================================
def extract_ccd_trials(df):
    if df.empty or "onset" not in df.columns or "value" not in df.columns:
        return []
    trials = []
    on  = df["onset"].astype(float).values
    val = df["value"].astype(str).values
    fb  = df["feedback"].astype(str).values if "feedback" in df.columns else ["n/a"] * len(df)
    starts  = [i for i,v in enumerate(val) if "contrastTrial_start" in v]
    presses = [i for i,v in enumerate(val) if "buttonPress" in v]
    for ti in starts:
        t0 = on[ti]
        later = [pi for pi in presses if on[pi] > t0]
        if not later: 
            continue
        pi = later[0]
        rt = (on[pi]-t0) * 1000.0
        if 100 <= rt <= 3000 and "smiley" in fb[pi].lower():
            trials.append((t0, rt))
    return trials

# ============================================================
# 4) Datasets
# ============================================================
class SusPretrainDataset(Dataset):
    """
    Ïù¥Î≤§Ìä∏ ÏóÜÏù¥ SuS ÌååÏùºÏóêÏÑú ÏúàÎèÑÏö∞Î•º Í∑úÏπôÏ†Å/ÎûúÎç§ Ï∂îÏ∂úÌï¥ Îëê viewÎ°ú Î∞òÌôò.
    """
    def __init__(self, eeg_files, win_s=WIN_S_SUS, stride_s=STRIDE_S_SUS, random_start=True):
        self.items = []  # (eeg_path, center_s)
        self.win_s = win_s; self.stride_s = stride_s; self.random_start = random_start
        for p,_ in eeg_files:
            X = cached_load_eeg(p)
            T = X.shape[1]
            Tw = int(win_s * TARGET_SFREQ)
            stride = int(stride_s * TARGET_SFREQ)
            centers = []
            if T > Tw:
                # ÏÑºÌÑ∞Î•º stride Í∞ÑÍ≤©ÏúºÎ°ú Ï†Ñ ÌååÏùºÏóê ÍπîÍ∏∞
                for t1 in range(Tw, T, stride):
                    centers.append(t1 / TARGET_SFREQ)
            # ÌååÏùº ÎÇ¥ ÏµúÏÜå Î≥¥Ïû• ÏÉòÌîå Ïàò
            if len(centers) == 0 and T >= Tw:
                centers = [Tw / TARGET_SFREQ]
            for c in centers:
                self.items.append((p, c))
        random.shuffle(self.items)

    def __len__(self): return len(self.items)

    @staticmethod
    def _augment(x):
        # Í∞ÑÎã®Ìïú Ï¶ùÍ∞ï: Í∞ÄÏö∞ÏãúÏïà ÎÖ∏Ïù¥Ï¶à + ÌÉÄÏûÑÎßàÏä§ÌÅ¨ + Ï±ÑÎÑê ÎìúÎ°≠
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)
        if np.random.rand() < 0.5:
            L = max(1, int(x.shape[1]*0.1))
            s = np.random.randint(0, x.shape[1]-L+1)
            x[:, s:s+L] = 0.0
        if np.random.rand() < 0.5:
            drop = max(1, int(x.shape[0]*0.05))
            idx = np.random.choice(x.shape[0], drop, replace=False)
            x[idx] = 0.0
        return x

    def __getitem__(self, idx):
        p, c = self.items[idx]
        X = cached_load_eeg(p)
        seg = make_window(X, c, win_sec=self.win_s)
        v1 = self._augment(seg.copy())
        v2 = self._augment(seg.copy())
        return torch.from_numpy(v1), torch.from_numpy(v2), torch.zeros(1)

class CcdRtDataset(Dataset):
    def __init__(self, eeg_event_pairs, win_s=WIN_S_CCD):
        self.samples = []  # (eeg_path, onset_s, rt_ms)
        for eeg_path, ev_path in eeg_event_pairs:
            if not os.path.exists(ev_path): 
                continue
            df = pd.read_csv(ev_path, sep="\t")
            for o, rt in extract_ccd_trials(df):
                self.samples.append((eeg_path, o, rt))
        
        # --- (ÏàòÏ†ï) RT Í∞íÏùò ÌèâÍ∑†Í≥º ÌëúÏ§ÄÌé∏Ï∞® Í≥ÑÏÇ∞ ---
        all_rts = np.array([s[2] for s in self.samples]).astype(np.float32)
        self.rt_mean = all_rts.mean()
        self.rt_std = all_rts.std() + 1e-6 # 0ÏúºÎ°ú ÎÇòÎà†ÏßÄÎäî Í≤É Î∞©ÏßÄ
        print(f"‚úÖ CCD Dataset: {len(self.samples)} trials. RT(ms) Mean={self.rt_mean:.2f}, Std={self.rt_std:.2f}")
        # ------------------------------------------
        
        self.win_s = win_s

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        p, o, rt = self.samples[idx]
        X = cached_load_eeg(p)
        seg = make_window(X, o, win_sec=self.win_s)
        
        # --- (ÏàòÏ†ï) RTÎ•º Z-scoreÎ°ú Ï†ïÍ∑úÌôî ---
        rt_normalized = (rt - self.rt_mean) / self.rt_std
        # -----------------------------------
        
        return torch.from_numpy(seg), torch.tensor([rt_normalized], dtype=torch.float32) # Ï†ïÍ∑úÌôîÎêú Í∞í Î∞òÌôò

# ============================================================
# 5) EEGConformer encoder (safe wrapper)
# ============================================================
from braindecode.models import EEGConformer

class SafeEEGConformerEncoder(nn.Module):
    """
    EEGConformer Î≤ÑÏ†ÑÎ≥Ñ ÏÉùÏÑ±Ïûê Ï∞®Ïù¥ ÏûêÎèô ÎåÄÏùë + Ï∂úÎ†• flatten.
    """
    def __init__(self, n_chans, sfreq, input_window_samples):
        super().__init__()
        last_err = None
        trials = [
            dict(n_chans=n_chans, n_outputs=1, n_times=input_window_samples,              sfreq=sfreq, return_features=True),
            dict(n_chans=n_chans, n_outputs=1, input_window_samples=input_window_samples, sfreq=sfreq, return_features=True),
            dict(n_chans=n_chans, n_outputs=1, n_times=input_window_samples,              sfreq=sfreq),
        ]
        for kw in trials:
            try:
                self.backbone = EEGConformer(**kw)
                break
            except TypeError as e:
                last_err = e
        if not hasattr(self, "backbone"):
            raise TypeError(f"EEGConformer init failed. Last error: {last_err}")

    def forward(self, x):
        z = self.backbone(x)
        if isinstance(z, tuple): z = z[0]
        return torch.flatten(z, 1)

class ContrastiveHead(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim//2),
            nn.ReLU(),
            nn.Linear(in_dim//2, proj_dim)
        )
    def forward(self, x):
        if x.ndim > 2: x = torch.flatten(x, 1)
        return self.net(x)


# ============================================================
# 6) Heads & Losses
# ============================================================
class RtHead(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, max(64, feat_dim//2)),
            nn.ReLU(),
            nn.Linear(max(64, feat_dim//2), 1)
        )
    def forward(self, z):
        if z.ndim > 2: z = torch.flatten(z, 1)
        return self.mlp(z)

def nt_xent_loss(z1, z2, temperature: float = 0.5):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    logits12 = torch.matmul(z1, z2.T) / temperature   # (N, N)
    logits21 = torch.matmul(z2, z1.T) / temperature   # (N, N)

    # ÏïàÏ†ïÌôî
    logits12 = logits12 - logits12.max(dim=1, keepdim=True).values
    logits21 = logits21 - logits21.max(dim=1, keepdim=True).values

    labels = torch.arange(z1.size(0), device=z1.device)  # diagÍ∞Ä positive
    loss = (F.cross_entropy(logits12, labels) + F.cross_entropy(logits21, labels)) / 2
    return loss


# ============================================================
# 7) Train loops
# ============================================================
def train_pretrain_sus(dl, encoder, epochs=EPOCHS_SUS, lr=LR_SUS):
    # feature projection Ï∂îÍ∞Ä
    with torch.no_grad():
        dummy, _, _ = next(iter(dl))
        feat_dim = encoder(dummy[:1].float().to(DEVICE)).shape[1]
    proj_head = nn.Sequential(
        nn.Linear(feat_dim, feat_dim // 2),
        nn.ReLU(),
        nn.Linear(feat_dim // 2, 128)
    ).to(DEVICE)

    opt = torch.optim.Adam(list(encoder.parameters()) + list(proj_head.parameters()), lr=lr)
    for ep in range(epochs):
        encoder.train(); proj_head.train(); losses=[]
        for x1, x2, _ in tqdm(dl, desc=f"[SuS pretrain] {ep+1}/{epochs}"):
            x1, x2 = x1.float().to(DEVICE), x2.float().to(DEVICE)
            z1, z2 = encoder(x1), encoder(x2)
            p1, p2 = proj_head(z1), proj_head(z2)
            loss = nt_xent_loss(p1, p2, temperature=2.0)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(loss.item())
        print(f"Epoch {ep+1}: contrastive loss={np.mean(losses):.4f}")
    return encoder


def train_ccd_rt(dl_tr, encoder, rt_head, epochs=EPOCHS_CCD, lr=LR_CCD):
    opt = torch.optim.Adam(list(encoder.parameters()) + list(rt_head.parameters()), lr=lr)
    for ep in range(epochs):
        encoder.train(); rt_head.train(); losses=[]
        for x, y in tqdm(dl_tr, desc=f"[CCD train] {ep+1}/{epochs}"):
            x, y = x.float().to(DEVICE), y.to(DEVICE)
            yhat = rt_head(encoder(x))
            loss = nn.functional.l1_loss(yhat, y)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(loss.item())
        print(f"Epoch {ep+1}: MAE={np.mean(losses):.3f} ms")

# ============================================================
# 8) Main
# ============================================================
def main():
    # --- 1) SuS pretraining ---
    all_preproc_files = collect_preprocessed_files(PREPROCESSED_ROOT)

    # task Ïù¥Î¶ÑÎ≥ÑÎ°ú ÌïÑÌÑ∞ÎßÅ
    sus_like_files = [
        (p, "") for (p, _) in all_preproc_files
        if "contrastchangedetection" not in p.lower()
    ]

    print(f"[INFO] Pretrain files (excluding CCD): {len(sus_like_files)}")

    ds_sus = SusPretrainDataset(
        sus_like_files,
        win_s=WIN_S_SUS,
        stride_s=STRIDE_S_SUS
    )
    dl_sus = DataLoader(
        ds_sus,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    # encoder Ï¥àÍ∏∞Ìôî (SuS Î∞∞ÏπòÎ°ú C,T Ï∂îÏ†ï)
    x_demo, _, _ = next(iter(dl_sus))
    _, C, T = x_demo.shape
    encoder = SafeEEGConformerEncoder(C, TARGET_SFREQ, T).to(DEVICE)

    # pretrain
    encoder = train_pretrain_sus(dl_sus, encoder, epochs=EPOCHS_SUS, lr=LR_SUS)

    # --- 2) CCD fine-tuning (RT) ---
    ccd_eeg_files = collect_preprocessed_files(PREPROCESSED_ROOT, task_name="contrastChangeDetection")
    matched_pairs = match_eeg_to_event(ccd_eeg_files, BIDS_ROOT)
    ds_ccd = CcdRtDataset(matched_pairs, win_s=WIN_S_CCD)
    dl_ccd = DataLoader(ds_ccd, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

    # head ÎßåÎì§Í≥† ÌïôÏäµ
    with torch.no_grad():
        feat_dim = encoder(x_demo[:1].float().to(DEVICE)).shape[1]
    print(f"[INFO] Encoder feature dim: {feat_dim}")
    rt_head = RtHead(feat_dim).to(DEVICE)

    train_ccd_rt(dl_ccd, encoder, rt_head, epochs=EPOCHS_CCD, lr=LR_CCD)

    # ÌïôÏäµ Ï¢ÖÎ£å ÌõÑ encoder + rt_head Ï†ÄÏû•
    torch.save({
        "encoder": encoder.state_dict(),
        "rt_head": rt_head.state_dict()
    }, "weights_all_task_ch1.pth")

    print("‚úÖ Saved Challenge 1 weights to weights_all_task_ch1.pth")


if __name__ == "__main__":
    main()


[INFO] Found 25944 preprocessed EEG files (all)
[INFO] Pretrain files (excluding CCD): 20558


[SuS pretrain] 1/5:  37%|‚ñà‚ñà‚ñà‚ñã      | 31594/86084 [19:32:21<71:01:55,  4.69s/it] 