In [86]:
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import numpy.ma as ma
import pandas as pd
import mne
data_dir = 'data_meg'
subj = "R2280"
dataqual = 'prepro' #or loc/exp
exp = 'loc' #or exp
dtype = "raw"
label_dir = 'data_log'
save_dir = 'data_meg'
bad_channels_dict = {
    "R2490": ['MEG 014', 'MEG 004', 'MEG 079', 'MEG 072', 'MEG 070', 'MEG 080', 'MEG 074', 'MEG 067', 'MEG 082', 'MEG 105', 'MEG 115', 'MEG 141', 'MEG 153'],
    "R2488": ['MEG 015', 'MEG 014', 'MEG 068', 'MEG 079', 'MEG 146', 'MEG 147', 'MEG 007', 'MEG 141'],
    "R2487": ['MEG 015', 'MEG 014', 'MEG 068', 'MEG 079', 'MEG 147', 'MEG 146', 'MEG 004'],
    "R2280": ['MEG 015', 'MEG 039', 'MEG 077', 'MEG 076', 'MEG 073', 'MEG 079', 'MEG 064', 'MEG 059', 'MEG 070']
}
bad_channels = bad_channels_dict.get(subj, [])


In [87]:
# raw = mne.io.read_raw_fif('data_meg/R2490/prepro/R2490_exp.fif', preload='temp_raw.fif')
raw = mne.io.read_raw_fif(f'{data_dir}/{subj}/{dataqual}/{subj}_{exp}.fif', preload=True)
raw.info['bads'].extend(bad_channels)
sfreq = raw.info['sfreq']
raw.filter(1, 40, method='iir')
downsample = 10
raw.resample(sfreq / downsample)
raw.drop_channels(bad_channels)

Opening raw data file data_meg/R2280/prepro/R2280_loc.fif...


  raw = mne.io.read_raw_fif(f'{data_dir}/{subj}/{dataqual}/{subj}_{exp}.fif', preload=True)


    Range : 0 ... 759999 =      0.000 ...   759.999 secs
Ready.
Reading 0 ... 759999  =      0.000 ...   759.999 secs...
Filtering raw data in 2 contiguous segments
Setting up band-pass filter from 1 - 40 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 1.00, 40.00 Hz: -6.02, -6.02 dB

425 events found on stim channel STI 014
Event IDs: [160 164 165 166]
364 events found on stim channel STI 014
Event IDs: [160 164 165 166]


  raw.resample(sfreq / downsample)


Unnamed: 0,General,General.1
,Filename(s),R2280_loc.fif
,MNE object type,Raw
,Measurement date,2024-11-15 at 20:38:55 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Duration,00:12:40 (HH:MM:SS)
,Sampling frequency,100.00 Hz
,Time points,76000
,Channels,Channels


In [88]:
events = mne.find_events(raw, stim_channel='STI 014', output='onset', shortest_event=1)
event_id = {
    'start': 160,
    'move': 161,
    'reveal_red': 162,
    'reveal_white': 163,
    'done': 164,
    'choice': 165,
    'timeout': 166
}

# Define trials to remove
trials_to_remove = []

364 events found on stim channel STI 014
Event IDs: [160 164 165 166]


In [89]:
start_events = events[events[:, 2] == event_id['start']]
done_events = events[events[:, 2] == event_id['done']] 
timeout_events = events[events[:, 2] == event_id['timeout']]
choice_events = events[events[:, 2] == event_id['choice']]
sfreq = raw.info['sfreq']

In [90]:
if subj == "R2488" :
    done_events = done_events[1:]

In [91]:

# Initialize a list to store filtered done events
filtered_done_events = [done_events[0]]  # Start with the first event

# Check for at least 3 seconds between each done event
for i in range(1, len(done_events)):
    time_diff = (done_events[i, 0] - done_events[i-1, 0]) / sfreq
    if time_diff < 2:
        print(f"Warning: Less than 3 seconds between done events at indices {i-1} and {i}")
    
    else:
        filtered_done_events.append(done_events[i])
done_events = filtered_done_events
# Use filtered_done_events for further processing
print(f"Filtered done events count: {len(filtered_done_events)}")



Filtered done events count: 121


In [99]:
choice_events = events[events[:, 2] == event_id['choice']]
# start_events = np.array([[done_event[0] - int(3.4 * sfreq), 0, event_id['start']] for done_event in done_events])
start_events = events[events[:, 2] == event_id['start']]
# start_events = start_events[-120:]
sfreq = raw.info['sfreq']  # Sampling frequency


# Initialize a list to store trial information
trial_info = []
start_idx = 0

# Iterate through each reconstructed start event to create trial information
for start_event, done_event in zip(start_events, done_events):
    start_sample = start_event[0]
    done_sample = done_event[0]
    # done_sample = start_sample + int(3.4 * sfreq)
    # Calculate tmin and tmax for the epoch
    tmin = -0.2  # 0.2 s before 'start'
    tmax = (done_sample - start_sample) / sfreq  # Duration from 'start' to 'done'
    # Find choice events within the trial
    choice_event = choice_events[(choice_events[:, 0] > start_sample) & 
                                 (choice_events[:, 0] < done_sample + int(0.5*sfreq))]
    choice_time = choice_event[-1, 0] if len(choice_event) > 0 else None
    
    # Store trial information
    trial_info.append({
        'event_sample': done_sample,
        'trial_index': start_idx,
        'duration': tmax,
        'tmin': tmin,
        'tmax': tmax,
        'done': len(done_events) > 0,
        'done_times': done_sample / sfreq,
        'start_times': start_sample / sfreq,
        'choice_event': len(choice_event) > 0,
        'choice_time': choice_time / sfreq if len(choice_event) > 0 else None
    })
    start_idx += 1

