In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [3]:
# ============================================================
# CBraMod (from CH2) â†’ CCD RT downstream (CH1) with
# temporal context, reaction-phase weighting, norm alignment,
# and multi-task auxiliary (RT tertile classification).
# Requires: 100 Hz .npy cache for EEG (.set -> *_cached.npy)
# ============================================================

import os, random, json, warnings
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

warnings.filterwarnings("ignore")

# ------------------
# Config
# ------------------
BIDS_ROOT         = "/data5/open_data/HBN/EEG_BIDS/"
PREPROCESSED_ROOT = "/data5/open_data/HBN/Preprocessed_EEG/0922try_bySubject/"
CACHE_DIR         = "/data5/open_data/HBN/cache_eeg_100hz_noref"   # where *_cached.npy lives
CH2_CKPT_PATH     = "/home/RA/EEG_Challenge/Challenge2/best_cbramod_cached_finetune.pth"  # âœ… your CH2 weights

TARGET_SFREQ = 100
BATCH_SIZE   = 64
NUM_WORKERS  = 4
EPOCHS       = 10
LR           = 1e-3
WD           = 1e-4
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
SEED         = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# Temporal context (seconds): pre + capped reaction + post
PRE_CTX_S     = 1.0
REACT_CAP_S   = 1.0   # include up to 1s of reaction (onsetâ†’press)
POST_CTX_S    = 0.5
TOTAL_LEN_S   = PRE_CTX_S + REACT_CAP_S + POST_CTX_S  # = 2.5s
TOTAL_SAMPLES = int(TOTAL_LEN_S * TARGET_SFREQ)       # 250

# Weighting for reaction length in the included segment
WEIGHT_GAMMA  = 1.0   # w = 1 + GAMMA * (min(rt, REACT_CAP_S)/REACT_CAP_S)

# Multi-task auxiliary
AUX_NUM_CLASSES = 3   # fast / medium / slow tertiles
AUX_LAMBDA      = 0.2

# ============================================================
# 1) Files and events
# ============================================================
def collect_preprocessed_files(root_path, task_name=None):
    results = []
    for dirpath, _, filenames in os.walk(root_path):
        if "sub-" not in dirpath:
            continue
        for fn in filenames:
            low = fn.lower()
            if not low.endswith("_eeg_pp.set"):
                continue
            if task_name and (task_name.lower() not in low):
                continue
            results.append(os.path.join(dirpath, fn))
    results = sorted(results)
    print(f"[INFO] Found {len(results)} preprocessed EEG files ({'task: '+task_name if task_name else 'all'})")
    return [(f, "") for f in results]

def collect_ccd_event_files(bids_root):
    ev_files = glob(os.path.join(
        bids_root, "ds*/sub-*", "eeg", "sub-*_task-contrastChangeDetection*_events.tsv"
    ))
    return ev_files

def match_eeg_to_event(preproc_files, bids_root):
    ev_files = collect_ccd_event_files(bids_root)
    ev_dict = {os.path.basename(ef).replace("_events.tsv", ""): ef for ef in ev_files}
    pairs = []
    for eeg_path, _ in preproc_files:
        base = os.path.basename(eeg_path)
        key = base.replace("_eeg_pp.set", "").replace(".set", "")
        if key in ev_dict:
            pairs.append((eeg_path, ev_dict[key]))
        else:
            key_no_run = key.replace("_run-1","").replace("_run-2","")
            matched = None
            for k,v in ev_dict.items():
                k_norm = k.replace("_run-1","").replace("_run-2","")
                if k_norm == key_no_run:
                    matched = v; break
            if matched:
                pairs.append((eeg_path, matched))
    print(f"ðŸ”— Matched {len(pairs)} CCD EEGâ†”event pairs.")
    return pairs

# ============================================================
# 2) Cache + window
# ============================================================
def cached_load_eeg(eeg_set_path):
    fname = os.path.basename(eeg_set_path).replace(".set", "_cached.npy")
    path  = os.path.join(CACHE_DIR, fname)
    if os.path.exists(path):
        return np.load(path)  # (C, T)
    return None

