In [100]:
import os
import glob
import mne
import numpy as np
import pandas as pd

INPUT_DIRS = [
    'control_counting',
    'control_resting_state',
    'nimitta',
    'rest_baseline',
    'resting_state',
    'resting_state_byjhana_concatenated'
]

OUTPUT_BASE = 'ALL_NoooAr'

os.makedirs(OUTPUT_BASE, exist_ok=True)

for folder in ['raw', '03s', '10s', 'raw_interpolated', '03s_ar', '10s_ar']:
    os.makedirs(os.path.join(OUTPUT_BASE, folder), exist_ok=True)

summary = []

def parse_metadata(path):
    fname = os.path.basename(path).lower()
    control = '1' if 'control' in path else '0'
    
    sub = fname.split('-')[0].replace('sub', '').zfill(2)
    day = 'NA'
    if 'day' in fname:
        try:
            day = fname.split('-')[1].replace('day', '')
            day = str(day).zfill(2)
        except IndexError:
            pass
    
    if 'counting' in path:
        condition = 'count'
    elif 'mindfulness' in path:
        condition = 'mindf'
    elif 'nimitta' in path:
        condition = 'nimit'
    elif 'rest_eyes_closed' in fname:
        condition = 'rposc' if 'post' in fname else 'rprec'
    elif 'rest_eyes_open' in fname:
        condition = 'rposo' if 'post' in fname else 'rpreo'
    elif 'jhana-j' in fname:
        jn = fname.split('-')[-1].replace('.fif', '')
        condition = f'jhan{jn[-1]}'
    elif 'jhana' in fname:
        condition = 'jhana'
    else:
        condition = 'unkn'
        
    return control, sub, day, condition

def build_filename(control, sub, day, condition, n_elem, length, bads, kind):
    day_str = day if day else ''
    return f"{control}_{sub}_{day_str}_{condition}_{n_elem}_{length}_{bads}_{kind}.fif"

def remove_bads(raw):
    
    bad_segments = []
    for annot in raw.annotations:
        if annot['description'].lower() == 'bad_':
            bad_segments.append((annot['onset']-raw.first_time, annot['onset'] + annot['duration']-raw.first_time))

    good_segments = []
    start_time = 0

    for bad_start, bad_end in bad_segments:
        if bad_start - start_time >= 10:
            good_segments.append((start_time, bad_start))
        start_time = bad_end 

    if start_time < raw.times[-1]:  
        good_segments.append((start_time, raw.times[-1]))

    good_raws = [raw.copy().crop(tmin, tmax) for tmin, tmax in good_segments]

    if len(good_raws) > 0:
        raw_nobad = mne.concatenate_raws(good_raws)
    else: 
        raw_nobad = None

    return raw_nobad, str(len(good_segments)).zfill(3), str(int(raw_nobad.times[-1])).zfill(4)

