In [1]:
# ============================================================
# EEGConformer Masked Reconstruction (SSL)
# - 129ch Ï∫êÏãú ‚Üí Cz Ï†úÍ±∞(128ch)
# - Block masking (Ïó∞ÏÜç Íµ¨Í∞Ñ Í∞ÄÎ¶º)
# - 2-layer Conv1d decoder
# - Î≥µÏõêÏùÄ (B,C,T) Ï†ÑÏ≤¥, ÏÜêÏã§ÏùÄ 'Í∞ÄÎ†§ÏßÑ Íµ¨Í∞Ñ'Îßå
# ============================================================

import os, random, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from braindecode.models import EEGConformer

# -----------------------------
# CONFIG
# -----------------------------
CACHE_DIR   = "/data5/open_data/HBN/cache_eeg_windows"
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ       = 100
N_CHANS     = 128           # Cz Ï†úÍ±∞ ÌõÑ
CROP_T      = 200           # SSLÏùÄ ÏßßÍ≤å (Îç∞Ïù¥ÌÑ∞ Ïàò ‚Üë)
MASK_RATIO  = 0.30          # Ï†ÑÏ≤¥ ÏãúÍ∞ÑÏùò 30%Î•º Í∞ÄÎ¶º(Ïó∞ÏÜç Î∏îÎ°ù)
BLOCK_LEN   = 32            # Î∏îÎ°ùÎãπ Í∏∏Ïù¥(ÏÉòÌîå) ‚Äî 100Hz Í∏∞Ï§Ä 0.32s
BATCH_SIZE  = 16
EPOCHS      = 50
LR          = 3e-4
WD          = 1e-4
SEED        = 42
NUM_WORKERS = 4
EMBED_DIM   = 256           # EEGConformer latent dim
GRAD_CLIP   = 1.0

def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed()

# ============================================================
# 1) BLOCK MASKING
# ============================================================
def make_block_mask(T, mask_ratio=0.3, block_len=32):
    """
    True=Î≥¥Ïù¥Îäî ÏúÑÏπò, False=Í∞ÄÎ¶¨Îäî ÏúÑÏπò (Ïó∞ÏÜç Î∏îÎ°ù)
    Î∏îÎ°ùÏù¥ ÏòÅÏÉÅÏ≤òÎüº Í≤πÏπ† Ïàò ÏûàÏßÄÎßå, ÎπÑÏú®ÏùÑ ÎåÄÎûµ ÎßåÏ°±
    """
    keep = np.ones(T, dtype=bool)
    target_mask = int(T * mask_ratio)
    if target_mask <= 0:
        return keep
    # ÏµúÏÜå 1Î∏îÎ°ù
    n_blocks = max(1, target_mask // block_len)
    for _ in range(n_blocks):
        start = np.random.randint(0, max(1, T - block_len + 1))
        keep[start:start+block_len] = False
    # ÌïÑÏöî Ïãú ÏûîÏó¨ Î∂ÄÎ∂Ñ ÎûúÎç§ Î≥¥ÏôÑ(Ï°∞Í∏à Îçî Í∞ÄÎ¶¥ Ïàò ÏûàÏùå)
    masked_now = (~keep).sum()
    while masked_now < target_mask:
        idx = np.random.randint(0, T)
        if keep[idx]:
            keep[idx] = False
            masked_now += 1
    return keep

# ============================================================
# 2) DATASET
# ============================================================
class MaskedEEGDataset(Dataset):
    """
    - Ï∫êÏãú ÌååÏùºÏóêÏÑú ÎûúÎç§ ÏÑ∏Í∑∏Î®ºÌä∏ ÌïòÎÇò ÏÑ†ÌÉù ‚Üí Cz Ï†úÍ±∞(129->128)
    - ÎûúÎç§ ÌÅ¨Î°≠(CROP_T), ÏßßÏúºÎ©¥ zero-pad
    - per-channel z-score
    - block keep-mask ÏÉùÏÑ± (True=Î≥¥ÏûÑ, False=Í∞ÄÎ¶º)
    """
    def __init__(self, cache_dir, crop_len=200, mask_ratio=0.3, block_len=32):
        self.files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir)
                      if f.endswith(".npy")]
        self.crop_len = crop_len
        self.mask_ratio = mask_ratio
        self.block_len = block_len

    def __len__(self): return len(self.files)

    def __getitem__(self, i):
        arr = np.load(self.files[i], mmap_mode="r")               # (segments, C, T)
        seg = arr[random.randint(0, arr.shape[0]-1)]              # (C, T)

        # ‚úÖ Cz Ï†úÍ±∞ (129 ‚Üí 128)
        if seg.shape[0] == 129:
            seg = seg[:-1, :]

        C, T = seg.shape
        # ‚úÖ Í∏∏Ïù¥ Î∂ÄÏ°± Ïãú zero-pad
        if T < self.crop_len:
            pad = np.zeros((C, self.crop_len), dtype=np.float32)
            pad[:, :T] = seg
            seg = pad
            T = self.crop_len

        # ‚úÖ ÏïàÏ†ÑÌïú ÌÅ¨Î°≠
        start = random.randint(0, T - self.crop_len) if T > self.crop_len else 0
        x = seg[:, start:start + self.crop_len]                   # (C, crop_len)

        # ‚úÖ per-channel z-score
        x = (x - x.mean(axis=1, keepdims=True)) / (x.std(axis=1, keepdims=True) + 1e-6)

        # ‚úÖ block keep-mask (1, T)
        keep = make_block_mask(self.crop_len, mask_ratio=self.mask_ratio, block_len=self.block_len)
        keep = torch.tensor(keep, dtype=torch.bool).unsqueeze(0)  # (1, T)

        x = torch.tensor(x, dtype=torch.float32)                  # (C, T)
        return x, keep

# ============================================================
# 3) MODEL: Encoder + 2-layer Decoder
# ============================================================
class EEGMaskedAutoencoder(nn.Module):
    """
    - Encoder: EEGConformer(n_outputs=embed_dim)
      -> (B, embed_dim, T_enc) ÎòêÎäî (B, embed_dim)
    - Decoder: 2-layer Conv1d (embed_dim -> 2*embed_dim -> C)
    - ÏûÖÎ†•:  x: (B, C, T), keep: (B, 1, T)  (True=Î≥¥ÏûÑ, False=Í∞ÄÎ¶º)
    - Ï∂úÎ†•:  x_hat: (B, C, T), ÏÜêÏã§ÏùÄ ~keep(Í∞ÄÎ†§ÏßÑ Íµ¨Í∞Ñ)ÏóêÏÑúÎßå Í≥ÑÏÇ∞
    """
    def __init__(self, n_chans=128, n_times=200, sfreq=100, embed_dim=256):
        super().__init__()
        self.encoder = EEGConformer(
            n_chans=n_chans,
            n_outputs=embed_dim,   # latent channels
            n_times=n_times,
            sfreq=sfreq
        )
        self.decoder = nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim * 2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(embed_dim * 2, n_chans, kernel_size=1)
        )

    def forward(self, x, keep_bool):
        # x: (B, C, T), keep_bool: (B, 1, T)
        B, C, T = x.shape

        # ‚úÖ Í∞ÄÎ†§ÏßÑ Í≥≥ zero-out (Î∏åÎ°úÎìúÏ∫êÏä§ÌåÖ)
        visible = x * keep_bool.float()                           # (B,C,T) * (B,1,T)

        # Encoder
        z = self.encoder(visible)                                 # (B, embed_dim, T_enc) or (B, embed_dim)
        if z.ndim == 2:
            z = z.unsqueeze(-1)                                   # (B, embed_dim, 1)

        # üîß ÏãúÍ∞Ñ Í∏∏Ïù¥ Î≥¥Ï†ï (ÌïÑÏöî Ïãú)
        T_enc = z.shape[-1]
        if T_enc != T:
            z = F.interpolate(z, size=T, mode="linear", align_corners=False)

        # 2-layer decoder
        x_hat = self.decoder(z)                                   # (B, C, T)
        return x_hat

def masked_mse(pred, target, keep_bool, eps=1e-8):
    """
    pred/target: (B,C,T), keep_bool: (B,1,T) (True=Î≥¥ÏûÑ)
    ÏÜêÏã§ÏùÄ 'Í∞ÄÎ†§ÏßÑ Íµ¨Í∞Ñ'(~keep)Îßå ÌèâÍ∑†
    """
    masked = (~keep_bool).float()                                 # (B,1,T)
    diff2 = (pred - target) ** 2                                  # (B,C,T)
    num = (diff2 * masked).sum()
    den = masked.sum() * pred.size(1) + eps                       # (maskÎêú ÌÉÄÏûÑÏàò √ó Ï±ÑÎÑêÏàò)
    return num / den

# ============================================================
# 4) TRAIN
# ============================================================
def main():
    ds = MaskedEEGDataset(CACHE_DIR, crop_len=CROP_T, mask_ratio=MASK_RATIO, block_len=BLOCK_LEN)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, pin_memory=True)

    model = EEGMaskedAutoencoder(n_chans=N_CHANS, n_times=CROP_T, sfreq=SFREQ, embed_dim=EMBED_DIM).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

    for ep in range(1, EPOCHS + 1):
        model.train(); total = 0.0
        for x, keep in tqdm(loader, desc=f"SSL Masked Ep{ep}"):
            x = x.to(DEVICE, non_blocking=True)                   # (B,C,T)
            keep = keep.to(DEVICE, non_blocking=True)             # (B,1,T)

            opt.zero_grad()
            x_hat = model(x, keep)
            loss = masked_mse(x_hat, x, keep)
            loss.backward()
            if GRAD_CLIP is not None:
                nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            opt.step()
            total += loss.item()

        sch.step()
        print(f"[Ep {ep:02d}] masked-recon MSE = {total / max(1,len(loader)):.5f}")

    # Encoder ÌååÎùºÎØ∏ÌÑ∞Îßå Ï†ÄÏû• ‚Üí downstreamÏóêÏÑú Î°úÎìú
    torch.save(model.encoder.state_dict(), "pretrained_eegconformer_masked_block2dec.pth")
    print("‚úÖ Saved: pretrained_eegconformer_masked_block2dec.pth")

if __name__ == "__main__":
    main()


  from pkg_resources import DefaultProvider, ResourceManager, \
SSL Masked Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.40it/s]


[Ep 01] masked-recon MSE = 0.97043


SSL Masked Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.01it/s]


[Ep 02] masked-recon MSE = 0.96906


SSL Masked Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.47it/s]


[Ep 03] masked-recon MSE = 0.97256


SSL Masked Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.51it/s]


[Ep 04] masked-recon MSE = 0.96907


SSL Masked Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.62it/s]


[Ep 05] masked-recon MSE = 0.97671


SSL Masked Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.28it/s]


[Ep 06] masked-recon MSE = 0.96710


SSL Masked Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.16it/s]


[Ep 07] masked-recon MSE = 0.97920


SSL Masked Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.43it/s]


[Ep 08] masked-recon MSE = 0.97800


SSL Masked Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.51it/s]


[Ep 09] masked-recon MSE = 0.98730


SSL Masked Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.05it/s]