def make_context_window(x_ct, onset_s, rt_ms,
                        pre_s=PRE_CTX_S, react_cap_s=REACT_CAP_S, post_s=POST_CTX_S,
                        sfreq=TARGET_SFREQ):
    """
    Extract [onset - pre, onset + min(rt,cap) + post] as a fixed length segment (pad if needed).
    Returns seg(C, TOTAL_SAMPLES), react_fracâˆˆ[0,1].
    """
    Ttot = int((pre_s + react_cap_s + post_s) * sfreq)
    if x_ct is None:
        return np.zeros((128, Ttot), np.float32), 0.0

    onset = int(onset_s * sfreq)
    react_len = int(min(rt_ms/1000.0, react_cap_s) * sfreq)
    t0 = max(0, onset - int(pre_s * sfreq))
    t1_desired = onset + int(react_cap_s * sfreq) + int(post_s * sfreq)
    seg = x_ct[:, t0:t1_desired]
    need = Ttot - seg.shape[1]
    if need > 0:
        seg = np.pad(seg, ((0,0),(0,need)), mode="constant")
    else:
        seg = seg[:, :Ttot]

    react_frac = (react_len / int(react_cap_s * sfreq)) if react_cap_s > 0 else 0.0
    react_frac = float(np.clip(react_frac, 0.0, 1.0))
    return seg.astype(np.float32), react_frac

# ============================================================
# 3) Parse trials
# ============================================================
def extract_ccd_trials(df):
    if df is None or df.empty or "onset" not in df.columns or "value" not in df.columns:
        return []
    trials = []
    on  = df["onset"].astype(float).values
    val = df["value"].astype(str).values
    fb  = df["feedback"].astype(str).values if "feedback" in df.columns else ["n/a"] * len(df)
    starts  = [i for i,v in enumerate(val) if "contrastTrial_start" in v]
    presses = [i for i,v in enumerate(val) if "buttonPress" in v]
    for ti in starts:
        t0 = on[ti]
        later = [pi for pi in presses if on[pi] > t0]
        if not later: 
            continue
        pi = later[0]
        rt = (on[pi]-t0) * 1000.0
        if 100 <= rt <= 3000 and "smiley" in fb[pi].lower():
            trials.append((t0, rt))
    return trials

# ============================================================
# 4) Dataset (context window + weighting + tertile class)
# ============================================================
class CCDContextDataset(Dataset):
    def __init__(self, eeg_event_pairs):
        raw_samples = []
        for eeg_path, ev_path in tqdm(eeg_event_pairs, desc="Parse CCD events"):
            if not os.path.exists(ev_path): 
                continue
            df = pd.read_csv(ev_path, sep="\t")
            for onset, rt in extract_ccd_trials(df):
                raw_samples.append((eeg_path, onset, rt))

        # keep only samples with cache present
        self.samples = []
        for eeg_path, onset, rt in tqdm(raw_samples, desc="Verify cache"):
            if cached_load_eeg(eeg_path) is not None:
                self.samples.append((eeg_path, onset, float(rt)))
        if len(self.samples) == 0:
            raise RuntimeError("No valid CCD samples with cache.")

        # stats for z-score RT label and tertiles
        rts = np.array([rt for _,_,rt in self.samples], np.float32)
        self.rt_mean = float(rts.mean())
        self.rt_std  = float(rts.std() + 1e-6)
        self.t1, self.t2 = np.percentile(rts, [33.3, 66.6]).tolist()
        print(f"âœ… CCD samples: {len(self.samples)} | RT mean={self.rt_mean:.1f} ms, std={self.rt_std:.1f} ms, tertiles=({self.t1:.0f},{self.t2:.0f})")

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        eeg_path, onset_s, rt_ms = self.samples[idx]
        X = cached_load_eeg(eeg_path)  # (C,T), z-scored already
        seg, react_frac = make_context_window(X, onset_s, rt_ms)

        # labels
        rt_norm = (rt_ms - self.rt_mean) / self.rt_std
        # tertile class
        if rt_ms < self.t1: cls = 0
        elif rt_ms < self.t2: cls = 1
        else: cls = 2

        weight = 1.0 + WEIGHT_GAMMA * react_frac
        return (
            torch.from_numpy(seg),                    # (C, Ttot)
            torch.tensor([rt_norm], dtype=torch.float32),  # (1,)
            torch.tensor(cls, dtype=torch.long),      # ()
            torch.tensor([weight], dtype=torch.float32)    # (1,)
        )

# ============================================================
# 5) CBraMod encoder (same as CH2) + context/normalization
# ============================================================
class ChannelAffine(nn.Module):
    """Learnable per-channel scale/bias (alignment)."""
    def __init__(self, C):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C))
        self.beta  = nn.Parameter(torch.zeros(C))
    def forward(self, x):  # (B,C,T)
        return x * self.gamma[:,None] + self.beta[:,None]

