In [1]:
from Recordings import AEPFeedbackRecording, Recordings, Recording
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# Load the data
basepath = './Recordings'
recordings = Recordings(basepath)
# recordings.print_info(group='*', session='*', subject_id='1', experiment_id='aep_feedback')

Reading recording: ./Recordings/1/.DS_Store
Reading recording: ./Recordings/1/7/1/7_aep_2024-03-28_14-34-11_1.xdf
Reading recording: ./Recordings/1/7/1/7_aep_feedback_2024-03-28_14-41-11_1.xdf
Creating RawArray with float64 data, n_channels=8, n_times=108836
    Range : 0 ... 108835 =      0.000 ...   435.340 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1651 samples (6.604 s)
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 60 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window wit

In [3]:
aep_feedback_recordings = recordings.filter_by(group='*', session='*', subject_id=1, experiment_id='aep_feedback')

In [4]:
test_recording = aep_feedback_recordings[0]

In [5]:
# trial-begin
# standard/oddball
# response-received-(arrow_up/arrow_down)/response-was-missed
# response-was-(correct/incorrect)/None
# rt-(XXXms)
# trial-end
# ---

In [6]:
# %matplotlib widget
# fig = test_recording._raw.plot(
#     events=test_recording.mne_events,
#     start=5,
#     duration=10,
#     color="gray",
#     event_color={1: "r", 2: "g", 3: "b", 4: "m"}
# )

In [7]:
def filter_trials(trials, stimulus='both', response='both'):
    out_trials = []
    rts = [trial for trial in trials if trial['reaction_time']]
    if stimulus != 'both':
        rts = [trial for trial in rts if trial['stimulus'][1] == stimulus]
    if response != 'both':
        rts = [trial for trial in rts if trial['response'][1] == response]
    rts = [trial for trial in rts]
    out_trials.extend(rts)
    return out_trials

In [8]:
def get_subject_reaction_times(subject_id, stimulus='both', response='both'):
    reaction_times = []
    aep_feedback_recordings = recordings.filter_by(group='*', session='*', subject_id=subject_id, experiment_id='aep_feedback')
    for recording in aep_feedback_recordings:
        rts = filter_trials(recording.trials, stimulus, response)
        rts = [trial['reaction_time'] for trial in rts]
        reaction_times.extend(rts)
    return reaction_times

In [9]:
def get_subject_accuracy(subject_id, stimulus='both'):
    reaction_times = []
    is_correct = []
    aep_feedback_recordings = recordings.filter_by(group='*', session='*', subject_id=subject_id, experiment_id='aep_feedback')
    for recording in aep_feedback_recordings:
        filtered_trial = filter_trials(recording.trials, stimulus, 'both')
        is_correct.extend([trial['response'][1] == 'correct' for trial in filtered_trial])
    return sum(is_correct) / len(is_correct)

In [10]:
for i in range(1, 8):
    print(f'{i}: {get_subject_accuracy(i, "both")*100:.2f}%')

1: 96.40%
2: 92.91%
3: 99.25%
4: 96.43%
5: 95.44%
6: 86.27%
7: 96.95%


In [11]:
for i in range(1, 8):
    print(f'{i}: {get_subject_accuracy(i, "oddball")*100:.2f}%')

1: 76.27%
2: 69.62%
3: 96.67%
4: 79.63%
5: 72.88%
6: 22.03%
7: 81.67%


In [12]:
for i in range(1, 8):
    print(f'{i}: {get_subject_accuracy(i, "standard")*100:.2f}%')

1: 100.00%
2: 97.20%
3: 99.71%
4: 99.65%
5: 99.40%
6: 97.86%
7: 99.70%


In [13]:
def add_subject_reaction_plot(reaction_times, axes, title):
    h, bins = np.histogram(reaction_times, bins=30, range=[200, 1300])
    axes.hist(bins[:-1], bins, weights=h)

    axes.set_title(title)
    
    # Calculate mean and standard deviation
    mean = np.mean(reaction_times)
    std_dev = np.std(reaction_times)
    
    # Add label for mean
    axes.text(mean + 650, 100, f'Mean: {mean:.2f}', color='r', fontsize=10, ha='right', va='center')
    
    # Add mean value indicator (vertical line)
    axes.axvline(mean, color='r', linestyle='--', label='Mean')
    
    # Add standard deviation indicators
    axes.axvline(mean + std_dev, color='g', linestyle=':', label='+1 Std Dev')
    axes.axvline(mean - std_dev, color='g', linestyle=':', label='-1 Std Dev')

In [14]:
fig, axs = plt.subplots(2, 4, sharey=True, tight_layout=True)
for i in range(2):
    for j in range(4):
        subject_id = i * 4 + j + 1
        if subject_id == 8:
            break
        rts = get_subject_reaction_times(subject_id, 'both', 'correct')
        add_subject_reaction_plot(rts, axs[i, j], f'Subject ID: {subject_id}')
fig.show()