[Ep 10] masked-recon MSE = 0.99754


SSL Masked Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.03it/s]


[Ep 11] masked-recon MSE = 0.98990


SSL Masked Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.86it/s]


[Ep 12] masked-recon MSE = 1.00225


SSL Masked Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.95it/s]


[Ep 13] masked-recon MSE = 0.95086


SSL Masked Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.94it/s]


[Ep 14] masked-recon MSE = 0.96132


SSL Masked Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.10it/s]


[Ep 15] masked-recon MSE = 0.98386


SSL Masked Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:06<00:00,  3.31it/s]


[Ep 16] masked-recon MSE = 0.96617


SSL Masked Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.12it/s]


[Ep 17] masked-recon MSE = 0.97987


SSL Masked Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.43it/s]


[Ep 18] masked-recon MSE = 1.00034


SSL Masked Ep19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:05<00:00,  3.46it/s]


[Ep 19] masked-recon MSE = 0.96902


SSL Masked Ep20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.11it/s]


[Ep 20] masked-recon MSE = 0.97084


SSL Masked Ep21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.45it/s]


[Ep 21] masked-recon MSE = 0.96118


SSL Masked Ep22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.50it/s]


[Ep 22] masked-recon MSE = 0.97007


SSL Masked Ep23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.78it/s]


[Ep 23] masked-recon MSE = 0.98180


SSL Masked Ep24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.66it/s]


[Ep 24] masked-recon MSE = 0.98498


SSL Masked Ep25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.22it/s]


[Ep 25] masked-recon MSE = 0.99075


SSL Masked Ep26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.08it/s]


[Ep 26] masked-recon MSE = 0.97790


SSL Masked Ep27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.35it/s]


[Ep 27] masked-recon MSE = 0.98531


SSL Masked Ep28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.17it/s]


[Ep 28] masked-recon MSE = 0.98068


SSL Masked Ep29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.24it/s]


[Ep 29] masked-recon MSE = 0.95852


SSL Masked Ep30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.84it/s]


[Ep 30] masked-recon MSE = 0.96674


SSL Masked Ep31: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.40it/s]


[Ep 31] masked-recon MSE = 0.97736


SSL Masked Ep32: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.26it/s]


[Ep 32] masked-recon MSE = 0.96081


SSL Masked Ep33: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.70it/s]


[Ep 33] masked-recon MSE = 0.98676


SSL Masked Ep34: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.80it/s]


[Ep 34] masked-recon MSE = 0.97514


SSL Masked Ep35: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.28it/s]


[Ep 35] masked-recon MSE = 0.98143


SSL Masked Ep36: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.68it/s]


[Ep 36] masked-recon MSE = 0.96509


SSL Masked Ep37: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.79it/s]


[Ep 37] masked-recon MSE = 0.96173


SSL Masked Ep38: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.44it/s]


[Ep 38] masked-recon MSE = 0.98116


SSL Masked Ep39: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:05<00:00,  3.97it/s]


[Ep 39] masked-recon MSE = 0.99127


SSL Masked Ep40: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:05<00:00,  3.55it/s]


[Ep 40] masked-recon MSE = 0.96241


SSL Masked Ep41: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.07it/s]


[Ep 41] masked-recon MSE = 0.93923


SSL Masked Ep42: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.14it/s]


[Ep 42] masked-recon MSE = 0.97954


SSL Masked Ep43: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.41it/s]


[Ep 43] masked-recon MSE = 0.95669


SSL Masked Ep44: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.28it/s]


[Ep 44] masked-recon MSE = 0.98024


SSL Masked Ep45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.74it/s]


[Ep 45] masked-recon MSE = 0.96839


SSL Masked Ep46: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.19it/s]


[Ep 46] masked-recon MSE = 0.94789


SSL Masked Ep47: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.12it/s]


[Ep 47] masked-recon MSE = 0.96438


SSL Masked Ep48: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  9.66it/s]


[Ep 48] masked-recon MSE = 0.96709


SSL Masked Ep49: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:01<00:00, 11.34it/s]


[Ep 49] masked-recon MSE = 0.96932


SSL Masked Ep50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.93it/s]

[Ep 50] masked-recon MSE = 0.95301
‚úÖ Saved: pretrained_eegconformer_masked_block2dec.pth





In [6]:
# ssl_pretrain_cbramod.py
import os, random, numpy as np, torch, sys
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# CBraMod import (Î†àÌè¨ ÌÅ¥Î°†Ìïú Í≤ΩÎ°úÏóêÏÑú Ïã§Ìñâ)
CBRAMOD_PATH = "/home/RA/EEG_Challenge/Challenge2/CBraMod"
if CBRAMOD_PATH not in sys.path:
    sys.path.append(CBRAMOD_PATH)
from models.cbramod import CBraMod
import torch.nn as nn

# ---------------- config ----------------
CACHE_DIR   = "/data5/open_data/HBN/cache_eeg_windows"
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ       = 100
N_CHANS_IN  = 128           # 129Î©¥ Cz Ï†úÍ±∞ÌïòÏó¨ 128 ÏÇ¨Ïö©
S, P        = 4, 200        # (time_segments, points_per_patch) -> 800 samples
BATCH_SIZE  = 16
EPOCHS      = 100
LR          = 3e-4
WD          = 1e-4
MASK_RATIO  = 0.30
BLOCK_LEN   = 32            # Ïó∞ÏÜç Í∞ÄÎ¶º Í∏∏Ïù¥(ÏÉòÌîå)
SEED        = 42

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

def make_block_mask(P_len, mask_ratio=0.3, block_len=32):
    """Ìïú Ìå®Ïπò Í∏∏Ïù¥ PÏóê ÎåÄÌï¥ True=Î≥¥ÏûÑ, False=Í∞ÄÎ¶º (Ïó∞ÏÜç Î∏îÎ°ù)"""
    keep = np.ones(P_len, dtype=bool)
    target = int(P_len * mask_ratio)
    n_blocks = max(1, target // block_len)
    for _ in range(n_blocks):
        st = np.random.randint(0, max(1, P_len - block_len + 1))
        keep[st:st+block_len] = False
    # Î∂ÄÏ°±Î∂Ñ Î≥¥ÏôÑ
    while (~keep).sum() < target:
        i = np.random.randint(0, P_len)
        keep[i] = False
    return keep  # (P,)

class HBNCBraModSSL(Dataset):
    """HBN cache ‚Üí (C,S,P), keep-mask: (1,S,P)"""
    def __init__(self, cache_dir, S=4, P=200):
        self.files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir) if f.endswith(".npy")]
        self.S, self.P = S, P
        self.win_len = S * P

    def __len__(self): return len(self.files)

    def __getitem__(self, i):
        arr = np.load(self.files[i], mmap_mode="r")        # (segments, C, T)
        seg = arr[np.random.randint(0, arr.shape[0])]      # (C, T)
        # Cz Ï†úÍ±∞ (129 -> 128)
        if seg.shape[0] == 129:
            seg = seg[:-1, :]
        C, T = seg.shape
        # Í∏∏Ïù¥ Î≥¥Ï†ï(>=win_len): ÎûúÎç§ ÌÅ¨Î°≠, (<win_len): zero-pad
        if T < self.win_len:
            pad = np.zeros((C, self.win_len), np.float32)
            pad[:, :T] = seg
            seg = pad
        else:
            st = np.random.randint(0, T - self.win_len + 1)
            seg = seg[:, st:st + self.win_len]

        # ÌëúÏ§ÄÌôî (Ï±ÑÎÑêÎ≥Ñ)
        seg = (seg - seg.mean(axis=1, keepdims=True)) / (seg.std(axis=1, keepdims=True) + 1e-6)
        # (C, S, P)
        x = torch.tensor(seg.reshape(C, self.S, self.P), dtype=torch.float32)

        # keep-mask ÎßåÎì§Í∏∞: Ìå®ÏπòÎßàÎã§ PÏ∂ï Î∏îÎ°ùÎßàÏä§ÌÇπ
        keep = np.stack([make_block_mask(self.P, MASK_RATIO, BLOCK_LEN) for _ in range(self.S)], axis=0)  # (S,P)
        keep = torch.tensor(keep, dtype=torch.bool).unsqueeze(0)  # (1, S, P)

        return x, keep

def masked_mse(pred, target, keep_bool, eps=1e-8):
    """
    pred/target: (B,C,S,P), keep_bool: (B,1,S,P)  (True=Î≥¥ÏûÑ)
    ÏÜêÏã§ÏùÄ 'Í∞ÄÎ†§ÏßÑ Íµ¨Í∞Ñ'(~keep)Îßå ÌèâÍ∑†
    """
    masked = (~keep_bool).float()                          # (B,1,S,P)
    diff2  = (pred - target) ** 2                          # (B,C,S,P)
    num = (diff2 * masked).sum()
    den = masked.sum() * pred.size(1) + eps               # (maskÎêú ÏãúÍ∞ÑÏàò√óÏ±ÑÎÑêÏàò)
    return num / den

def main():
    ds = HBNCBraModSSL(CACHE_DIR, S=S, P=P)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

    # ‚úÖ CBraMod Î°úÎìú: proj_out Ï†úÍ±∞Ìï¥ feature map (B,C,S,P) Î∞õÍ∏∞
    model = CBraMod().to(DEVICE)
    model.proj_out = nn.Identity()   # README Quick Start Í∂åÏû• Î∞©Ïãù :contentReference[oaicite:2]{index=2}

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

    for ep in range(1, EPOCHS+1):
        model.train(); running = 0.0
        for x, keep in tqdm(loader, desc=f"CBraMod-SSL Ep{ep}"):
            x    = x.to(DEVICE, non_blocking=True)              # (B,C,S,P)
            keep = keep.to(DEVICE, non_blocking=True)           # (B,1,S,P)
            # ÏûÖÎ†• Í∞ÄÎ¶¨Í∏∞
            visible = x * keep.float()

            opt.zero_grad()
            with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
                pred = model(visible)                           # (B,C,S,P) ‚Äî Quick StartÏôÄ ÎèôÏùºÌïú Ï∂úÎ†• Ìè¨Îß∑ :contentReference[oaicite:3]{index=3}
                loss = masked_mse(pred, x, keep)

            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            running += loss.item()

        print(f"[Ep {ep:02d}] masked MSE: {running/max(1,len(loader)):.5f}")

    # Ïù∏ÏΩîÎçî Ï†ÑÏ≤¥ Ï†ÄÏû•(Í∑∏ÎåÄÎ°ú Î°úÎìúÌï¥ÏÑú Îã§Ïö¥Ïä§Ìä∏Î¶ºÏóê ÏÇ¨Ïö©)
    torch.save(model.state_dict(), "cbramod_ssl_hbn.pth")
    print("‚úÖ Saved SSL checkpoint: cbramod_ssl_hbn.pth")

if __name__ == "__main__":
    main()


CBraMod-SSL Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.23it/s]


[Ep 01] masked MSE: 19.43482


CBraMod-SSL Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.56it/s]


[Ep 02] masked MSE: 7.34289


CBraMod-SSL Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.03it/s]


[Ep 03] masked MSE: 4.37954


CBraMod-SSL Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.30it/s]