for folder in INPUT_DIRS:
    files = sorted(glob.glob(os.path.join(folder, "*.fif")))[:1]
    for file in files: 
        print(f"🔄 Processing {file}")
        control, sub, day, condition = parse_metadata(file)
        raw = mne.io.read_raw_fif(file, preload=True)
        num_bad_channels = str(len(raw.info['bads'])).zfill(2)
        print(raw.info['bads'])
        raw_nobad, n_elem, length = remove_bads(raw)    

        
        fname = build_filename(control, sub, day, condition, n_elem, length, num_bad_channels, 'raw')
        fpath = os.path.join(OUTPUT_BASE, 'raw', fname)
        raw_nobad.save(fpath, overwrite=True)
        
        epochs_3 = mne.make_fixed_length_epochs(raw, duration=3, preload=True)
        epochs_10 = mne.make_fixed_length_epochs(raw, duration=10, preload=True)
        
        n_epochs_3 = len(epochs_3)
        length_3 = n_epochs_3 * 3
        n_epochs_3 = str(n_epochs_3).zfill(3)
        length_3 = str(length_3).zfill(4)

        n_epochs_10 = len(epochs_10)
        length_10 = n_epochs_10 * 10
        n_epochs_10 = str(n_epochs_10).zfill(3)
        length_10 = str(length_10).zfill(4)

        fname_3 = build_filename(control, sub, day, condition, n_epochs_3, length_3, num_bad_channels, 'epo')
        fname_10 = build_filename(control, sub, day, condition, n_epochs_10, length_10, num_bad_channels, 'epo')
        
        fpath_3 = os.path.join(OUTPUT_BASE, '03s', fname_3)
        epochs_3.save(fpath_3, overwrite=True)
        
        fpath_10 = os.path.join(OUTPUT_BASE, '10s', fname_10)
        epochs_10.save(fpath_10, overwrite=True)

        # interpolate and run AR
        raw_nobad.interpolate_bads(reset_bads=True)
        fpath = os.path.join(OUTPUT_BASE, 'raw_interpolated', fname)
        raw_nobad.save(fpath, overwrite=True)

        ar = AutoReject()
        epochs_3_ar, reject_log_3 = ar.fit_transform(epochs_3, return_log=True)
        epochs_3_ar.interpolate_bads(reset_bads=True)
        fpath_3_ar = os.path.join(OUTPUT_BASE, '03s_ar', fname_3)
        epochs_3_ar.save(fpath_3_ar, overwrite=True)

        
        ar = AutoReject()
        epochs_10_ar, reject_log_10 = ar.fit_transform(epochs_10, return_log=True)
        epochs_10_ar.interpolate_bads(reset_bads=True)
        fpath_10_ar = os.path.join(OUTPUT_BASE, '10s_ar', fname_10)
        epochs_10_ar.save(fpath_10_ar, overwrite=True)
        
        n_dropped_3 = reject_log_3.bad_epochs.sum()
        n_interp_total_3 = np.nansum(reject_log_3.labels)
    
        n_dropped_10 = reject_log_10.bad_epochs.sum()
        n_interp_total_10 = np.nansum(reject_log_10.labels)

        fig = reject_log_3.plot(show=False)
        fig.savefig(os.path.join(OUTPUT_BASE, '03s_ar', fname_3.replace('.fif', '_rejlog.png')))
        fig = reject_log_10.plot(show=False)
        fig.savefig(os.path.join(OUTPUT_BASE, '10s_ar', fname_10.replace('.fif', '_rejlog.png')))
        
        summary.append({
            'original_name': file, 'filename': fname, 
            'control': control, 'sub': sub, 'day': day, 'condition': condition,
            'bad_channels': num_bad_channels,
            'num_03s_epochs': n_epochs_3, 'num_10s_epochs': n_epochs_10, 
            'length_raw': length, 'length_03s': length_3, 'length_10s': length_10,
            'num_03s_epochs_ar': str(len(epochs_3_ar)).zfill(3), 'num_10s_epochs_ar': str(len(epochs_10_ar)).zfill(3),
            'length_03s_ar': str(int(len(epochs_3_ar))*3).zfill(4), 'length_10s_ar': str(int(len(epochs_10_ar))*10).zfill(4),
            'n_dropped_3': n_dropped_3, 'n_dropped_10': n_dropped_10,
            'n_interp_total_3': n_interp_total_3, 'n_interp_total_10': n_interp_total_10,
            'percent_retained_03s': round(len(epochs_3_ar) / int(n_epochs_3) * 100, 2),
            'percent_retained_10s': round(len(epochs_10_ar) / int(n_epochs_10) * 100, 2),

        })



    # Save combined summary
    df = pd.DataFrame(summary)
    summary_path = os.path.join(OUTPUT_BASE, 'summary.csv')
    df.to_csv(summary_path, index=False)
    df.to_excel(os.path.join(OUTPUT_BASE, 'summary.xlsx'), index=False)
    print(f"📝 Summary saved to {summary_path}")
    
    print("✅ All files processed.")


🔄 Processing control_counting/sub0-control-counting-raw.fif
Opening raw data file control_counting/sub0-control-counting-raw.fif...
    Range : 512 ... 30976 =      2.000 ...   121.000 secs
Ready.
Reading 0 ... 30464  =      0.000 ...   119.000 secs...
[]
Writing /Users/jonasmago/PhD_code_data/github/eeg_jhana/notebooks/hand_cleaning/ALL_NoooAr/raw/1_00_NA_count_003_0081_00_raw.fif
Closing /Users/jonasmago/PhD_code_data/github/eeg_jhana/notebooks/hand_cleaning/ALL_NoooAr/raw/1_00_NA_count_003_0081_00_raw.fif
[done]
Not setting metadata
39 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 39 events and 768 original time points ...
13 bad epochs dropped
Not setting metadata
11 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 11 events and 2560 original time points ...
5 bad epochs dropped
Setting channel interpolation method to {'eeg': 'spline'}.
Writing /U

  raw_nobad.interpolate_bads(reset_bads=True)


  0%|          | Creating augmented epochs : 0/32 [00:00<?,       ?it/s]

  0%|          | Computing thresholds ... : 0/32 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/26 [00:00<?,       ?it/s]

  0%|          | n_interp : 0/3 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/26 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/26 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/26 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]





Estimated consensus=1.00 and n_interpolate=31


  0%|          | Repairing epochs : 0/26 [00:00<?,       ?it/s]

No bad epochs were found for your data. Returning a copy of the data you wanted to clean. Interpolation may have been done.
Setting channel interpolation method to {'eeg': 'spline'}.
Running autoreject on ch_type=eeg


  epochs_3_ar.interpolate_bads(reset_bads=True)


  0%|          | Creating augmented epochs : 0/32 [00:00<?,       ?it/s]

  0%|          | Computing thresholds ... : 0/32 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/6 [00:00<?,       ?it/s]

  0%|          | n_interp : 0/3 [00:00<?,       ?it/s]

  0%|          | Repairing epochs : 0/6 [00:00<?,       ?it/s]

  0%|          | Fold : 0/10 [00:00<?,       ?it/s]

ValueError: Cannot have number of splits n_splits=10 greater than the number of samples: n_samples=6.