In [4]:
import numpy as np
import os
from collections import defaultdict
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from braindecode import EEGClassifier
from braindecode.models import EEGNet

from sklearn.base import BaseEstimator, TransformerMixin
from mne.decoding import CSP
from braindecode.models import ShallowFBCSPNet, Deep4Net
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import Pipeline

from sklearn.linear_model import LogisticRegression
import numpy as np
import scipy.linalg as la

import numpy as np
from scipy.linalg import eigh

sample_file = "/var/datasets/physionet.org/files/MI/files/eegmmidb/1.0.0/S001/S001R03.edf"

def load_data(file_path, tmin=-1, tmax=5):
    import mne
    import numpy as np

    # --- carica il file EDF ---
    raw = mne.io.read_raw_edf(file_path, preload=True, stim_channel='Event marker', verbose=False)
    sfreq = raw.info['sfreq']

    # --- filtra i dati ---
    raw.filter(1., 40., fir_design='firwin', verbose=False)

    # --- estrai gli eventi ---
    events, event_id = mne.events_from_annotations(raw, verbose=False)

    print(f"Event IDs found: {event_id}")
    print(f"events: {events}")



    # --- crea gli epochs ---
    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True, verbose=False)

    # --- mappa le etichette ---
    if 'T0' in event_id and 'T1' in event_id and 'T2' in event_id:
        # left hand vs right hand
        label_map = {'T0': 0, 'T1': 1, 'T2': 2}
    elif 'T3' in event_id and 'T4' in event_id:
        # both hands vs both feet
        label_map = {'T3': 4, 'T4': 5}
    elif 'T0' in event_id and 'T1' in event_id and 'T2' in event_id:
        # left hand vs right hand vs feet
        label_map = {'T0': 0, 'T1': 1, 'T2': 2}
    else:
        return None, None
    # --- estrai i dati per ogni tipo di evento ---
    data = []
    labels = []
    for code, label in label_map.items():
        if code in epochs.event_id:  # controlla che esista
            ep_data = epochs[code].get_data()  # (n_epochs, n_channels, n_times)
            # check if the n_times is 161
            if ep_data.shape[2] != int((tmax - tmin) * sfreq) + 1 or sfreq != 160:
                print(f"Skipping {file_path} due to unexpected n_times: {ep_data.shape[2]}")
                continue
            data.append(ep_data)
            labels.extend([label] * len(ep_data))
    return np.concatenate(data), np.array(labels)




In [None]:
import torch

def inspect_weights(file_path):
    # Carica il checkpoint (in CPU per evitare problemi di memoria se non si è su GPU)
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))

    if 'model_state_dict' not in checkpoint:
        print("Errore: La chiave 'model_state_dict' non è stata trovata nel file.")
        return

    state_dict = checkpoint['model_state_dict']
    print(f"Ispezione dei pesi del modello nel file: {file_path}\n")

    # Stampa le shape di ogni tensore salvato
    for name, tensor in state_dict.items():
        # I tensori del backbone Cbramod avranno nomi lunghi
        # I tensori del classifier avranno nomi come 'classifier.0.weight' se è un SimpleFeaturesClassifier
        
        # Identificazione del componente
        if 'feature_extractor' in name:
            component = "BACKBONE (CBraMod)"
        elif 'classifier' in name:
            component = "CLASSIFIER HEAD"
        else:
            component = "ALTRO/DVAE"

        print(f"[{component:<20}] {name:<60} Shape: {list(tensor.shape)}")

# Esempio di utilizzo:
# inspect_weights('/home/burger/canWeReally/experiments/Mode_F_MI_PORCA_MISA_Misaaaa_continue_learning_rrr/last_model_weights.pt')

In [3]:
path_to_weights = '/home/burger/canWeReally/experiments/Mode_F_MI_PORCA_MISA_Misaaaa_continue_learning_rrrrrr/last_model_weights.pt'
inspect_weights(path_to_weights)

Ispezione dei pesi del modello nel file: /home/burger/canWeReally/experiments/Mode_F_MI_PORCA_MISA_Misaaaa_continue_learning_rrrrrr/last_model_weights.pt

[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.mask_encoding        Shape: [200]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.positional_encoding.0.weight Shape: [200, 1, 19, 7]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.positional_encoding.0.bias Shape: [200]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.proj_in.0.weight     Shape: [25, 1, 1, 49]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.proj_in.0.bias       Shape: [25]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.proj_in.1.weight     Shape: [25]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.proj_in.1.bias       Shape: [25]
[BACKBONE (CBraMod)  ] feature_extractor.model.patch_embedding.proj_in.3.weight     Shape: [25, 25, 1, 3]
[BACKBONE (CBraMod)  ] feature

In [5]:
X, y = load_data(sample_file)


Event IDs found: {np.str_('T0'): 1, np.str_('T1'): 2, np.str_('T2'): 3}
events: [[    0     0     1]
 [  672     0     3]
 [ 1328     0     1]
 [ 2000     0     2]
 [ 2656     0     1]
 [ 3328     0     2]
 [ 3984     0     1]
 [ 4656     0     3]
 [ 5312     0     1]
 [ 5984     0     3]
 [ 6640     0     1]
 [ 7312     0     2]
 [ 7968     0     1]
 [ 8640     0     2]
 [ 9296     0     1]
 [ 9968     0     3]
 [10624     0     1]
 [11296     0     2]
 [11952     0     1]
 [12624     0     3]
 [13280     0     1]
 [13952     0     3]
 [14608     0     1]
 [15280     0     2]
 [15936     0     1]
 [16608     0     2]
 [17264     0     1]
 [17936     0     3]
 [18592     0     1]
 [19264     0     2]]