[Ep 04] masked MSE: 3.11617


CBraMod-SSL Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.70it/s]


[Ep 05] masked MSE: 2.46341


CBraMod-SSL Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.71it/s]


[Ep 06] masked MSE: 2.06488


CBraMod-SSL Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.46it/s]


[Ep 07] masked MSE: 1.84002


CBraMod-SSL Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.78it/s]


[Ep 08] masked MSE: 1.65546


CBraMod-SSL Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.95it/s]


[Ep 09] masked MSE: 1.53078


CBraMod-SSL Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.33it/s]


[Ep 10] masked MSE: 1.45172


CBraMod-SSL Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.99it/s]


[Ep 11] masked MSE: 1.34525


CBraMod-SSL Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.49it/s]


[Ep 12] masked MSE: 1.29676


CBraMod-SSL Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.32it/s]


[Ep 13] masked MSE: 1.24967


CBraMod-SSL Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.10it/s]


[Ep 14] masked MSE: 1.18416


CBraMod-SSL Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.96it/s]


[Ep 15] masked MSE: 1.16991


CBraMod-SSL Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.18it/s]


[Ep 16] masked MSE: 1.14413


CBraMod-SSL Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.00it/s]


[Ep 17] masked MSE: 1.10171


CBraMod-SSL Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.17it/s]


[Ep 18] masked MSE: 1.08866


CBraMod-SSL Ep19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.73it/s]


[Ep 19] masked MSE: 1.05498


CBraMod-SSL Ep20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.38it/s]


[Ep 20] masked MSE: 1.02842


CBraMod-SSL Ep21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.99it/s]


[Ep 21] masked MSE: 0.99219


CBraMod-SSL Ep22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.98it/s]


[Ep 22] masked MSE: 0.98575


CBraMod-SSL Ep23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.44it/s]


[Ep 23] masked MSE: 0.95077


CBraMod-SSL Ep24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.54it/s]


[Ep 24] masked MSE: 0.95649


CBraMod-SSL Ep25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.20it/s]


[Ep 25] masked MSE: 0.91023


CBraMod-SSL Ep26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.18it/s]


[Ep 26] masked MSE: 0.94346


CBraMod-SSL Ep27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.34it/s]


[Ep 27] masked MSE: 0.90113


CBraMod-SSL Ep28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.09it/s]


[Ep 28] masked MSE: 0.91265


CBraMod-SSL Ep29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.74it/s]


[Ep 29] masked MSE: 0.89849


CBraMod-SSL Ep30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.66it/s]


[Ep 30] masked MSE: 0.86131


CBraMod-SSL Ep31: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.92it/s]


[Ep 31] masked MSE: 0.85880


CBraMod-SSL Ep32: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.03it/s]


[Ep 32] masked MSE: 0.86334


CBraMod-SSL Ep33: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.60it/s]


[Ep 33] masked MSE: 0.83996


CBraMod-SSL Ep34: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.68it/s]


[Ep 34] masked MSE: 0.83075


CBraMod-SSL Ep35: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.38it/s]


[Ep 35] masked MSE: 0.83185


CBraMod-SSL Ep36: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.62it/s]


[Ep 36] masked MSE: 0.82462


CBraMod-SSL Ep37: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.94it/s]


[Ep 37] masked MSE: 0.84662


CBraMod-SSL Ep38: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.04it/s]


[Ep 38] masked MSE: 0.82587


CBraMod-SSL Ep39: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.85it/s]


[Ep 39] masked MSE: 0.79349


CBraMod-SSL Ep40: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.30it/s]


[Ep 40] masked MSE: 0.79392


CBraMod-SSL Ep41: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.94it/s]


[Ep 41] masked MSE: 0.79217


CBraMod-SSL Ep42: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.65it/s]


[Ep 42] masked MSE: 0.78902


CBraMod-SSL Ep43: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.30it/s]


[Ep 43] masked MSE: 0.78205


CBraMod-SSL Ep44: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.49it/s]


[Ep 44] masked MSE: 0.76765


CBraMod-SSL Ep45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.26it/s]


[Ep 45] masked MSE: 0.76225


CBraMod-SSL Ep46: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.67it/s]


[Ep 46] masked MSE: 0.75392


CBraMod-SSL Ep47: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.21it/s]


[Ep 47] masked MSE: 0.74830


CBraMod-SSL Ep48: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.94it/s]


[Ep 48] masked MSE: 0.74283


CBraMod-SSL Ep49: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.23it/s]


[Ep 49] masked MSE: 0.75597


CBraMod-SSL Ep50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.37it/s]


[Ep 50] masked MSE: 0.74063


CBraMod-SSL Ep51: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.66it/s]


[Ep 51] masked MSE: 0.71712


CBraMod-SSL Ep52: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.27it/s]


[Ep 52] masked MSE: 0.73423


CBraMod-SSL Ep53: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.66it/s]


[Ep 53] masked MSE: 0.72719


CBraMod-SSL Ep54: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.61it/s]


[Ep 54] masked MSE: 0.71434


CBraMod-SSL Ep55: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.41it/s]


[Ep 55] masked MSE: 0.71705


CBraMod-SSL Ep56: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.94it/s]


[Ep 56] masked MSE: 0.71043


CBraMod-SSL Ep57: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.25it/s]


[Ep 57] masked MSE: 0.71947


CBraMod-SSL Ep58: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.49it/s]


[Ep 58] masked MSE: 0.70672


CBraMod-SSL Ep59: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.64it/s]


[Ep 59] masked MSE: 0.70131


CBraMod-SSL Ep60: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.08it/s]


[Ep 60] masked MSE: 0.72563


CBraMod-SSL Ep61: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.65it/s]


[Ep 61] masked MSE: 0.70687


CBraMod-SSL Ep62: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.83it/s]


[Ep 62] masked MSE: 0.68784


CBraMod-SSL Ep63: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.71it/s]


[Ep 63] masked MSE: 0.69341


CBraMod-SSL Ep64: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.76it/s]


[Ep 64] masked MSE: 0.70331


CBraMod-SSL Ep65: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.65it/s]


[Ep 65] masked MSE: 0.68780


CBraMod-SSL Ep66: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.67it/s]


[Ep 66] masked MSE: 0.68620


CBraMod-SSL Ep67: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.35it/s]


[Ep 67] masked MSE: 0.68985


CBraMod-SSL Ep68: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.05it/s]


[Ep 68] masked MSE: 0.67084


CBraMod-SSL Ep69: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.13it/s]


[Ep 69] masked MSE: 0.66929


CBraMod-SSL Ep70: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.30it/s]


[Ep 70] masked MSE: 0.67052


CBraMod-SSL Ep71: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.94it/s]


[Ep 71] masked MSE: 0.65169


CBraMod-SSL Ep72: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.17it/s]


[Ep 72] masked MSE: 0.69938


CBraMod-SSL Ep73: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.61it/s]


[Ep 73] masked MSE: 0.64463


CBraMod-SSL Ep74: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.51it/s]


[Ep 74] masked MSE: 0.65682


CBraMod-SSL Ep75: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.09it/s]


[Ep 75] masked MSE: 0.64640


CBraMod-SSL Ep76: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.46it/s]


[Ep 76] masked MSE: 0.65856


CBraMod-SSL Ep77: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.33it/s]


[Ep 77] masked MSE: 0.65887


CBraMod-SSL Ep78: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.40it/s]


[Ep 78] masked MSE: 0.64001


CBraMod-SSL Ep79: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.00it/s]


[Ep 79] masked MSE: 0.66417


CBraMod-SSL Ep80: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.31it/s]


[Ep 80] masked MSE: 0.65402


CBraMod-SSL Ep81: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.90it/s]


[Ep 81] masked MSE: 0.63619


CBraMod-SSL Ep82: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.10it/s]


[Ep 82] masked MSE: 0.63746


CBraMod-SSL Ep83: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.45it/s]


[Ep 83] masked MSE: 0.65745


CBraMod-SSL Ep84: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.97it/s]


[Ep 84] masked MSE: 0.63057


CBraMod-SSL Ep85: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.94it/s]


[Ep 85] masked MSE: 0.63754


CBraMod-SSL Ep86: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.40it/s]


[Ep 86] masked MSE: 0.64470


CBraMod-SSL Ep87: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.66it/s]


[Ep 87] masked MSE: 0.63367


CBraMod-SSL Ep88: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.68it/s]


[Ep 88] masked MSE: 0.63655


CBraMod-SSL Ep89: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.10it/s]


[Ep 89] masked MSE: 0.63407


CBraMod-SSL Ep90: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.66it/s]


[Ep 90] masked MSE: 0.62041


CBraMod-SSL Ep91: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.24it/s]


[Ep 91] masked MSE: 0.62930


CBraMod-SSL Ep92: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.14it/s]


[Ep 92] masked MSE: 0.62628


CBraMod-SSL Ep93: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.70it/s]


[Ep 93] masked MSE: 0.63006


CBraMod-SSL Ep94: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.49it/s]


[Ep 94] masked MSE: 0.65490


CBraMod-SSL Ep95: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.76it/s]


[Ep 95] masked MSE: 0.60658


CBraMod-SSL Ep96: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.53it/s]


[Ep 96] masked MSE: 0.61904


CBraMod-SSL Ep97: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.86it/s]


[Ep 97] masked MSE: 0.62761


CBraMod-SSL Ep98: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.40it/s]


[Ep 98] masked MSE: 0.62267


CBraMod-SSL Ep99: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.47it/s]


[Ep 99] masked MSE: 0.61323


CBraMod-SSL Ep100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.71it/s]


[Ep 100] masked MSE: 0.62279
‚úÖ Saved SSL checkpoint: cbramod_ssl_hbn.pth


In [7]:
# ssl_finetune_hbn_from_pretrained.py
import os, random, numpy as np, torch, sys
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ========== CBraMod import ==========
CBRAMOD_PATH = "/home/RA/EEG_Challenge/Challenge2/CBraMod"
if CBRAMOD_PATH not in sys.path:
    sys.path.append(CBRAMOD_PATH)
from models.cbramod import CBraMod
import torch.nn as nn

