<a href="https://colab.research.google.com/github/camillabocciolone/Leonardo-project/blob/main/ll2egpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# proviamo EEGPT con OpenBCI

## import libraries

In [None]:
!pip -q install einops tqdm scikit-learn scipy

import os, sys, numpy as np, torch
from pathlib import Path
from scipy.signal import resample_poly
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report


## --- Repo EEGPT ---

In [4]:
%cd /content/EEGPT
!grep -vE '^av==' requirements.txt > requirements_noav.txt
!pip install -r requirements_noav.txt


/content/EEGPT
Collecting accelerate==0.25.0 (from -r requirements_noav.txt (line 2))
  Using cached accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)
Collecting adan-pytorch==0.1.0 (from -r requirements_noav.txt (line 3))
  Using cached adan_pytorch-0.1.0-py3-none-any.whl.metadata (661 bytes)
Collecting aiohttp==3.8.4 (from -r requirements_noav.txt (line 4))
  Using cached aiohttp-3.8.4.tar.gz (7.3 MB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting aiosignal==1.3.1 (from -r requirements_noav.txt (line 5))
  Using cached aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Collecting albumentations==1.4.1 (from -r requirements_noav.txt (line 6))
  Using cached albumentations-1.4.1-py3-none-any.whl.metadata (36 kB)
Collecting alembic==1.10.3 (from -r requirements_noav.txt (line 7))
  Using cached alembic-1.10.3-py3-none-any.whl.metadata (7.2 kB)
Collecti

In [3]:

%cd /content
if not Path("EEGPT").exists():
    !git clone -q https://github.com/BINE022/EEGPT.git
%cd EEGPT
!pip -q in
!pip install -r requirements.txt


/content
/content/EEGPT
ERROR: unknown command "in"
Collecting accelerate==0.25.0 (from -r requirements.txt (line 2))
  Using cached accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)
Collecting adan-pytorch==0.1.0 (from -r requirements.txt (line 3))
  Using cached adan_pytorch-0.1.0-py3-none-any.whl.metadata (661 bytes)
Collecting aiohttp==3.8.4 (from -r requirements.txt (line 4))
  Using cached aiohttp-3.8.4.tar.gz (7.3 MB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting aiosignal==1.3.1 (from -r requirements.txt (line 5))
  Using cached aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Collecting albumentations==1.4.1 (from -r requirements.txt (line 6))
  Using cached albumentations-1.4.1-py3-none-any.whl.metadata (36 kB)
Collecting alembic==1.10.3 (from -r requirements.txt (line 7))
  Using cached alembic-1.10.3-py3-none-any.whl.metadata (7.2 kB)
C

In [1]:

stall -r requirements.txt

# --- Import modello EEGPT ---
if str(Path.cwd()) not in sys.path:
    sys.path.append(str(Path.cwd()))
from downstream.Modules.models.EEGPT_mcae_finetune import EEGPTClassifier, CHANNEL_DICT

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# --------- 1) Config canali & sampling ----------
# I TUOI CANALI (OpenBCI 8ch)
my_channels = ["Fp1","Fp2","F7","F3","FZ","F4","F8","C2"]  # case-insensitive
# Verifica che siano nel dizionario del repo (usa maiuscole interne)
channels = []
for ch in my_channels:
    key = ch.upper()
    if key not in CHANNEL_DICT:
        raise ValueError(f"Canale non riconosciuto da EEGPT: {ch}. Controlla la nomenclatura 10-20.")
    channels.append(ch)  # tieni l'originale; il wrapper fa upper() da solo

print("Canali passati a EEGPT:", channels)

FS_SRC = 250    # OpenBCI
FS_TGT = 256    # EEGPT pretrain
WIN_SEC = 4.0   # 4 secondi
TGT_LEN = int(FS_TGT * WIN_SEC)  # 1024 campioni
STRIDE_SEC = 2.0  # overlap 50%
STRIDE = int(FS_TGT * STRIDE_SEC)

# --------- 2) Funzioni: resample 250->256 e segmentazione (C,T)->(N,C,1024) ----------
def resample_to_256(x_8xT, fs_src=250, fs_tgt=256):
    # x shape: (C=8, T)
    # usa resample_poly: fattori piccoli e precisi (up=128, down=125 per 250->256)
    up, down = 128, 125
    return resample_poly(x_8xT, up, down, axis=1)

def make_windows(x_8xT_256, win_len=TGT_LEN, stride=STRIDE):
    # x: (C=8, T256)
    C, T = x_8xT_256.shape
    xs = []
    for start in range(0, T - win_len + 1, stride):
        seg = x_8xT_256[:, start:start+win_len]  # (8,1024)
        xs.append(seg)
    return np.stack(xs, axis=0) if xs else None  # (N, 8, 1024)

# --------- 3) Qui simulo il CARICAMENTO dei tuoi dati grezzi ---------
# Sostituisci questa parte con il tuo loader reale (file .csv/.txt/.edf già sincronizzati).
# Creo un dummy "per sessione": 90 s @250 Hz -> 22500 campioni
rng = np.random.default_rng(42)
def make_dummy_session(seconds=90, n_classes=4):
    T = int(FS_SRC * seconds)
    x = rng.normal(size=(8, T)).astype(np.float32)           # (8, T250)
    y_label = rng.integers(0, n_classes)                     # un'etichetta per la sessione (esempio)
    return x, y_label

# Costruisco dataset fittizio (10 sessioni -> poi segmentate in finestre)
N_SESS = 10
N_CLASSES = 4
X_all, y_all = [], []
for _ in range(N_SESS):
    raw, y = make_dummy_session(seconds=90, n_classes=N_CLASSES)     # (8, T250)
    raw_256 = resample_to_256(raw)                                   # (8, T256)
    Xw = make_windows(raw_256)                                       # (N, 8, 1024)
    if Xw is None:
        continue
    X_all.append(Xw)
    y_all += [y]*len(Xw)

X_all = np.concatenate(X_all, axis=0)  # (N, 8, 1024)
y_all = np.array(y_all, dtype=np.int64)
print("Dataset dummy finestrato:", X_all.shape, y_all.shape)

# --------- 4) Split e DataLoader ----------
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import TensorDataset, DataLoader

X_tr, X_te, y_tr, y_te = train_test_split(X_all, y_all, test_size=0.2, stratify=y_all, random_state=0)
X_tr, X_va, y_tr, y_va = train_test_split(X_tr, y_tr, test_size=0.25, stratify=y_tr, random_state=0)  # 60/20/20

BATCH = 64
train_loader = DataLoader(TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr)), batch_size=BATCH, shuffle=True,  pin_memory=True)
val_loader   = DataLoader(TensorDataset(torch.tensor(X_va), torch.tensor(y_va)), batch_size=BATCH, shuffle=False, pin_memory=True)
test_loader  = DataLoader(TensorDataset(torch.tensor(X_te), torch.tensor(y_te)), batch_size=BATCH, shuffle=False, pin_memory=True)

