In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import os.path as op
import mne
import scipy.io
import os
import glob
import csv
from mne.io import read_epochs_fieldtrip
from mne import create_info
import numpy as np
from burst_detection import extract_bursts
from burst_detection import extract_bursts_single_trial
from fooof import FOOOF
import mat73
import h5py
import json

In [None]:
def make_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (np.int64, np.float64)):
        return obj.item()
    raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

In [None]:
age_intervals = {
    'three': [8.1, 12.4],
    'six': [7.6, 12.4], 
    'twelve': [10.15, 16.35]
}
waveform_times = {
    'three':0.488,
    'six':0.5,
    'twelve':0.377
}
def detect_bursts(age):
    c3_chans = ["C3", "FC3", "CP3", "C1"]
    c4_chans = ["C4", "FC4", "CP4", "C2"]
    ch_names = c3_chans + c4_chans  
    
                                                          
    subject_age_path=os.path.join('/home/common/bonaiuto/devmobeta/derivatives/')
    subjects= os.path.join(subject_age_path,'participants_v2.csv')
    
    age_bursts=[]
    sfreq = 512
    
    subjects_path = '/home/common/bonaiuto/devmobeta/data/sub-*'
    subject_ids = [os.path.split(x)[-1] for x in glob.glob(subjects_path)]
    
    with open(subjects, 'r') as csvfile:
        reader = list(csv.DictReader(csvfile, delimiter=','))
        for row in reader:
            print("Processing row:", row)
            subject=row['subject_id']
            session=row['session']
            subject_age = row['age']
            beta_lims = age_intervals[subject_age]
            w_size = waveform_times[subject_age]
            system=row['eeg_system']

            if subject_age != age:
                continue 
            if subject == 'sub-252' and session =='ses-01':  #outlier
                continue
            
            if system == "BrainVision":
                print(f"Processing subject {subject} with session {session} (BrainVision)")

                sub_path = os.path.join(subject_age_path, subject)
                ses_path = os.path.join(sub_path, session)

                if os.path.exists(os.path.join(ses_path,'eeg','NEARICA_NF')): 
                    print(f"Processing subject {subject} with age {subject_age} and session {session}")
                    superlet_path=os.path.join(ses_path,'eeg','NEARICA_NF')
                    fname = os.path.join(superlet_path, f'04_rereferenced_data/{subject}_task-devmobeta_grasp_eeg_rereferenced_data.set')
                    if os.path.exists(fname):
                        data = mne.read_epochs_eeglab(fname)
                        print("Event IDs in data:", data.event_id)  

                
                        ch_indices = [data.ch_names.index(ch) for ch in ch_names]
                        times = data.times
               
                        for epoch, event_label in zip(['S  2', 'S  3'], ['go', 'grsp']):  # Correspondance des labels
                            if epoch in data.event_id:  # Vérifie l'event_id correspondant
                                event_code = data.event_id[epoch]
                                trial_indices = np.where(data.events[:, 2] == event_code)[0]
                                epoch_fname = os.path.join(superlet_path, f'{subject}_{event_label}_processed_superlet_tf.mat')
                                if os.path.exists(epoch_fname):
                                    with h5py.File(epoch_fname, 'r') as f:                                    
                                        trial_tf = np.array(f['trial_tf'])
                                        foi=np.array(f['foi'])
                                        print('superlet ok')
                                            
                                        for ch_idx, ch_name in zip(ch_indices, ch_names):  
                                            #try:
                                            chan_raw = data.get_data()[trial_indices, ch_idx, :]
                                            chan_tf = trial_tf[:,ch_idx,:,:]

                                            average_psd = np.mean(chan_tf, axis=(0, -1))
                                            

                                            if np.all(average_psd == 0): 
                                                print(f"Skipping channel {ch_name} because PSD is all zeros.")
                                                continue
                                                
                                            beta_idx = np.where((foi >= beta_lims[0]) & (foi <= beta_lims[1]))[0]
                                            beta_power = np.mean(chan_tf[:, beta_idx, :])


                                            chan_tf = np.transpose(chan_tf, (2, 1, 0)) # trial x freq x time

                                            print(f"Subject {subject} has age {subject_age}, interval: {beta_lims}")
                                            search_range = np.where((foi >= beta_lims[0]-3) & (foi <= beta_lims[1]+3))[0]

                                            ff = FOOOF()
                                            ff.fit(np.squeeze(foi), average_psd, [1, 50]) 
                                            ap = 10 ** ff._ap_fit 
                                            
                                            # Fit du modèle aperiodique
                                            ff = FOOOF()
                                            ff.fit(np.squeeze(foi), average_psd, [1, 50]) 

                                            # Récupération de l'aperiodique
                                            ap = 10 ** ff._ap_fit 

                                            # DEBUG PLOTS
                                            #plt.figure(figsize=(10, 6))
                                            #plt.semilogy(foi, average_psd, label='Original PSD', color='blue')
                                            #plt.semilogy(foi, ap, label='FOOOF Aperiodic Fit', color='orange')
                                            #plt.semilogy(foi, average_psd - ap, label='PSD - Aperiodic', color='green')
                                            #plt.axvspan(beta_lims[0]-3, beta_lims[1]+3, color='red', alpha=0.2, label='Search range')
                                            #plt.xlabel('Frequency (Hz)')
                                            #plt.ylabel('Power (log scale)')
                                            #plt.title(f'Debug FOOOF Fit - {ch_name} - {event_label}')
                                            #plt.legend()
                                            #plt.show()

                                            # Debug numérique
                                            #print(f"[DEBUG] Mean PSD: {np.mean(average_psd):.4e}")
                                            #print(f"[DEBUG] Mean AP: {np.mean(ap):.4e}")
                                            #print(f"[DEBUG] Mean Residual: {np.mean(average_psd - ap):.4e}")
                                            #print(f"[DEBUG] Min/Max Residual: {np.min(average_psd - ap):.4e} / {np.max(average_psd - ap):.4e}")

                                            

                                            bursts = extract_bursts(
                                                 chan_raw, chan_tf[:,search_range,:], times, 
                                                 foi[search_range], beta_lims, 
                                                 ap[search_range].reshape(-1,1), sfreq, w_size=w_size) 
                                            #plt.figure()
                                            #plt.plot(np.mean(bursts['waveform'],axis=0))
                                            #plt.title(f"Mean burst waveform for {ch_idx}, {epoch}, ")
                                            #plt.show()

                                            if ch_name in c3_chans:
                                                cluster = 'C3'
                                            elif ch_name in c4_chans:
                                                cluster = 'C4'
                                            else:
                                                cluster = 'unknown'  

                                            age_bursts.append({
                                                 'subject': subject,
                                                 'epoch': event_label,
                                                 'session': session,
                                                 'channel': ch_name,
                                                 'system': system,
                                                 'cluster' : cluster,
                                                 'bursts': bursts,
                                                 'beta_power': beta_power 
                                                  })    
                                                      