# ========== Config ==========
CACHE_DIR   = "/data5/open_data/HBN/cache_eeg_windows"
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ       = 100
N_CHANS     = 128
S, P        = 4, 200
MASK_RATIO  = 0.3
BLOCK_LEN   = 32
BATCH_SIZE  = 16
EPOCHS      = 50
LR          = 3e-4
WD          = 1e-4
SEED        = 42
NUM_WORKERS = 4
PRETRAINED_PATH = "/home/RA/EEG_Challenge/Challenge2/CBraMod/pretrained_weights/pretrained_weights.pth"

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# ========== Helper functions ==========
def make_block_mask(P_len, mask_ratio=0.3, block_len=32):
    keep = np.ones(P_len, dtype=bool)
    target = int(P_len * mask_ratio)
    n_blocks = max(1, target // block_len)
    for _ in range(n_blocks):
        st = np.random.randint(0, max(1, P_len - block_len + 1))
        keep[st:st+block_len] = False
    while (~keep).sum() < target:
        keep[np.random.randint(0, P_len)] = False
    return keep

class HBNCBraModSSL(Dataset):
    """HBN EEG cache ‚Üí (C,S,P), keep-mask"""
    def __init__(self, cache_dir, S=4, P=200):
        self.files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir) if f.endswith(".npy")]
        self.S, self.P = S, P
        self.win_len = S * P
    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        arr = np.load(self.files[i], mmap_mode="r")
        seg = arr[np.random.randint(0, arr.shape[0])]
        if seg.shape[0] == 129: seg = seg[:-1,:]  # Cz Ï†úÍ±∞
        C,T = seg.shape
        if T < self.win_len:
            pad = np.zeros((C, self.win_len), np.float32); pad[:,:T]=seg; seg=pad
        else:
            st = np.random.randint(0, T - self.win_len + 1)
            seg = seg[:, st:st+self.win_len]
        seg = (seg - seg.mean(axis=1, keepdims=True)) / (seg.std(axis=1, keepdims=True)+1e-6)
        x = torch.tensor(seg.reshape(C, self.S, self.P), dtype=torch.float32)
        keep = np.stack([make_block_mask(self.P, MASK_RATIO, BLOCK_LEN) for _ in range(self.S)], axis=0)
        keep = torch.tensor(keep, dtype=torch.bool).unsqueeze(0)
        return x, keep

def masked_mse(pred, target, keep_bool, eps=1e-8):
    masked = (~keep_bool).float()
    diff2  = (pred - target)**2
    num = (diff2 * masked).sum()
    den = masked.sum() * pred.size(1) + eps
    return num/den

# ========== Main ==========
def main():
    ds = HBNCBraModSSL(CACHE_DIR, S, P)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

    # 1. Load pretrained CBraMod
    model = CBraMod().to(DEVICE)
    model.proj_out = nn.Identity()

    if os.path.exists(PRETRAINED_PATH):
        ckpt = torch.load(PRETRAINED_PATH, map_location="cpu")
        model.load_state_dict(ckpt, strict=False)
        print(f"‚úÖ Loaded pretrained CBraMod weights from {PRETRAINED_PATH}")
    else:
        print("‚ö†Ô∏è Pretrained model not found, training from scratch.")

    # 2. Continue SSL with HBN EEG
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

    for ep in range(1, EPOCHS+1):
        model.train(); total=0.0
        for x, keep in tqdm(loader, desc=f"CBraMod HBN-SSL Ep{ep}"):
            x, keep = x.to(DEVICE), keep.to(DEVICE)
            visible = x * keep.float()
            opt.zero_grad()
            with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
                pred = model(visible)
                loss = masked_mse(pred, x, keep)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            total += loss.item()
        print(f"[Ep{ep:02d}] masked MSE = {total/len(loader):.6f}")

    torch.save(model.state_dict(), "cbramod_ssl_hbn_from_pretrained.pth")
    print("‚úÖ Saved fine-tuned SSL model: cbramod_ssl_hbn_from_pretrained.pth")

if __name__ == "__main__":
    main()


‚úÖ Loaded pretrained CBraMod weights from /home/RA/EEG_Challenge/Challenge2/CBraMod/pretrained_weights/pretrained_weights.pth


CBraMod HBN-SSL Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.28it/s]


[Ep01] masked MSE = 0.970809


CBraMod HBN-SSL Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.36it/s]


[Ep02] masked MSE = 0.844289


CBraMod HBN-SSL Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.41it/s]


[Ep03] masked MSE = 0.752549


CBraMod HBN-SSL Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.63it/s]


[Ep04] masked MSE = 0.718108


CBraMod HBN-SSL Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.46it/s]


[Ep05] masked MSE = 0.690187


CBraMod HBN-SSL Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.96it/s]


[Ep06] masked MSE = 0.656835


CBraMod HBN-SSL Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.02it/s]


[Ep07] masked MSE = 0.653044


CBraMod HBN-SSL Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.52it/s]


[Ep08] masked MSE = 0.619441


CBraMod HBN-SSL Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.62it/s]


[Ep09] masked MSE = 0.609711


CBraMod HBN-SSL Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.44it/s]


[Ep10] masked MSE = 0.626462


CBraMod HBN-SSL Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.31it/s]


[Ep11] masked MSE = 0.583463


CBraMod HBN-SSL Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.03it/s]


[Ep12] masked MSE = 0.607272


CBraMod HBN-SSL Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  6.90it/s]


[Ep13] masked MSE = 0.596146


CBraMod HBN-SSL Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.51it/s]


[Ep14] masked MSE = 0.571267


CBraMod HBN-SSL Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.34it/s]


[Ep15] masked MSE = 0.602610


CBraMod HBN-SSL Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.06it/s]


[Ep16] masked MSE = 0.598155


CBraMod HBN-SSL Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.66it/s]


[Ep17] masked MSE = 0.578831


CBraMod HBN-SSL Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.07it/s]


[Ep18] masked MSE = 0.576058


CBraMod HBN-SSL Ep19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.17it/s]


[Ep19] masked MSE = 0.577157


CBraMod HBN-SSL Ep20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.00it/s]


[Ep20] masked MSE = 0.576164


CBraMod HBN-SSL Ep21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  9.07it/s]


[Ep21] masked MSE = 0.553744


CBraMod HBN-SSL Ep22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.64it/s]


[Ep22] masked MSE = 0.559030


CBraMod HBN-SSL Ep23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.08it/s]


[Ep23] masked MSE = 0.543769


CBraMod HBN-SSL Ep24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.14it/s]


[Ep24] masked MSE = 0.552027


CBraMod HBN-SSL Ep25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.03it/s]


[Ep25] masked MSE = 0.527633


CBraMod HBN-SSL Ep26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.93it/s]


[Ep26] masked MSE = 0.573872


CBraMod HBN-SSL Ep27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.17it/s]


[Ep27] masked MSE = 0.532977


CBraMod HBN-SSL Ep28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.04it/s]


[Ep28] masked MSE = 0.547562


CBraMod HBN-SSL Ep29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.20it/s]


[Ep29] masked MSE = 0.556649


CBraMod HBN-SSL Ep30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  6.56it/s]


[Ep30] masked MSE = 0.525067


CBraMod HBN-SSL Ep31: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:04<00:00,  4.10it/s]


[Ep31] masked MSE = 0.540127


CBraMod HBN-SSL Ep32: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.23it/s]


[Ep32] masked MSE = 0.535363


CBraMod HBN-SSL Ep33: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.63it/s]


[Ep33] masked MSE = 0.523683


CBraMod HBN-SSL Ep34: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.33it/s]


[Ep34] masked MSE = 0.530437


CBraMod HBN-SSL Ep35: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.42it/s]


[Ep35] masked MSE = 0.533464


CBraMod HBN-SSL Ep36: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.67it/s]


[Ep36] masked MSE = 0.523001


CBraMod HBN-SSL Ep37: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.23it/s]


[Ep37] masked MSE = 0.551005


CBraMod HBN-SSL Ep38: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.37it/s]


[Ep38] masked MSE = 0.529522


CBraMod HBN-SSL Ep39: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.65it/s]


[Ep39] masked MSE = 0.506054


CBraMod HBN-SSL Ep40: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.59it/s]


[Ep40] masked MSE = 0.518768


CBraMod HBN-SSL Ep41: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.55it/s]


[Ep41] masked MSE = 0.508010


CBraMod HBN-SSL Ep42: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.46it/s]


[Ep42] masked MSE = 0.523164


CBraMod HBN-SSL Ep43: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.67it/s]


[Ep43] masked MSE = 0.515965


CBraMod HBN-SSL Ep44: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.53it/s]


[Ep44] masked MSE = 0.504159


CBraMod HBN-SSL Ep45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.56it/s]


[Ep45] masked MSE = 0.510754


CBraMod HBN-SSL Ep46: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.67it/s]


[Ep46] masked MSE = 0.497441


CBraMod HBN-SSL Ep47: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.06it/s]


[Ep47] masked MSE = 0.507021


CBraMod HBN-SSL Ep48: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.36it/s]


[Ep48] masked MSE = 0.513052


CBraMod HBN-SSL Ep49: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  8.42it/s]


[Ep49] masked MSE = 0.518703


CBraMod HBN-SSL Ep50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  7.65it/s]


[Ep50] masked MSE = 0.514092
‚úÖ Saved fine-tuned SSL model: cbramod_ssl_hbn_from_pretrained.pth


In [15]:
# finetune_hbn_regressor_from_ssl.py
import os, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from collections import defaultdict

# ---- CBraMod import ----
CBRAMOD_PATH = "/home/RA/EEG_Challenge/Challenge2/CBraMod"
import sys
if CBRAMOD_PATH not in sys.path:
    sys.path.append(CBRAMOD_PATH)
from models.cbramod import CBraMod

# ---- Config ----
CACHE_DIR = "/data5/open_data/HBN/cache_eeg_windows"
BIDS_ROOT = "/data5/open_data/HBN/EEG_BIDS"
DATASETS  = [f"ds0055{i:02d}" for i in range(5,17)]
SSL_CKPT  = "cbramod_ssl_hbn.pth"  # Stage1 output
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
S, P = 4, 200
N_CH = 128
K = 4
BATCH = 8
EPOCHS = 30
LR = 2e-4
WD = 1e-3
NUM_WORKERS = 4

# ---- Metadata load ----
def load_all_participants(bids_root=BIDS_ROOT, datasets=DATASETS):
    dfs=[]
    for ds in datasets:
        pfile = os.path.join(bids_root, ds, "participants.tsv")
        if os.path.exists(pfile):
            df = pd.read_csv(pfile, sep="\t"); df["release"]=ds; dfs.append(df)
    df = pd.concat(dfs, ignore_index=True)
    df.columns=[c.lower() for c in df.columns]
    df=df.drop_duplicates(subset=["participant_id"]).set_index("participant_id")
    if "p_factor" not in df.columns:
        alt=[c for c in df.columns if "p" in c and "factor" in c]
        if alt: df=df.rename(columns={alt[0]:"p_factor"})
    for c in ["age","sex","ehq_total"]:
        if c not in df.columns: df[c]=np.nan
    return df

meta_df=load_all_participants()
labels_df=meta_df[["p_factor"]].dropna()
meta_extra=meta_df[["age","sex","ehq_total"]].copy()
meta_extra["sex"]=meta_extra["sex"].map({"M":1,"F":0,"m":1,"f":0}).fillna(0.5)
for col in ["age","ehq_total"]:
    m,s=meta_extra[col].mean(),meta_extra[col].std()+1e-6
    meta_extra[col]=(meta_extra[col]-m)/s

# ---- Indexing ----
def subj_from_fname(f): return f.split("_")[0]
def build_index(cache_dir, files, label_idx):
    idx=[]
    for f in files:
        subj=subj_from_fname(f)
        if subj not in label_idx: continue
        arr=np.load(os.path.join(cache_dir,f),mmap_mode="r")
        for s in range(arr.shape[0]): idx.append((f,s,subj))
    return idx

all_files=[f for f in os.listdir(CACHE_DIR) if f.endswith(".npy")]
train_files=[f for f in all_files if "_train_" in f]
val_files=[f for f in all_files if "_val_" in f]
train_index=build_index(CACHE_DIR,train_files,labels_df.index)
val_index=build_index(CACHE_DIR,val_files,labels_df.index)