# --------- 5) Modello EEGPT con i TUOI 8 canali ----------
ckpt_path = Path("checkpoint/eegpt_mcae_58chs_4s_large4E.ckpt")  # metti qui il ckpt se ce l'hai
use_pretrained = ckpt_path.exists()

model = EEGPTClassifier(
    num_classes=N_CLASSES,
    ckpt_path=str(ckpt_path) if use_pretrained else None,
    channels=channels,        # <<<<<< 8 canali: il resto è "mascherato"
    sample_rate=FS_TGT,       # 256 Hz (dopo il resample)
    input_time_length=int(WIN_SEC)
).to(device)
print("EEGPT init OK. Pretrained:", use_pretrained)

# --------- 6) Linear probing (congela il backbone) ----------
for p in model.parameters():
    p.requires_grad = False

# EEGPTClassifier in genere restituisce direttamente i logits
criterion = torch.nn.CrossEntropyLoss()
opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-3, weight_decay=1e-4)

def forward_logits(xb):
    xb = xb.to(device).float()  # (B, 8, 1024)
    out = model(xb)
    if isinstance(out, dict) and "logits" in out:
        out = out["logits"]
    return out

@torch.no_grad()
def eval_loader(loader):
    model.eval()
    yp, yt = [], []
    for xb, yb in loader:
        logits = forward_logits(xb)
        yp.append(logits.argmax(1).cpu().numpy())
        yt.append(yb.numpy())
    yp = np.concatenate(yp); yt = np.concatenate(yt)
    return float(accuracy_score(yt, yp)), confusion_matrix(yt, yp, labels=list(range(N_CLASSES)))

EPOCHS, PATIENCE = 5, 3
best_val, noimp = -1.0, 0
print("\n=== Linear Probing (backbone congelato) ===")
for ep in range(1, EPOCHS+1):
    model.train()
    run, seen = 0.0, 0
    pbar = tqdm(train_loader, desc=f"[LP] {ep}/{EPOCHS}")
    for xb, yb in pbar:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        logits = forward_logits(xb)
        loss = criterion(logits, yb)
        loss.backward(); opt.step()
        bs = xb.size(0); run += loss.item()*bs; seen += bs
        pbar.set_postfix(loss=run/max(1,seen))
    va_acc, _ = eval_loader(val_loader)
    print(f"Val ACC: {va_acc:.3f}")
    if va_acc > best_val: best_val, noimp = va_acc, 0
    else:
        noimp += 1
        if noimp >= PATIENCE:
            print("Early stop.")
            break

# --------- 7) Test ----------
te_acc, cm = eval_loader(test_loader)
print("\n=== TEST ===")
print("Test ACC:", te_acc)
print("Confusion Matrix:\n", cm)
print(classification_report(
    np.concatenate([yb.numpy() for _, yb in test_loader]),
    np.concatenate([forward_logits(xb).argmax(1).cpu().numpy() for xb, _ in test_loader]),
    digits=3
))



/content
/content/EEGPT
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herror
[1;31merror[0m: [1msubprocess-exited-with-error[0m

[31m×[0m [32mGetting requirements to build wheel[0m did not run

  @autocast(True)
  @autocast(True)


Device: cpu
Canali passati a EEGPT: ['Fp1', 'Fp2', 'F7', 'F3', 'FZ', 'F4', 'F8', 'C2']
Dataset dummy finestrato: (440, 8, 1024) (440,)


TypeError: 'NoneType' object is not iterable