In [3]:
import os, glob, csv, h5py
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

In [4]:
age_intervals = {
    'three': [8.1, 12.4],
    'six': [7.6, 12.4],
    'twelve': [10.15, 16.35]
}

In [5]:
def load_participant_info(csv_path):
    participant_info = {}
    with open(csv_path, 'r') as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            participant_info[(row['subject_id'], row['session'])] = {
                'age': row['age'],
                'system': row['eeg_system']
            }
    return participant_info

In [6]:
def compute_beta_power(age, test_subject=None, test_session=None):
    subject_age_path = '/home/common/bonaiuto/devmobeta/derivatives/'
    csv_path = os.path.join(subject_age_path, 'participants_v2.csv')
    participant_info = load_participant_info(csv_path)
    
    epochs = ['go', 'grsp']
    baseline_window = [-4.5, -1]
    
    # Dictionnaire pour stocker les données par sujet
    subject_data = {}
    epoch_time = {}

    if test_subject is not None:
        subject_files = [os.path.join(subject_age_path, test_subject)]
        print(f"test : {test_subject}/{test_session or '*'}")
    else:
        subject_files = glob.glob(os.path.join(subject_age_path, 'sub-*'))

    subjects_processed = 0
    sessions_skipped = 0

    # Traitement sujet par sujet
    for subject_dir in subject_files:
        subject = os.path.basename(subject_dir)
        sessions = [d for d in os.listdir(subject_dir) if os.path.isdir(os.path.join(subject_dir, d))]
        if subject == 'sub-259' and session == 'ses-01':
            continue
        if subject == 'sub-262' and session == 'ses-01':
            continue
        if test_session is not None:
            sessions = [s for s in sessions if s == test_session]

        for session in sessions:
            info = participant_info.get((subject, session))
            if info is None or info['age'] != age:
                sessions_skipped += 1
                continue

            #print(f" processing {subject}/{session} (age: {age})")
            superlet_path = os.path.join(subject_age_path, subject, session, 'eeg', 'NEARICA_NF')
            if not os.path.exists(superlet_path):
                continue

            session_key = f"{subject}_{session}"
            session_epochs = {}
            mean_base_go = None

            # Charger les données pour chaque epoch
            for epoch in epochs:
                fname = os.path.join(superlet_path, f'{subject}_{epoch}_processed_superlet_tf.mat')
                if not os.path.exists(fname):
                    continue

                try:
                    with h5py.File(fname, 'r') as f:
                        trial_tf = np.array(f['trial_tf'])  # (time, chan, freq, trial)
                        foi = np.array(f['foi']).squeeze()
                        if trial_tf.ndim != 4:
                            continue

                        times = np.linspace(-5, 5, trial_tf.shape[0])
                        if epoch not in epoch_time:
                            epoch_time[epoch] = times

                        beta_idx = np.where(
                            (foi >= age_intervals[age][0]) & (foi <= age_intervals[age][1])
                        )[0]
                        if len(beta_idx) == 0:
                            continue

                        # Moyenne sur chan, freq beta, trials
                        power = np.mean(trial_tf[:, :, beta_idx, :], axis=(1, 2, 3))
                        
                        session_epochs[epoch] = power

                except Exception as e:
                    print(f"Erreur {epoch} ({subject}/{session}) : {e}")
                    continue

            # Normalisation sur la baseline 'go' 
            if 'go' in session_epochs:
                times = epoch_time['go']
                baseline_mask = (times >= baseline_window[0]) & (times <= baseline_window[1])
                mean_base_go = np.mean(session_epochs['go'][baseline_mask])
                
                if np.isnan(mean_base_go) or mean_base_go == 0:
                    continue
                
                
                subj_normalized = {}
                for epoch in epochs:
                    if epoch in session_epochs:
                        power = session_epochs[epoch]
                        normalized_power = 100.0 * (power - mean_base_go) / mean_base_go
                        
                        # Ajouter uniquement si pas de NaN
                        if not np.any(np.isnan(normalized_power)):
                            subj_normalized[epoch] = normalized_power
                
                # Stocker uniquement si les deux epochs sont présents
                if len(subj_normalized) == len(epochs):
                    subject_data[session_key] = subj_normalized
                    subjects_processed += 1

    # Rejet des outliers
    age_beta_power = {epoch: [] for epoch in epochs}
    
    for epoch in epochs:
        # Collecter toutes les séries pour cet epoch
        all_series = [subj_data[epoch] for subj_data in subject_data.values() if epoch in subj_data]
        
        if len(all_series) == 0:
            continue
        
        # Calculer mean et std globaux
        all_values = np.concatenate(all_series)
        mean_val = np.mean(all_values)
        std_val = np.std(all_values)
        
        def is_outlier(signal):
            return np.any(np.abs(signal - mean_val) > 2.5 * std_val)
        
        # Filtrer les outliers
        rejected = 0
        for session_key, subj_data in subject_data.items():
            if epoch in subj_data:
                if not is_outlier(subj_data[epoch]):
                    age_beta_power[epoch].append(subj_data[epoch])
                else:
                    rejected += 1
        
        print(f"{epoch}: {rejected} rejected / {len(all_series)} sessions")
        
        # Conversion en array numpy
        age_beta_power[epoch] = np.array(age_beta_power[epoch])

    print(f" : {subjects_processed} sessions processed, {sessions_skipped} skipped\n")

    return epoch_time, age_beta_power