In [100]:
for info in trial_info:
    if info['choice_time'] is not None:
        print(info['trial_index'], info['choice_time'] - info['start_times'])
    elif info['done_times'] is not None:
        print(info['trial_index'], None )


0 2.1899999999999995
1 0.9199999999999999
2 0.7800000000000011
3 0.9100000000000001
4 1.0700000000000003
5 0.769999999999996
6 0.6900000000000048
7 0.6900000000000048
8 0.990000000000002
9 None
10 0.8000000000000043
11 1.5900000000000034
12 1.3900000000000006
13 0.7299999999999898
14 0.7000000000000028
15 1.1400000000000006
16 0.7800000000000011
17 0.5999999999999943
18 1.039999999999992
19 0.7399999999999949
20 0.75
21 0.9699999999999989
22 1.5999999999999943
23 1.519999999999996
24 0.9200000000000159
25 None
26 0.8600000000000136
27 1.1800000000000068
28 0.7800000000000011
29 None
30 None
31 0.7599999999999909
32 0.8599999999999852
33 1.1200000000000045
34 1.289999999999992
35 0.6699999999999875
36 1.6299999999999955
37 1.0199999999999818
38 0.960000000000008
39 None
40 1.1599999999999966
41 1.2700000000000102
42 None
43 None
44 2.6599999999999966
45 None
46 1.25
47 1.0600000000000023
48 0.9000000000000341
49 1.8400000000000318
50 0.7599999999999909
51 1.079999999999984
52 1.08999999

In [101]:
new_events = np.array([[info['event_sample'], 0, event_id['done']] for info in trial_info])
tmin = -0.2  
# Initialize lists to store individual epochs data and trial information
data_list = []

for idx, event in enumerate(new_events):
    event_id_code = event[2]
    choice_time = trial_info[idx]['choice_time']  # Retrieve choice_time for the current trial
    print(choice_time)
    if choice_time is not None:
        start_sample = int(choice_time * sfreq)  # Convert choice_time to sample index
        tmin = -0.5  # 0.5 s before 'choice'
        tmax = 0.5  # 0.5 s after 'choice'
    else:
        done_sample = int(trial_info[idx]['done_times'] * sfreq)  # Convert done_times to sample index
        tmin = -1  # 1 s before 'done'
        tmax = 0  # At 'done'
    picks = mne.pick_types(raw.info, meg=True, exclude='bads')

    epochs = mne.Epochs(
        raw, [event], event_id={f'event_{event_id_code}': event_id_code},
        tmin=tmin, tmax=tmax, preload=True,picks=picks,
        reject_by_annotation=False, reject=None, verbose=True
    )
    data_list.append(epochs.get_data())

12.85
Not setting metadata
1 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 101 original time points ...
0 bad epochs dropped
16.52
Not setting metadata
1 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 101 original time points ...
0 bad epochs dropped
21.32
Not setting metadata
1 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 101 original time points ...
0 bad epochs dropped
26.38
Not setting metadata
1 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 1 events and 10

In [95]:
print(len(data_list))

121


In [102]:
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
import joblib  # For saving the model

In [103]:
# Assuming X is your feature matrix with shape (n_epochs, n_channels, n_times)
X = np.array([md.data for md in data_list])  # Ensure this is 3D
X = X.squeeze(axis=1)  # This might reduce dimensions, ensure it's still 3D

# Load the CSV file
locolizer = pd.read_csv(f'{label_dir}/{subj}/loc_data.csv')

# Extract valid trial indices
valid_trial_indices = {info['trial_index'] for info in trial_info}

# Extract labels and group information
rule_label = locolizer.loc[locolizer['trial_index'].isin(valid_trial_indices), 'rule'].values
group_start = locolizer.loc[locolizer['trial_index'].isin(valid_trial_indices), 'num_start'].values

# Identify unique group_start labels
unique_groups = np.unique(group_start)

# Prepare data for each unique group_start
group_data = {group: [] for group in unique_groups}
group_labels = {group: [] for group in unique_groups}

# Populate group_data and group_labels
for idx, info in enumerate(trial_info):
    trial_index = info['trial_index']
    if trial_index in valid_trial_indices:
        group = group_start[idx]
        group_data[group].append(X[idx])  # Ensure X[idx] is 3D
        group_labels[group].append(rule_label[idx])

def flatten_data(X):
    n_samples, n_channels, n_timepoints = X.shape
    return X.reshape(n_samples, n_channels * n_timepoints)

# Train and save decoders for each group
for group in unique_groups:
    X_group = np.array(group_data[group])
    y_group = np.array(group_labels[group])
    
    # Debugging: Check the shape of X_group
    print(f"Group: {group}, X_group shape: {X_group.shape}")
    # Flatten the data
    X_group_flat = flatten_data(X_group)
    
    # Train the decoder
    clf = make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000))
    clf.fit(X_group_flat, y_group)
    
    # Save the trained model
    model_filename = f"locolizer/{subj}_decoder_{group}.joblib"
    joblib.dump(clf, model_filename)
    print(f"Model for group {group} saved as {model_filename}")

Group: 4, X_group shape: (30, 148, 101)
Model for group 4 saved as locolizer/R2280_decoder_4.joblib
Group: 9, X_group shape: (30, 148, 101)
Model for group 9 saved as locolizer/R2280_decoder_9.joblib
Group: 16, X_group shape: (30, 148, 101)
Model for group 16 saved as locolizer/R2280_decoder_16.joblib
Group: 25, X_group shape: (31, 148, 101)
Model for group 25 saved as locolizer/R2280_decoder_25.joblib