def subj_to_idx(idx):
    d=defaultdict(list)
    for i,(_,_,s) in enumerate(idx): d[s].append(i)
    return d
train_map=subj_to_idx(train_index); val_map=subj_to_idx(val_index)

# ---- Dataset ----
class WindowDS(Dataset):
    def __init__(self, cache_dir, index, labels, S=4,P=200,train=True):
        self.cache_dir=cache_dir; self.index=index; self.labels=labels
        self.S=S; self.P=P; self.W=S*P; self.train=train
    def __getitem__(self,i):
        f,s,subj=self.index[i]
        y=float(self.labels.loc[subj])
        x=np.load(os.path.join(self.cache_dir,f),mmap_mode="r")[s]
        if x.shape[0]==129: x=x[:-1,:]
        C,T=x.shape
        if T<self.W:
            pad=np.zeros((C,self.W),np.float32); pad[:,:T]=x; x=pad
        else:
            st=np.random.randint(0,T-self.W+1) if self.train else (T-self.W)//2
            x=x[:,st:st+self.W]
        x=(x-x.mean(1,keepdims=True))/(x.std(1,keepdims=True)+1e-6)
        return torch.tensor(x.reshape(C,self.S,self.P),dtype=torch.float32),torch.tensor(y),subj
    def __len__(self): return len(self.index)

class SubjectDS(Dataset):
    def __init__(self, base, subj_map, meta, K=4,train=True):
        self.base=base; self.map=subj_map; self.subjs=list(self.map.keys()); self.meta=meta; self.K=K; self.train=train
    def __getitem__(self,i):
        subj=self.subjs[i]; idxs=self.map[subj]
        sel=random.choices(idxs,k=self.K)
        Xs,ys=[],[]
        for j in sel:
            x,y,s=self.base[j]; Xs.append(x); ys.append(y)
        X=torch.stack(Xs,0); meta_vec=torch.tensor(self.meta.loc[subj].values,dtype=torch.float32)
        return X,ys[0],meta_vec,subj
    def __len__(self): return len(self.subjs)

# ---- Model ----
class CBraModRegressor(nn.Module):
    def __init__(self, meta_dim=3):
        super().__init__()
        self.backbone=CBraMod(); self.backbone.proj_out=nn.Identity()
        self.meta_fc=nn.Sequential(nn.Linear(meta_dim,32),nn.ReLU())
        self.regressor=nn.Sequential(nn.Linear(N_CH+32,128),nn.ReLU(),nn.Linear(128,1))
    def forward(self,X,meta):
        B,K=X.shape[:2]
        X=X.view(B*K,N_CH,S,P)
        F=self.backbone(X)
        F=F.mean((-1,-2))        # GAP
        F=F.view(B,K,N_CH).mean(1)
        M=self.meta_fc(meta)
        return self.regressor(torch.cat([F,M],1)).squeeze(-1)

# ---- Train/Eval ----
def evaluate(model, loader, mean, std):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for X, y, meta, _ in loader:
            X, meta = X.to(DEVICE), meta.to(DEVICE)
            y = y.to(torch.float32).to(DEVICE)

            p = model(X, meta)                   # (B,)
            # üîç Ïú†Ìö®Í∞íÎßå ÏÇ¨Ïö© (ÏπòÌôò X)
            mask = torch.isfinite(p) & torch.isfinite(y)
            if mask.sum() == 0:
                continue

            p = p[mask]
            y = y[mask]

            p_denorm = p * std + mean
            preds.extend(p_denorm.cpu().numpy())
            trues.extend(y.cpu().numpy())

    preds, trues = np.asarray(preds), np.asarray(trues)
    if len(trues) == 0:
        return np.nan, np.nan, np.nan

    # Î∂ïÍ¥¥ ÏßÑÎã® Î°úÍ∑∏(Ìïú Î≤àÎßå Î≥¥Îäî Ïö©ÎèÑ)
    # print("pred mean/std:", preds.mean(), preds.std())

    r2 = 0.0 if np.var(trues) < 1e-8 else r2_score(trues, preds)
    return mean_absolute_error(trues, preds), mean_squared_error(trues, preds), r2


def main():
    tr_base=WindowDS(CACHE_DIR,train_index,labels_df["p_factor"],S,P,train=True)
    va_base=WindowDS(CACHE_DIR,val_index,labels_df["p_factor"],S,P,train=False)
    tr_ds=SubjectDS(tr_base,train_map,meta_extra,K,train=True)
    va_ds=SubjectDS(va_base,val_map,meta_extra,K,train=False)
    tr_loader=DataLoader(tr_ds,batch_size=BATCH,shuffle=True,num_workers=NUM_WORKERS)
    va_loader=DataLoader(va_ds,batch_size=BATCH,shuffle=False,num_workers=NUM_WORKERS)

    model = CBraModRegressor().to(DEVICE)
    if os.path.exists(SSL_CKPT):
        ckpt = torch.load(SSL_CKPT, map_location="cpu")
        model.backbone.load_state_dict(ckpt, strict=False)
        print("‚úÖ Loaded SSL backbone weights.")
    else:
        print("‚ö†Ô∏è No SSL checkpoint found. Training from scratch.")

    # üîì ÌòπÏãúÎùºÎèÑ freeze ÎêòÏñ¥ ÏûàÏúºÎ©¥ Ï†ÑÎ∂Ä ÌíÄÍ∏∞
    for p in model.parameters():
        p.requires_grad = False

    # üß™ ÏßÑÏßúÎ°ú ÌíÄÎ†∏ÎäîÏßÄ Ìïú Î≤àÎßå Î°úÍπÖ
    n_trainable = sum(p.requires_grad for p in model.parameters())
    print(f"Trainable params tensors: {n_trainable}")

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

    mean,std=labels_df["p_factor"].mean(),labels_df["p_factor"].std()+1e-6
    best={"mae":1e9,"mse":1e9,"r2":-1}
    for ep in range(1,EPOCHS+1):
        model.train()
        for X, y, m, subj in tqdm(tr_loader, desc=f"Train Ep{ep}"):
            X, m = X.to(DEVICE), m.to(DEVICE)
            y = y.to(torch.float32).to(DEVICE)
            y_n = ((y - mean) / std)

            opt.zero_grad()
            p = model(X, m)  # (B,)

            # üîç ÏàòÏπò ÏßÑÎã®
            nan_frac = torch.isnan(p).float().mean().item()
            if nan_frac > 0:
                # Ïù¥ Î∞∞Ïπò Ïä§ÌÇµ (ÎçÆÏñ¥Ïì∞Í∏∞ ÎåÄÏã† Î≤ÑÎ¶¨Í∏∞)
                # ÌïÑÏöîÌïòÎ©¥ Ïó¨Í∏∞ÏÑú gradient clipping/ LR ÎÇÆÏ∂îÍ∏∞ Îì±ÏùÑ Í≥†Î†§
                continue

            loss = 0.5 * (F.l1_loss(p, y_n) + F.mse_loss(p, y_n))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Ìè≠Ï£º Î∞©ÏßÄ
            opt.step()

        val_mae,val_mse,val_r2=evaluate(model,va_loader,mean,std)
        print(f"[Ep{ep:02d}] Val MAE={val_mae:.4f} MSE={val_mse:.4f} R¬≤={val_r2:.3f}")
        if val_mae<best["mae"]:
            best={"mae":val_mae,"mse":val_mse,"r2":val_r2}
            torch.save(model.state_dict(),"cbramod_hbn_regressor_from_ssl_best.pth")
            print(" ‚úì Saved best model")

if __name__=="__main__":
    main()


‚úÖ Loaded SSL backbone weights.
Trainable params tensors: 0


Train Ep1:   0%|          | 0/38 [00:02<?, ?it/s]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [16]:
# domain_adapt_regression_cbramod_adversarial.py
import os, sys, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from collections import defaultdict
from itertools import cycle

# ------------------------------------------------------------
# Config
# ------------------------------------------------------------
CBRAMOD_PATH = "/home/RA/EEG_Challenge/Challenge2/CBraMod"
if CBRAMOD_PATH not in sys.path:
    sys.path.append(CBRAMOD_PATH)
from models.cbramod import CBraMod  # repoÏùò Î™®Îç∏

CACHE_DIR = "/data5/open_data/HBN/cache_eeg_windows"
BIDS_ROOT = "/data5/open_data/HBN/EEG_BIDS"
DATASETS  = [f"ds0055{i:02d}" for i in range(5, 17)]
SSL_CKPT  = "/home/RA/EEG_Challenge/Challenge2/CBraMod/pretrained_weights/pretrained_weights.pth"  # ÏÇ¨Ï†ÑÌïôÏäµ Î∞±Î≥∏ Í≤ΩÎ°ú

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
S, P    = 4, 200
WIN_T   = S * P
N_CH    = 128          # 129Î©¥ Cz Ï†úÍ±∞
BATCH   = 8
EPOCHS  = 30
LR      = 2e-4
WD      = 1e-3
NUM_WORKERS = 4
LAMBDA_DA   = 0.2      # domain loss Í∞ÄÏ§ëÏπò
SEED    = 42

def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed()

# ------------------------------------------------------------
# Utilities
# ------------------------------------------------------------
def subj_from_fname(f): return f.split("_")[0]

def load_all_participants(bids_root=BIDS_ROOT, datasets=DATASETS):
    dfs=[]
    for ds in datasets:
        pfile=os.path.join(bids_root, ds, "participants.tsv")
        if os.path.exists(pfile):
            df=pd.read_csv(pfile, sep="\t"); df["release"]=ds; dfs.append(df)
    df=pd.concat(dfs, ignore_index=True)
    df.columns=[c.lower() for c in df.columns]
    df=df.drop_duplicates(subset=["participant_id"]).set_index("participant_id")
    if "p_factor" not in df.columns:
        alt=[c for c in df.columns if "p" in c and "factor" in c]
        if alt: df=df.rename(columns={alt[0]:"p_factor"})
    for c in ["age","sex","ehq_total"]:
        if c not in df.columns: df[c]=np.nan
    return df

meta_df   = load_all_participants()
labels_df = meta_df[["p_factor"]].dropna()

# ------------------------------------------------------------
# Indexing
# ------------------------------------------------------------
def build_index(cache_dir, files, label_idx):
    idx=[]
    for f in files:
        subj=subj_from_fname(f)
        if subj not in label_idx: 
            continue
        arr=np.load(os.path.join(cache_dir,f), mmap_mode="r")
        for s in range(arr.shape[0]):
            idx.append((f, s, subj))
    return idx

all_files = [f for f in os.listdir(CACHE_DIR) if f.endswith(".npy")]
src_files = [f for f in all_files if "_train_" in f]  # source domain (labeled)
tgt_files = [f for f in all_files if "_val_" in f]    # target domain

src_index = build_index(CACHE_DIR, src_files, labels_df.index)
tgt_index = build_index(CACHE_DIR, tgt_files, labels_df.index)