#                                             except Exception as e:
#                                                 print(f"Error processing electrode {ch_name} (index{ch_idx}): {e}")
#                                                 continue
                            else:
                                print(f"Event {epoch} not found in data.event_id for subject {subject}.")
                    else:
                        print(f"File {fname} not found.")
                else:
                    print(f"NEARICA_NF folder does not exist for subject {subject}.")
                            
                     
    unique_subjects = set(entry['subject'] for entry in age_bursts)
    print(f"N {age}: {len(unique_subjects)}")
    print(f"N {age}: {len(age_bursts)}")
    age_bursts_serializable = json.loads(json.dumps(age_bursts, default=make_serializable))
        
    with open(f'bursts_{age}_BV.json', 'w') as json_file:
        json.dump(age_bursts_serializable, json_file)
    
    from collections import defaultdict

    subject_burst_counts = defaultdict(int)
    
    for entry in age_bursts:
        subject_burst_counts[entry['subject']] += len(entry['bursts']['trial'])

    
    print(" Résumé des bursts par sujet")
    for subj, count in subject_burst_counts.items():
        print(f"Sujet {subj} : {count} bursts")
    
    with open(subjects, 'r') as csvfile:
        reader = list(csv.DictReader(csvfile, delimiter=','))
        all_subjects_with_age = {row['subject_id'] for row in reader if row['age'] == age and row['eeg_system'] == "BrainVision"}
    
    no_burst_subjects = all_subjects_with_age - subject_burst_counts.keys()
    
    print("Subject with 0 burst")
    for subj in sorted(no_burst_subjects):
        print(f"{subj}")
    return age_bursts

In [None]:
bursts_three = detect_bursts('three')
bursts_six = detect_bursts('six')
bursts_twelve = detect_bursts('twelve')

In [None]:
fwhm_times = []
for entry in bursts_three:
    bursts = entry.get('bursts', {})
    fwhm_time_list = bursts.get('fwhm_time', [])
    for fwhm_time in fwhm_time_list:
        try:
            fwhm_times.append(float(fwhm_time))
        except ValueError:
            print(f"Invalid fwhm_time value: {fwhm_time}")