class TemporalContextBlock(nn.Module):
    """Depthwise dilated conv pyramid to expand temporal RF before SincConv."""
    def __init__(self, C, dilations=(1,2,4), k=5):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(C, C, kernel_size=k, padding=d*(k//2), dilation=d, groups=C)
            for d in dilations
        ])
        self.proj = nn.Conv1d(C, C, kernel_size=1)
        self.act  = nn.ReLU()
    def forward(self, x):  # (B,C,T)
        y = 0
        for conv in self.convs:
            y = y + conv(x)
        y = self.proj(self.act(y))
        return x + y

class SincConv1d(nn.Module):
    def __init__(self, out_channels=64, kernel_size=129, sample_rate=100, min_hz=0.3, max_hz=45.0):
        super().__init__()
        self.out_channels = out_channels
        self.kernel_size  = kernel_size
        self.sample_rate  = sample_rate
        self.min_hz       = float(min_hz)
        self.max_hz       = float(max_hz)
        low  = torch.linspace(self.min_hz, self.max_hz - 5.0, out_channels)
        band = torch.ones(out_channels) * 5.0
        self.low_hz_  = nn.Parameter(low)
        self.band_hz_ = nn.Parameter(band)
        n = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1).float()
        self.register_buffer("n", n)

    def forward(self, x):  # x: (B, C, T)
        B, C, T = x.shape
        device, dtype = x.device, x.dtype
        low  = torch.clamp(torch.abs(self.low_hz_), min=self.min_hz, max=self.max_hz - 1.0)
        raw_high = low + torch.abs(self.band_hz_)
        min_v, max_v = low + 1.0, torch.full_like(low, self.max_hz)
        high = torch.clamp(raw_high, min=min_v, max=max_v)
        n = self.n.to(device=device, dtype=dtype)
        window = torch.hamming_window(self.kernel_size, periodic=False, dtype=dtype, device=device)
        nyq = self.sample_rate / 2.0
        filters = []
        for i in range(self.out_channels):
            f1, f2 = low[i]/nyq, high[i]/nyq
            h1 = 2 * f2 * torch.sinc(2 * f2 * n)
            h2 = 2 * f1 * torch.sinc(2 * f1 * n)
            bandpass = (h1 - h2) * window
            filters.append(bandpass)
        filt = torch.stack(filters, dim=0).unsqueeze(1)  # (out, 1, K)
        x_dw = x.view(B * C, 1, T)
        y = F.conv1d(x_dw, filt, stride=1, padding=self.kernel_size // 2)
        y = y.view(B, C, self.out_channels, y.shape[-1]).sum(dim=1)  # (B, out, T)
        return y

class SEBlock(nn.Module):
    def __init__(self, c, r=8):
        super().__init__()
        self.fc1, self.fc2 = nn.Linear(c, c//r), nn.Linear(c//r, c)
    def forward(self, x):
        s = x.mean(-1)
        e = torch.sigmoid(self.fc2(F.relu(self.fc1(s)))).unsqueeze(-1)
        return x * e

class CBraModBackbone(nn.Module):
    def __init__(self, out_dim=512):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(64,128,7,padding=3), nn.ReLU(),
            nn.Conv1d(128,256,5,padding=2), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Linear(256,out_dim)
        self.out_dim = out_dim
    def forward(self,x):
        return self.fc(self.conv(x).squeeze(-1))

class EncoderCBraMod(nn.Module):
    """Input alignment â†’ temporal context â†’ sinc â†’ SE â†’ backbone â†’ 512-d feature"""
    def __init__(self, in_chans=128, sr=100, out_dim=512):
        super().__init__()
        self.affine = ChannelAffine(in_chans)
        self.bn     = nn.BatchNorm1d(in_chans, affine=True)
        self.ctx    = TemporalContextBlock(in_chans, dilations=(1,2,4), k=5)
        self.front  = SincConv1d(out_channels=64, kernel_size=129, sample_rate=sr)
        self.se     = SEBlock(64)
        self.backbone = CBraModBackbone(out_dim=out_dim)
        self.out_dim = out_dim

    def forward(self, x):  # (B,C,T)
        x = self.affine(x)
        x = self.bn(x)
        x = self.ctx(x)
        x = self.front(x)
        x = self.se(x)
        z = self.backbone(x)  # (B,512)
        return z

def load_ch2_weights_into_encoder(encoder: EncoderCBraMod, ckpt_path: str):
    if not os.path.exists(ckpt_path):
        print(f"[WARN] CH2 ckpt not found: {ckpt_path}")
        return
    sd = torch.load(ckpt_path, map_location="cpu")
    # keys: "front", "backbone", maybe "head"
    missing=[]
    if "front" in sd: encoder.front.load_state_dict(sd["front"], strict=False)
    else: missing.append("front")
    if "backbone" in sd: encoder.backbone.load_state_dict(sd["backbone"], strict=False)
    else: missing.append("backbone")
    if missing:
        print(f"[WARN] Missing parts in ckpt: {missing}")
    else:
        print(f"[INFO] Loaded CH2 weights from {ckpt_path}")

# ============================================================
# 6) Heads + loss
# ============================================================
class RtHead(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, max(64, feat_dim//2)), nn.ReLU(),
            nn.Linear(max(64, feat_dim//2), 1)
        )
    def forward(self, z): return self.mlp(z)  # (B,1)

class AuxClsHead(nn.Module):
    def __init__(self, feat_dim, n_classes=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, max(64, feat_dim//2)), nn.ReLU(),
            nn.Linear(max(64, feat_dim//2), n_classes)
        )
    def forward(self, z): return self.net(z)  # (B,K)

def weighted_mae(pred, target, weight):  # pred,target (B,1), weight (B,1)
    return (weight * (pred - target).abs()).mean()

# ============================================================
# 7) Train / Eval
# ============================================================
@torch.no_grad()
def evaluate(dl, encoder, rt_head, aux_head):
    encoder.eval(); rt_head.eval(); aux_head.eval()
    preds, trues = [], []
    for x, y, cls, w in dl:
        x = x.to(DEVICE).float()
        y = y.to(DEVICE).float()
        z = encoder(x)
        yhat = rt_head(z)
        preds.append(yhat.squeeze(1).cpu().numpy())
        trues.append(y.squeeze(1).cpu().numpy())
    preds = np.concatenate(preds); trues = np.concatenate(trues)
    mae = np.mean(np.abs(preds - trues))
    mse = np.mean((preds - trues)**2)
    nrmse = np.sqrt(mse) / (trues.std() + 1e-8)
    r2 = 1 - np.sum((preds-trues)**2) / np.sum((trues-trues.mean())**2 + 1e-8)
    return dict(MAE=float(mae), MSE=float(mse), NRMSE=float(nrmse), R2=float(r2))

def train(train_dl, val_dl, encoder, rt_head, aux_head,
          epochs=EPOCHS, lr=LR, wd=WD, aux_lambda=AUX_LAMBDA):
    params = list(encoder.parameters()) + list(rt_head.parameters()) + list(aux_head.parameters())
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
    best = {"NRMSE": 9e9, "epoch": -1}

    for ep in range(1, epochs+1):
        encoder.train(); rt_head.train(); aux_head.train()
        losses=[]
        for x, y, cls, w in tqdm(train_dl, desc=f"[Train] ep{ep}/{epochs}", leave=False):
            x = x.to(DEVICE).float()
            y = y.to(DEVICE).float()      # (B,1) normalized RT
            cls = cls.to(DEVICE).long()   # (B,)
            w = w.to(DEVICE).float()      # (B,1)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
                z = encoder(x)                 # (B,512)
                yhat = rt_head(z)             # (B,1)
                logits = aux_head(z)          # (B,K)
                loss_reg = weighted_mae(yhat, y, w)
                loss_aux = F.cross_entropy(logits, cls)
                loss = loss_reg + aux_lambda * loss_aux

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            losses.append(loss.item())

        valm = evaluate(val_dl, encoder, rt_head, aux_head)
        print(f"[ep{ep:02d}] train={np.mean(losses):.4f} | "
              f"val MAE={valm['MAE']:.4f} MSE={valm['MSE']:.4f} "
              f"NRMSE={valm['NRMSE']:.4f} R2={valm['R2']:.3f}")

        if valm["NRMSE"] < best["NRMSE"]:
            best = {**valm, "epoch": ep}
            torch.save({
                "encoder": encoder.state_dict(),
                "rt_head": rt_head.state_dict(),
                "aux_head": aux_head.state_dict(),
                "val_metrics": valm
            }, "cbramod_ssl_ctx_react_aux_best.pth")
            with open("cbramod_ssl_ctx_react_aux_best.json","w") as f:
                json.dump(best, f, indent=2)
            print("  âœ… Saved checkpoint: cbramod_ssl_ctx_react_aux_best.pth")
    print("Best:", best)

# ============================================================
# 8) Main
# ============================================================
def main():
    # a) build CCD dataset from cached files
    ccd_files = collect_preprocessed_files(PREPROCESSED_ROOT, task_name="contrastChangeDetection")
    pairs = match_eeg_to_event(ccd_files, BIDS_ROOT)
    ds = CCDContextDataset(pairs)

    # b) split train/val quickly (files randomized; for strict subject split replace with subj-based)
    n_total = len(ds)
    n_val = max(64, int(0.1*n_total))
    n_tr  = n_total - n_val
    train_ds, val_ds = random_split(ds, [n_tr, n_val], generator=torch.Generator().manual_seed(SEED))
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
    val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=max(1,NUM_WORKERS//2), pin_memory=True)

    # c) model
    encoder = EncoderCBraMod(in_chans=128, sr=TARGET_SFREQ, out_dim=512).to(DEVICE)
    load_ch2_weights_into_encoder(encoder, CH2_CKPT_PATH)
    rt_head  = RtHead(encoder.out_dim).to(DEVICE)
    aux_head = AuxClsHead(encoder.out_dim, n_classes=AUX_NUM_CLASSES).to(DEVICE)

    # d) train
    train(train_dl, val_dl, encoder, rt_head, aux_head, epochs=EPOCHS, lr=LR, wd=WD, aux_lambda=AUX_LAMBDA)

    # e) export final (last state too)
    torch.save({
        "encoder": encoder.state_dict(),
        "rt_head": rt_head.state_dict(),
        "aux_head": aux_head.state_dict()
    }, "cbramod_ssl_ctx_react_aux_last.pth")
    print("Saved: cbramod_ssl_ctx_react_aux_last.pth")

if __name__ == "__main__":
    main()


[INFO] Found 5386 preprocessed EEG files (task: contrastChangeDetection)
ðŸ”— Matched 5386 CCD EEGâ†”event pairs.


Parse CCD events: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5386/5386 [01:35<00:00, 56.21it/s] 
Verify cache: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 341/341 [00:09<00:00, 36.16it/s]


âœ… CCD samples: 341 | RT mean=2863.3 ms, std=168.6 ms, tertiles=(2870,2940)
[INFO] Loaded CH2 weights from /home/RA/EEG_Challenge/Challenge2/best_cbramod_cached_finetune.pth


                                                             

[ep01] train=1.3907 | val MAE=0.4946 MSE=1.0774 NRMSE=0.9999 R2=0.000
  âœ… Saved checkpoint: cbramod_ssl_ctx_react_aux_best.pth


                                                             

[ep02] train=1.3861 | val MAE=0.4739 MSE=1.0866 NRMSE=1.0042 R2=-0.008


                                                             

[ep03] train=1.3624 | val MAE=0.4857 MSE=1.0856 NRMSE=1.0038 R2=-0.008


                                                             

[ep04] train=1.3127 | val MAE=0.4917 MSE=1.0928 NRMSE=1.0071 R2=-0.014


                                                             

[ep05] train=1.3413 | val MAE=0.4816 MSE=1.0965 NRMSE=1.0088 R2=-0.018


                                                             

[ep06] train=1.3863 | val MAE=0.4799 MSE=1.0968 NRMSE=1.0089 R2=-0.018


                                                             

[ep07] train=1.3763 | val MAE=0.5028 MSE=1.1051 NRMSE=1.0128 R2=-0.026


                                                             

[ep08] train=1.3669 | val MAE=0.4989 MSE=1.1047 NRMSE=1.0125 R2=-0.025


                                                             

[ep09] train=1.3232 | val MAE=0.4967 MSE=1.1035 NRMSE=1.0120 R2=-0.024


                                                              

[ep10] train=1.3744 | val MAE=0.4932 MSE=1.0987 NRMSE=1.0098 R2=-0.020
Best: {'MAE': 0.49458539485931396, 'MSE': 1.0773568153381348, 'NRMSE': 0.9999401571923082, 'R2': 0.00011962652206420898, 'epoch': 1}
Saved: cbramod_ssl_ctx_react_aux_last.pth