# ------------------------------------------------------------
# Datasets
# ------------------------------------------------------------
class WindowDS(Dataset):
    """(segments, C, T) Ï∫êÏãú ‚Üí (C, S, P), y"""
    def __init__(self, cache_dir, index, labels, S=4, P=200, train=True):
        self.cache_dir=cache_dir; self.index=index; self.labels=labels
        self.S=S; self.P=P; self.W=S*P; self.train=train

    def __len__(self): return len(self.index)

    def __getitem__(self, i):
        f, s, subj = self.index[i]
        y = float(self.labels.loc[subj])
        x = np.load(os.path.join(self.cache_dir, f), mmap_mode="r")[s]  # (C, T)

        # Cz Ï†úÍ±∞ (129 -> 128)
        if x.shape[0] == 129: 
            x = x[:-1, :]

        C, T = x.shape
        # pad/crop to W
        if T < self.W:
            pad = np.zeros((C, self.W), np.float32)
            pad[:, :T] = x; x = pad
        else:
            st = np.random.randint(0, T - self.W + 1) if self.train else (T - self.W)//2
            x = x[:, st:st+self.W]

        # per-channel z-score
        x = (x - x.mean(axis=1, keepdims=True)) / (x.std(axis=1, keepdims=True) + 1e-6)

        return torch.tensor(x.reshape(C, self.S, self.P), dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

src_ds = WindowDS(CACHE_DIR, src_index, labels_df["p_factor"], S, P, train=True)
tgt_ds = WindowDS(CACHE_DIR, tgt_index, labels_df["p_factor"], S, P, train=False)

src_loader = DataLoader(src_ds, batch_size=BATCH, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
tgt_loader = DataLoader(tgt_ds, batch_size=BATCH, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
tgt_eval_loader = DataLoader(tgt_ds, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# ------------------------------------------------------------
# Model: CBraMod backbone + GRL domain head + Regressor
# ------------------------------------------------------------
class GradientReversalFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lamb): 
        ctx.lamb = lamb
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lamb * grad_output, None

class GRL(nn.Module):
    def __init__(self, lamb=1.0): super().__init__(); self.lamb=float(lamb)
    def forward(self, x, lamb=None):
        if lamb is None: lamb = self.lamb
        return GradientReversalFn.apply(x, float(lamb))

class DomainAdaptRegressor(nn.Module):
    def __init__(self, meta_dim=0):  # meta Ïïà Ïì∞Î©¥ 0
        super().__init__()
        self.backbone = CBraMod()
        self.backbone.proj_out = nn.Identity()      # (B, C, S, P)
        self.grl = GRL()
        self.domain_head = nn.Sequential(
            nn.Linear(N_CH, 64), nn.ReLU(),
            nn.Linear(64, 2)                         # source / target
        )
        # Í∞ÑÎã® ÌöåÍ∑ÄÌó§Îìú: GAP ÌõÑ Î∞îÎ°ú ÌöåÍ∑Ä
        self.regressor = nn.Sequential(
            nn.Linear(N_CH + meta_dim, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x, lamb_da=0.0, meta=None):
        f = self.backbone(x)             # (B, C, S, P)
        f = f.mean(dim=(-1, -2))         # (B, C) = (B, 128)
        # regression
        if meta is None:
            z = f
        else:
            z = torch.cat([f, meta], dim=1)
        y_hat = self.regressor(z).squeeze(-1)
        # domain logits
        d_logits = self.domain_head(self.grl(f, lamb_da))
        return y_hat, d_logits

# ------------------------------------------------------------
# Metrics
# ------------------------------------------------------------
def safe_r2(y_true, y_pred):
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred)
    if y_true.size < 2 or np.var(y_true) < 1e-8:
        return 0.0
    return r2_score(y_true, y_pred)

@torch.no_grad()
def evaluate(model, loader, y_mean, y_std):
    model.eval()
    preds, trues = [], []
    for x, y in loader:
        x = x.to(DEVICE); y = y.to(torch.float32).to(DEVICE)
        p, _ = model(x, lamb_da=0.0, meta=None)     # eval: ÎèÑÎ©îÏù∏ Î°úÏßì Î¨¥Ïãú
        # Ïó≠Ï†ïÍ∑úÌôî
        p = p * y_std + y_mean
        # Ïú†Ìö®Í∞íÎßå
        mask = torch.isfinite(p) & torch.isfinite(y)
        if mask.sum() == 0: 
            continue
        preds.extend(p[mask].cpu().numpy())
        trues.extend(y[mask].cpu().numpy())
    if len(trues) == 0:
        return np.nan, np.nan, np.nan
    preds, trues = np.array(preds), np.array(trues)
    return mean_absolute_error(trues, preds), mean_squared_error(trues, preds), safe_r2(trues, preds)

# ------------------------------------------------------------
# Train
# ------------------------------------------------------------
def main():
    model = DomainAdaptRegressor(meta_dim=0).to(DEVICE)
    if os.path.exists(SSL_CKPT):
        ckpt = torch.load(SSL_CKPT, map_location="cpu")
        model.backbone.load_state_dict(ckpt, strict=False)
        print("‚úÖ Loaded SSL pretrained CBraMod backbone.")
    else:
        print("‚ö†Ô∏è No SSL checkpoint found. Training backbone from scratch.")

    # ÌïÑÏöîÌïòÎ©¥ backbone ÏùºÎ∂ÄÎßå ÌíÄ ÏàòÎèÑ ÏûàÏùå. Ïó¨Í∏∞ÏÑ† Ï†ÑÏ≤¥ ÎØ∏ÏÑ∏Ï°∞Ï†ï ÌóàÏö©
    for p in model.backbone.parameters():
        p.requires_grad = True

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    ce  = nn.CrossEntropyLoss()

    y_mean = labels_df["p_factor"].mean()
    y_std  = labels_df["p_factor"].std() + 1e-6

    best = {"mae": np.inf, "mse": np.inf, "r2": -1.0}

    # target loaderÍ∞Ä Îçî ÏßßÏúºÎ©¥ cycleÎ°ú Î∞òÎ≥µ
    tgt_iter = cycle(tgt_loader)

    for ep in range(1, EPOCHS+1):
        model.train()
        total = 0.0
        for x_s, y_s in tqdm(src_loader, desc=f"[Train DA] Ep{ep}"):
            x_t, _ = next(tgt_iter)

            x_s = x_s.to(DEVICE); y_s = y_s.to(torch.float32).to(DEVICE)
            x_t = x_t.to(DEVICE)

            # label Ï†ïÍ∑úÌôî
            y_s_n = (y_s - y_mean) / y_std

            opt.zero_grad()

            # source: regression + domain(0)
            p_s, d_s = model(x_s, lamb_da=LAMBDA_DA, meta=None)
            loss_reg = 0.5 * (F.l1_loss(p_s, y_s_n) + F.mse_loss(p_s, y_s_n))
            loss_dom_s = ce(d_s, torch.zeros(len(x_s), dtype=torch.long, device=DEVICE))

            # target: domain(1)Îßå
            _, d_t = model(x_t, lamb_da=LAMBDA_DA, meta=None)
            loss_dom_t = ce(d_t, torch.ones(len(x_t), dtype=torch.long, device=DEVICE))

            loss = loss_reg + LAMBDA_DA * (loss_dom_s + loss_dom_t)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            total += loss.item()

        # ---- Evaluation on target domain (with labels) ----
        val_mae, val_mse, val_r2 = evaluate(model, tgt_eval_loader, y_mean, y_std)
        print(f"[Ep{ep:02d}] TrainLoss={total/max(1,len(src_loader)):.4f} | "
              f"TGT Val MAE={val_mae:.4f} MSE={val_mse:.4f} R¬≤={val_r2:.3f}")

        if val_mae < best["mae"]:
            best = {"mae": val_mae, "mse": val_mse, "r2": val_r2}
            torch.save(model.state_dict(), "cbramod_da_regressor_best.pth")
            print("‚úì Saved best: cbramod_da_regressor_best.pth")

    print(f"\nBest Target-domain Val ‚Äî MAE={best['mae']:.4f} | MSE={best['mse']:.4f} | R¬≤={best['r2']:.3f}")

if __name__ == "__main__":
    main()


‚úÖ Loaded SSL pretrained CBraMod backbone.


[Train DA] Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:37<00:00,  6.55it/s] 


[Ep01] TrainLoss=0.6597 | TGT Val MAE=0.9641 MSE=1.2424 R¬≤=-0.174
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:29<00:00,  6.58it/s]


[Ep02] TrainLoss=0.3646 | TGT Val MAE=1.0240 MSE=1.3978 R¬≤=-0.321


[Train DA] Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:39<00:00,  6.76it/s]


[Ep03] TrainLoss=0.2958 | TGT Val MAE=1.0006 MSE=1.3518 R¬≤=-0.278


[Train DA] Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:44<00:00,  6.74it/s]


[Ep04] TrainLoss=0.2568 | TGT Val MAE=0.9504 MSE=1.2183 R¬≤=-0.152
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:14<00:00,  6.63it/s]


[Ep05] TrainLoss=0.2312 | TGT Val MAE=0.9373 MSE=1.1644 R¬≤=-0.101
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:38<00:00,  6.76it/s]


[Ep06] TrainLoss=0.2154 | TGT Val MAE=0.9844 MSE=1.2698 R¬≤=-0.200


[Train DA] Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:09<00:00,  6.65it/s]


[Ep07] TrainLoss=0.2030 | TGT Val MAE=0.9916 MSE=1.2680 R¬≤=-0.199


[Train DA] Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:31<00:00,  6.79it/s]


[Ep08] TrainLoss=0.1924 | TGT Val MAE=0.9541 MSE=1.1817 R¬≤=-0.117


[Train DA] Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:55<00:00,  6.70it/s]


[Ep09] TrainLoss=0.1846 | TGT Val MAE=0.9302 MSE=1.1191 R¬≤=-0.058
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:29<00:00,  6.79it/s]


[Ep10] TrainLoss=0.1767 | TGT Val MAE=0.9962 MSE=1.2450 R¬≤=-0.177


[Train DA] Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:36<00:00,  6.55it/s]


[Ep11] TrainLoss=0.1703 | TGT Val MAE=0.9250 MSE=1.1073 R¬≤=-0.047
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:09<00:00,  6.65it/s]


[Ep12] TrainLoss=0.1641 | TGT Val MAE=0.9415 MSE=1.1125 R¬≤=-0.052


[Train DA] Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:58<00:00,  6.69it/s]


[Ep13] TrainLoss=0.1608 | TGT Val MAE=0.9516 MSE=1.1523 R¬≤=-0.089


[Train DA] Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:56<00:00,  6.70it/s]


[Ep14] TrainLoss=0.1576 | TGT Val MAE=0.8847 MSE=1.0353 R¬≤=0.021
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:34<00:00,  6.56it/s]


[Ep15] TrainLoss=0.1532 | TGT Val MAE=0.9201 MSE=1.1179 R¬≤=-0.057


[Train DA] Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:57<00:00,  6.69it/s]


[Ep16] TrainLoss=0.1501 | TGT Val MAE=0.9241 MSE=1.1422 R¬≤=-0.080


[Train DA] Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:17<00:00,  6.84it/s]