In [7]:
def plot_beta_power_by_age(age, test_subject=None, test_session=None):
    epoch_time, age_beta_power = compute_beta_power(age, test_subject, test_session)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
    age_labels = {'three': '3 m', 'six': '6 m', 'twelve': '12 m'}

    # Calculer les limites Y globales pour les deux epochs
    y_min, y_max = np.inf, -np.inf
    for epoch in ['go', 'grsp']:
        data = age_beta_power[epoch]
        if len(data) > 0:
            mean_power = np.mean(data, axis=0)
            sem_power = np.std(data, axis=0) / np.sqrt(data.shape[0])
            y_min = min(y_min, np.min(mean_power - sem_power))
            y_max = max(y_max, np.max(mean_power + sem_power))
    
    # Ajouter une marge
    y_range = y_max - y_min
    y_min -= 0.1 * y_range
    y_max += 0.1 * y_range

    for i, epoch in enumerate(['go', 'grsp']):
        ax = axes[i]
        data = age_beta_power[epoch]

        if len(data) == 0:
            ax.text(0.5, 0.5, 'no data', ha='center', va='center')
            continue

        times = epoch_time[epoch]
        mean_power = np.mean(data, axis=0)
        sem_power = np.std(data, axis=0) / np.sqrt(data.shape[0])

        # Tracés individuels: 
        # for subj_series in data:
        #     ax.plot(times, subj_series, color='gray', alpha=0.3, lw=1)

        # Moyenne avec SEM
        ax.plot(times, mean_power, color='blue' if epoch == 'go' else 'green',
                lw=2.5, label='mean')
        ax.fill_between(times, mean_power - sem_power, mean_power + sem_power,
                        color='blue' if epoch == 'go' else 'green', alpha=0.3)

        ax.axvline(0, color='red', linestyle='--', lw=1.5)
        ax.axhline(0, color='black', lw=0.8, alpha=0.5)
        ax.set_xlabel("Time (s)")
        if i == 0:
            ax.set_ylabel("Power % change")
        ax.set_title(f"{epoch.upper()} - {age_labels[age]} (n={data.shape[0]})")
        ax.set_ylim(y_min, y_max)
        ax.grid(alpha=0.3)
        ax.legend()

    plt.tight_layout()
    plt.show()

In [10]:
plot_beta_power_by_age('three')
#plot_beta_power_by_age('six')
#plot_beta_power_by_age('twelve')

IndexError: boolean index did not match indexed array along axis 0; size of axis is 3000 but size of corresponding boolean axis is 5000

In [12]:
def diagnose_file_dimensions(age):
    subject_age_path = '/home/common/bonaiuto/devmobeta/derivatives/'
    csv_path = os.path.join(subject_age_path, 'participants_v2.csv')
    participant_info = load_participant_info(csv_path)
    
    dimension_stats = {}
    
    for subject_dir in glob.glob(os.path.join(subject_age_path, 'sub-*')):
        subject = os.path.basename(subject_dir)
        sessions = [d for d in os.listdir(subject_dir) if os.path.isdir(os.path.join(subject_dir, d))]
        
        for session in sessions:
            info = participant_info.get((subject, session))
            if info is None or info['age'] != age:
                continue
                
            superlet_path = os.path.join(subject_age_path, subject, session, 'eeg', 'NEARICA_NF')
            
            for epoch in ['go', 'grsp']:
                fname = os.path.join(superlet_path, f'{subject}_{epoch}_processed_superlet_tf.mat')
                if not os.path.exists(fname):
                    continue
                    
                try:
                    with h5py.File(fname, 'r') as f:
                        trial_tf = np.array(f['trial_tf'])
                        n_time = trial_tf.shape[0]
                        
                        if n_time not in dimension_stats:
                            dimension_stats[n_time] = []
                        dimension_stats[n_time].append(f"{subject}/{session}/{epoch}")
                except:
                    continue
    
    print(f"\n=== Diagnostic des dimensions temporelles pour age={age} ===")
    for n_time, files in sorted(dimension_stats.items()):
        print(f"\n{n_time} points temporels ({len(files)} fichiers):")
        for f in files[:5]:  # Afficher les 5 premiers
            print(f"  - {f}")
        if len(files) > 5:
            print(f"  ... et {len(files)-5} autres")

# Appelez cette fonction avant vos plots
diagnose_file_dimensions('three')


=== Diagnostic des dimensions temporelles pour age=three ===

4 points temporels (16 fichiers):
  - sub-209/ses-01/grsp
  - sub-203/ses-02/grsp
  - sub-235/ses-01/grsp
  - sub-221/ses-02/grsp
  - sub-219/ses-01/grsp
  ... et 11 autres

3000 points temporels (1 fichiers):
  - sub-262/ses-01/go

5000 points temporels (89 fichiers):
  - sub-241/ses-01/go
  - sub-241/ses-01/grsp
  - sub-236/ses-01/go
  - sub-236/ses-01/grsp
  - sub-217/ses-01/go
  ... et 84 autres