print("Extracted fwhm_times:", fwhm_times)

f, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.hist(fwhm_times, bins=20, color="#DFFF00", edgecolor='black', linewidth=0.2)
ax.set_title("Burst Duration")
ax.set_xlabel("FWHM Time")
ax.set_ylabel("Frequency")
plt.show()

In [None]:
fwhm_times = []
fwhm_freqs = []
for entry in bursts_three:
    bursts = entry.get('bursts', {})
    fwhm_freq_list = bursts.get('fwhm_freq', [])
    for freq_list in fwhm_freq_list:
        if isinstance(freq_list, list):
            for fwhm_freq in freq_list:
                try:
                    fwhm_freqs.append(float(fwhm_freq))
                except ValueError:
                    print(f"Invalid fwhm_freq value: {fwhm_freq}")
        else:
            try:
                fwhm_freqs.append(float(freq_list))
            except ValueError:
                print(f"Invalid fwhm_freq value: {freq_list}")

print("Extracted fwhm_freqs:", fwhm_freqs)

f, ax = plt.subplots(1, 1, figsize=(14, 5))
ax.hist(fwhm_freqs, bins=10, color="#DFFF00", edgecolor='black', linewidth=0.2)
ax.set_title("Frequency Span")
ax.set_xlabel("FWHM Frequency")
ax.set_ylabel("Frequency")
plt.show()

In [None]:
peak_time = []
for entry in bursts_three:
    bursts = entry.get('bursts', {})
    peak_time_list = bursts.get('peak_time', [])
    for freq_list in peak_time_list:
        if isinstance(freq_list, list):
            for peak_time in freq_list:
                try:
                    peak_time.append(float(peak_time))
                except ValueError:
                    print(f"Invalid peak_time_value: {peak_time}")
        else:
            try:
                peak_time.append(float(freq_list))
            except ValueError:
                print(f"Invalid peak_time_value: {freq_list}")

print("Extracted peak_time:", peak_time)

f, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.hist(peak_time, bins=20, color="#DFFF00", edgecolor='black', linewidth=0.2)
#ax.plot(foi[search_range], residual_search_power*0.75e2, lw=0.5, c="black", label="scaled periodic power")
ax.legend()
ax.set_title("Peak time")
ax.set_xlabel("Peak time")
ax.set_ylabel("Frequency")
plt.show()

In [None]:
peak_freqs = []
for entry in bursts_three:
    bursts = entry.get('bursts', {})
    peak_freq_list = bursts.get('peak_freq', [])
    for freq_list in peak_freq_list:
        if isinstance(freq_list, list):
            for peak_freq in freq_list:
                try:
                    peak_freqs.append(float(peak_freq))
                except ValueError:
                    print(f"Invalid peak_freq value: {peak_freq}")
        else:
            try:
                peak_freqs.append(float(freq_list))
            except ValueError:
                print(f"Invalid peak_freq value: {freq_list}")

print("Extracted peak_freqs:", peak_freqs)

f, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.hist(peak_freqs, bins=20, color="#DFFF00", edgecolor='black', linewidth=0.2)
#ax.plot(foi[search_range], residual_search_power*0.75e2, lw=0.5, c="black", label="scaled periodic power")
ax.legend()
ax.set_title("Peak Frequency")
ax.set_xlabel("Peak Frequency")
ax.set_ylabel("Frequency")
plt.show()

In [None]:
peak_amp_bases = []
for entry in bursts_three:
    bursts = entry.get('bursts', {})
    peak_amp_base_list = bursts.get('peak_amp_base', [])
    for peak_amp_base in peak_amp_base_list:
        if isinstance(peak_amp_base, list):
            for amp in peak_amp_base:
                try:
                    peak_amp_bases.append(float(amp))
                except ValueError:
                    print(f"Invalid peak_amp_base value: {amp}")
        else:
            try:
                peak_amp_bases.append(float(peak_amp_base))
            except ValueError:
                print(f"Invalid peak_amp_base value: {peak_amp_base}")

print("Extracted peak_amp_bases:", peak_amp_bases)

f, ax = plt.subplots(1, 1, figsize=(7, 5))
ax.hist(peak_amp_bases, bins=20, color="#DFFF00", edgecolor='black', linewidth=0.2)
ax.set_title("Peak Amplitude")
ax.set_xlabel("Peak Amplitude Base")
ax.set_ylabel("Frequency")
plt.show()