[Ep17] TrainLoss=0.1459 | TGT Val MAE=0.9147 MSE=1.1317 R¬≤=-0.070


[Train DA] Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:17<00:00,  6.62it/s]


[Ep18] TrainLoss=0.1443 | TGT Val MAE=0.9170 MSE=1.0843 R¬≤=-0.025


[Train DA] Ep19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:45<00:00,  6.73it/s]


[Ep19] TrainLoss=0.1415 | TGT Val MAE=0.9212 MSE=1.0890 R¬≤=-0.029


[Train DA] Ep20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:32<00:00,  6.78it/s]


[Ep20] TrainLoss=0.1393 | TGT Val MAE=0.9101 MSE=1.1028 R¬≤=-0.042


[Train DA] Ep21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:25<00:00,  6.81it/s]


[Ep21] TrainLoss=0.1364 | TGT Val MAE=0.9316 MSE=1.1559 R¬≤=-0.093


[Train DA] Ep22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:29<00:00,  6.58it/s]


[Ep22] TrainLoss=0.1354 | TGT Val MAE=0.9170 MSE=1.1019 R¬≤=-0.042


[Train DA] Ep23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:32<00:00,  6.78it/s]


[Ep23] TrainLoss=0.1332 | TGT Val MAE=0.9410 MSE=1.1502 R¬≤=-0.087


[Train DA] Ep24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:13<00:00,  6.85it/s]


[Ep24] TrainLoss=0.1305 | TGT Val MAE=0.9081 MSE=1.1036 R¬≤=-0.043


[Train DA] Ep25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:19<00:00,  6.61it/s]


[Ep25] TrainLoss=0.1289 | TGT Val MAE=0.8922 MSE=1.0558 R¬≤=0.002


[Train DA] Ep26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:53<00:00,  6.70it/s]


[Ep26] TrainLoss=0.1273 | TGT Val MAE=0.9029 MSE=1.1017 R¬≤=-0.041


[Train DA] Ep27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:46<00:00,  6.52it/s]


[Ep27] TrainLoss=0.1250 | TGT Val MAE=0.8892 MSE=1.0712 R¬≤=-0.013


[Train DA] Ep28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [31:28<00:00,  6.58it/s]


[Ep28] TrainLoss=0.1238 | TGT Val MAE=0.8714 MSE=1.0084 R¬≤=0.047
‚úì Saved best: cbramod_da_regressor_best.pth


[Train DA] Ep29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:36<00:00,  6.77it/s]


[Ep29] TrainLoss=0.1216 | TGT Val MAE=0.9081 MSE=1.0893 R¬≤=-0.030


[Train DA] Ep30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12430/12430 [30:40<00:00,  6.75it/s]


[Ep30] TrainLoss=0.1203 | TGT Val MAE=0.9077 MSE=1.1073 R¬≤=-0.047

Best Target-domain Val ‚Äî MAE=0.8714 | MSE=1.0084 | R¬≤=0.047


## Self-supervised learning raw data

In [1]:
# ssl_pretrain_raw_eegconformer_from_cache_and_missing_v2.py
import os, random, numpy as np, torch, mne, warnings
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from braindecode.models import EEGConformer

# ---------------- CONFIG ----------------
CACHE_DIR = "/data5/open_data/HBN/cache_eeg_windows"
BIDS_ROOT = "/data5/open_data/HBN/EEG_BIDS"
MISSING_TXT = "/home/RA/EEG_Challenge/Challenge2/logs/missing_cache_files.txt"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ = 100
N_CHANS = 128
CROP_T = 1000
MASK_RATIO = 0.3
BLOCK_LEN = 100

BATCH_SIZE = 8
EPOCHS = 50
LR = 3e-4
WD = 1e-4
SEED = 42

warnings.filterwarnings("ignore")
mne.set_log_level("CRITICAL")
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# ============================================================
# 1. ÌååÏùº Î¶¨Ïä§Ìä∏ Íµ¨ÏÑ± (cache + missing ÌÖçÏä§Ìä∏ Í∏∞Î∞ò)
# ============================================================
def build_eeg_filelist(cache_dir, missing_txt, bids_root):
    cache_files = [f for f in os.listdir(cache_dir) if f.endswith(".npy")]
    found_from_cache, not_found = [], []

    # -----------------------------
    # ‚úÖ 1) CACHE FILES Ï≤òÎ¶¨
    # -----------------------------
    for f in cache_files:
        subj = f.split("_")[0]  # sub-XXXX
        task = f.split("_task-")[1].split("_")[0]
        run = None
        if "_run-" in f:
            run = f.split("_run-")[1].split("_")[0]
        matched = False

        for ds in os.listdir(bids_root):
            eeg_dir = os.path.join(bids_root, ds, subj, "eeg")
            if not os.path.exists(eeg_dir):
                continue

            for fname in os.listdir(eeg_dir):
                if not fname.endswith(".set"):
                    continue
                # run ÏóÜÎäî Í≤ΩÏö∞
                if run is None:
                    if f"task-{task}" in fname:
                        found_from_cache.append(os.path.join(eeg_dir, fname))
                        matched = True
                        break
                else:
                    # run ÏûàÎäî Í≤ΩÏö∞: 1 / 01 Ìå®Îî© Î™®Îëê ÌóàÏö©
                    run_variants = [f"run-{run}", f"run-0{run}"]
                    if f"task-{task}" in fname and any(r in fname for r in run_variants):
                        found_from_cache.append(os.path.join(eeg_dir, fname))
                        matched = True
                        break
            if matched:
                break
        if not matched:
            not_found.append(f)

    # -----------------------------
    # ‚úÖ 2) MISSING FILES Ï≤òÎ¶¨
    # -----------------------------
    found_from_missing = []
    if os.path.exists(missing_txt):
        with open(missing_txt, "r") as f:
            for line in f:
                base = line.strip()
                if len(base) == 0:
                    continue
                subj = base.split("_task-")[0]
                task = base.split("_task-")[1].split("_")[0]
                run = None
                if "_run-" in base:
                    run = base.split("_run-")[1].split("_")[0]

                for ds in os.listdir(bids_root):
                    eeg_dir = os.path.join(bids_root, ds, subj, "eeg")
                    if not os.path.exists(eeg_dir):
                        continue

                    for fname in os.listdir(eeg_dir):
                        if not fname.endswith(".set"):
                            continue
                        if run is None:
                            if f"task-{task}" in fname:
                                found_from_missing.append(os.path.join(eeg_dir, fname))
                                break
                        else:
                            run_variants = [f"run-{run}", f"run-0{run}"]
                            if f"task-{task}" in fname and any(r in fname for r in run_variants):
                                found_from_missing.append(os.path.join(eeg_dir, fname))
                                break

    # -----------------------------
    # ‚úÖ 3) ÏµúÏ¢Ö Í≤∞Í≥º
    # -----------------------------
    all_files = sorted(list(set(found_from_cache + found_from_missing)))
    print(f"‚úÖ Total EEG files: {len(all_files)} "
          f"(cache={len(found_from_cache)}, missing={len(found_from_missing)})")
    if len(not_found) > 0:
        print(f"‚ö†Ô∏è {len(not_found)} cache-based files not matched to .set:")
        print("\n".join(not_found[:15]), "...")
    return all_files


# ============================================================
# 2. Dataset
# ============================================================
def make_block_mask(T, mask_ratio=0.3, block_len=50):
    mask = np.ones(T, dtype=bool)
    n_mask = int(T * mask_ratio)
    n_blocks = max(1, n_mask // block_len)
    for _ in range(n_blocks):
        st = np.random.randint(0, T - block_len)
        mask[st:st + block_len] = False
    return mask

class RawEEGConformerSSL(Dataset):
    def __init__(self, file_list, crop_len=1000, mask_ratio=0.3):
        self.files = file_list
        self.crop_len = crop_len
        self.mask_ratio = mask_ratio

    def __len__(self): return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
        raw.filter(0.5, 50., fir_design="firwin")
        raw.set_eeg_reference("average", projection=False)
        raw.resample(SFREQ)

        x = raw.get_data()
        if x.shape[0] == 129:
            x = x[:-1, :]  # Cz Ï†úÍ±∞
        x = (x - x.mean(axis=1, keepdims=True)) / (x.std(axis=1, keepdims=True) + 1e-6)

        T = x.shape[1]
        if T < self.crop_len:
            pad = np.zeros((x.shape[0], self.crop_len))
            pad[:, :T] = x
            x = pad
        else:
            st = np.random.randint(0, T - self.crop_len + 1)
            x = x[:, st:st + self.crop_len]

        mask = make_block_mask(self.crop_len, self.mask_ratio, BLOCK_LEN)
        keep = torch.tensor(mask, dtype=torch.bool).unsqueeze(0)
        return torch.tensor(x, dtype=torch.float32), keep

# ============================================================
# 3. EEG-Conformer Masked Autoencoder
# ============================================================
class EEGMaskedAutoencoder(nn.Module):
    def __init__(self, n_chans=128, n_times=1000, sfreq=100, embed_dim=256):
        super().__init__()
        self.encoder = EEGConformer(
            n_chans=n_chans, n_outputs=embed_dim,
            n_times=n_times, sfreq=sfreq
        )
        self.decoder = nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(embed_dim, n_chans, 1)
        )

    def forward(self, x, keep):
        visible = x * keep.float()
        z = self.encoder(visible)
        if z.ndim == 2:
            z = z.unsqueeze(-1)
        if z.shape[-1] != x.shape[-1]:
            z = F.interpolate(z, size=x.shape[-1], mode="linear", align_corners=False)
        return self.decoder(z)

def masked_mse(pred, target, keep, eps=1e-8):
    mask = (~keep).float()
    diff2 = (pred - target) ** 2
    num = (diff2 * mask).sum()
    den = mask.sum() * pred.size(1) + eps
    return num / den

# ============================================================
# 4. Train
# ============================================================
def main():
    file_list = build_eeg_filelist(CACHE_DIR, MISSING_TXT, BIDS_ROOT)
    ds = RawEEGConformerSSL(file_list, crop_len=CROP_T, mask_ratio=MASK_RATIO)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

    model = EEGMaskedAutoencoder(N_CHANS, CROP_T, SFREQ, 256).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

    for ep in range(1, EPOCHS + 1):
        model.train(); total = 0.0
        for x, keep in tqdm(loader, desc=f"EEGConformer SSL Ep{ep}"):
            x, keep = x.to(DEVICE), keep.to(DEVICE)
            opt.zero_grad()
            x_hat = model(x, keep)
            loss = masked_mse(x_hat, x, keep)
            loss.backward(); opt.step()
            total += loss.item()
        print(f"[Ep{ep:02d}] masked MSE={total/len(loader):.6f}")

    torch.save(model.encoder.state_dict(), "eegconformer_ssl_raw_from_cache_missing.pth")
    print("‚úÖ Saved pretrained encoder -> eegconformer_ssl_raw_from_cache_missing.pth")

if __name__ == "__main__":
    main()


  from pkg_resources import DefaultProvider, ResourceManager, \


‚úÖ Total EEG files: 340 (cache=318, missing=23)


EEGConformer SSL Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [42:00<00:00, 58.62s/it]  


[Ep01] masked MSE=0.601568


EEGConformer SSL Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [39:28<00:00, 55.09s/it]  


[Ep02] masked MSE=2.843229


EEGConformer SSL Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [38:53<00:00, 54.27s/it]  


[Ep03] masked MSE=0.659530


EEGConformer SSL Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [40:26<00:00, 56.44s/it]  


[Ep04] masked MSE=0.945477


EEGConformer SSL Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [45:31<00:00, 63.52s/it]  


[Ep05] masked MSE=0.569224


EEGConformer SSL Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [57:21<00:00, 80.03s/it]  


[Ep06] masked MSE=0.721789


EEGConformer SSL Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [56:56<00:00, 79.45s/it]  


[Ep07] masked MSE=0.925053


EEGConformer SSL Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [1:01:06<00:00, 85.26s/it]


[Ep08] masked MSE=0.601216


EEGConformer SSL Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [59:24<00:00, 82.90s/it]  


[Ep09] masked MSE=0.550655


EEGConformer SSL Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [58:50<00:00, 82.11s/it]  


[Ep10] masked MSE=0.595056


EEGConformer SSL Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [59:10<00:00, 82.58s/it]   


[Ep11] masked MSE=0.646879


EEGConformer SSL Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [1:01:20<00:00, 85.60s/it] 


[Ep12] masked MSE=1.047933


EEGConformer SSL Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [55:19<00:00, 77.21s/it]  


[Ep13] masked MSE=0.799290


EEGConformer SSL Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [53:05<00:00, 74.08s/it]  


[Ep14] masked MSE=0.556182


EEGConformer SSL Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [53:40<00:00, 74.90s/it]  


[Ep15] masked MSE=0.576571


EEGConformer SSL Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [43:15<00:00, 60.36s/it]  


[Ep16] masked MSE=0.697548


EEGConformer SSL Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [39:39<00:00, 55.35s/it]  


[Ep17] masked MSE=0.581924


EEGConformer SSL Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 43/43 [40:24<00:00, 56.38s/it]  


[Ep18] masked MSE=1.073441


EEGConformer SSL Ep19:  19%|‚ñà‚ñä        | 8/43 [09:14<40:25, 69.31s/it]   


KeyboardInterrupt: 

In [4]:
# ssl_pretrain_eegconformer_fast.py
import os, random, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from braindecode.models import EEGConformer

# ---------------- Config ----------------
CACHE_DIR = "/data5/open_data/HBN/cache_eeg_windows"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== SPEED-OPTIMIZED PARAMETERS ====
SFREQ = 100
N_CHANS_TOTAL = 129
N_CHANS_USED = 64          # ÎûúÎç§ ÏÉòÌîåÎßÅÌï† Ï±ÑÎÑê Ïàò
CROP_T = 400               # crop Îã®Ï∂ï
MASK_RATIO = 0.3
BATCH_SIZE = 16
EPOCHS = 30
LR = 3e-4
WD = 1e-4
SEED = 42
NUM_WORKERS = 8
PREFETCH = 4
ACCUM_STEPS = 2            # gradient accumulation (Í∞ÄÏÉÅ batch = BATCH_SIZE*ACCUM_STEPS)

torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
torch.backends.cudnn.benchmark = True

# ============================================================
# Dataset: cache ‚Üí masked reconstruction
# ============================================================
def make_block_mask(T, mask_ratio=0.3, block_len=40):
    mask = np.ones(T, dtype=bool)
    n_mask = int(T * mask_ratio)
    n_blocks = max(1, n_mask // block_len)
    for _ in range(n_blocks):
        start = np.random.randint(0, T - block_len)
        mask[start:start + block_len] = False
    return mask

class MaskedEEGDataset(Dataset):
    def __init__(self, cache_dir, crop_len=400, mask_ratio=0.3, n_chans_used=64):
        self.files = [os.path.join(cache_dir, f) for f in os.listdir(cache_dir) if f.endswith(".npy")]
        self.crop_len = crop_len
        self.mask_ratio = mask_ratio
        self.n_chans_used = n_chans_used

    def __len__(self): return len(self.files)

    def __getitem__(self, i):
        arr = np.load(self.files[i], mmap_mode="r")          # (segments, C, T)
        seg = arr[random.randint(0, arr.shape[0]-1)]         # (C, T)
        if seg.shape[0] == N_CHANS_TOTAL:                    # Cz Ï†úÍ±∞
            seg = seg[:-1, :]

        # ÎûúÎç§ Ï±ÑÎÑê ÏÉòÌîåÎßÅ (128‚Üí64)
        if seg.shape[0] > self.n_chans_used:
            sel = np.random.choice(seg.shape[0], self.n_chans_used, replace=False)
            seg = seg[sel, :]

        C, T = seg.shape
        if T < self.crop_len:
            pad = np.zeros((C, self.crop_len), np.float32)
            pad[:, :T] = seg
            seg = pad
        else:
            start = np.random.randint(0, T - self.crop_len + 1)
            seg = seg[:, start:start+self.crop_len]

        # normalization
        seg = (seg - seg.mean(axis=1, keepdims=True)) / (seg.std(axis=1, keepdims=True) + 1e-6)
        x = torch.tensor(seg, dtype=torch.float32)           # (C, T)

        keep = make_block_mask(self.crop_len, self.mask_ratio)
        keep = torch.tensor(keep, dtype=torch.bool).unsqueeze(0)  # (1, T)
        return x, keep

# ============================================================
# Model: EEGConformer encoder + 2-layer decoder
# ============================================================
class EEGMaskedAutoencoder(nn.Module):
    def __init__(self, n_chans, n_times, sfreq, embed_dim=256):
        super().__init__()
        self.encoder = EEGConformer(n_chans=n_chans, n_outputs=embed_dim, n_times=n_times, sfreq=sfreq)
        self.decoder = nn.Sequential(
            nn.Conv1d(embed_dim, embed_dim//2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(embed_dim//2, n_chans, kernel_size=1)
        )

    def forward(self, x, keep):
        B, C, T = x.shape
        visible = x * keep.float()
        z = self.encoder(visible)
        if z.ndim == 2:
            z = z.unsqueeze(-1)
        if z.shape[-1] != T:
            z = F.interpolate(z, size=T, mode="linear", align_corners=False)
        return self.decoder(z)

def masked_mse(pred, target, keep, eps=1e-8):
    masked = (~keep).float()
    diff2 = (pred - target) ** 2
    num = (diff2 * masked).sum()
    den = masked.sum() * pred.size(1) + eps
    return num / den

# ============================================================
# Train Loop (with AMP + Grad Accum)
# ============================================================
def main():
    ds = MaskedEEGDataset(CACHE_DIR, crop_len=CROP_T, mask_ratio=MASK_RATIO, n_chans_used=N_CHANS_USED)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=NUM_WORKERS, prefetch_factor=PREFETCH,
                        pin_memory=True, persistent_workers=True, drop_last=True)

    model = EEGMaskedAutoencoder(N_CHANS_USED, CROP_T, SFREQ).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    for ep in range(1, EPOCHS + 1):
        model.train(); running = 0.0
        opt.zero_grad(set_to_none=True)
        for step, (x, keep) in enumerate(tqdm(loader, desc=f"SSL Ep{ep}")):
            x = x.to(DEVICE, non_blocking=True)
            keep = keep.to(DEVICE, non_blocking=True)

            with torch.cuda.amp.autocast():
                x_hat = model(x, keep)
                loss = masked_mse(x_hat, x, keep) / ACCUM_STEPS

            scaler.scale(loss).backward()
            if (step + 1) % ACCUM_STEPS == 0:
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)

            running += loss.item() * ACCUM_STEPS  # Î≥¥Ï†ï

        print(f"[Ep {ep:02d}] masked-recon MSE = {running / len(loader):.5f}")

    torch.save(model.encoder.state_dict(), f"pretrained_eegconformer_fast_{N_CHANS_USED}ch_crop{CROP_T}.pth")
    print(f"‚úÖ Saved: pretrained_eegconformer_fast_{N_CHANS_USED}ch_crop{CROP_T}.pth")

if __name__ == "__main__":
    main()


SSL Ep1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:10<00:00,  1.81it/s]


[Ep 01] masked-recon MSE = 0.99343


SSL Ep2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.12it/s]


[Ep 02] masked-recon MSE = 0.98232


SSL Ep3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:09<00:00,  2.05it/s]


[Ep 03] masked-recon MSE = 0.98091


SSL Ep4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:14<00:00,  1.34it/s]


[Ep 04] masked-recon MSE = 0.99101


SSL Ep5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:07<00:00,  2.62it/s]


[Ep 05] masked-recon MSE = 0.98254


SSL Ep6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  4.78it/s]


[Ep 06] masked-recon MSE = 0.95982


SSL Ep7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  4.85it/s]


[Ep 07] masked-recon MSE = 0.97306


SSL Ep8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.72it/s]


[Ep 08] masked-recon MSE = 0.98195


SSL Ep9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.13it/s]


[Ep 09] masked-recon MSE = 0.98839


SSL Ep10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:05<00:00,  3.20it/s]


[Ep 10] masked-recon MSE = 0.95493


SSL Ep11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.10it/s]


[Ep 11] masked-recon MSE = 0.95773


SSL Ep12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  4.79it/s]


[Ep 12] masked-recon MSE = 0.96595


SSL Ep13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.65it/s]


[Ep 13] masked-recon MSE = 0.94084


SSL Ep14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:04<00:00,  4.74it/s]


[Ep 14] masked-recon MSE = 0.95277


SSL Ep15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.25it/s]


[Ep 15] masked-recon MSE = 0.95060


SSL Ep16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.36it/s]


[Ep 16] masked-recon MSE = 0.96864


SSL Ep17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.75it/s]


[Ep 17] masked-recon MSE = 0.96628


SSL Ep18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.29it/s]


[Ep 18] masked-recon MSE = 0.97154


SSL Ep19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.23it/s]


[Ep 19] masked-recon MSE = 0.99264


SSL Ep20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.30it/s]


[Ep 20] masked-recon MSE = 0.95936


SSL Ep21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.27it/s]


[Ep 21] masked-recon MSE = 0.97573


SSL Ep22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.59it/s]


[Ep 22] masked-recon MSE = 0.95204


SSL Ep23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.55it/s]


[Ep 23] masked-recon MSE = 0.97813


SSL Ep24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  4.87it/s]


[Ep 24] masked-recon MSE = 0.98978


SSL Ep25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.53it/s]


[Ep 25] masked-recon MSE = 0.97026


SSL Ep26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.14it/s]


[Ep 26] masked-recon MSE = 0.93572


SSL Ep27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  4.85it/s]


[Ep 27] masked-recon MSE = 0.97203


SSL Ep28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.25it/s]


[Ep 28] masked-recon MSE = 0.95642


SSL Ep29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.17it/s]


[Ep 29] masked-recon MSE = 0.95923


SSL Ep30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19/19 [00:03<00:00,  5.36it/s]


[Ep 30] masked-recon MSE = 0.94153
‚úÖ Saved: pretrained_eegconformer_fast_64ch_crop400.pth
