In [None]:
import os
import os.path
import numpy as np
import glob
import pandas as pd
import re
import seaborn as sns
import matplotlib.pyplot as plt
import scipy
from scipy import stats
from scipy import signal
from scipy.stats import linregress, sem
from scipy.signal import chirp, find_peaks, peak_widths, butter

import sys
import math
import time
from netneurotools import stats as st
import statsmodels.api as sm
from matplotlib.backends.backend_pdf import PdfPages
from scipy import stats
libfolder=os.getcwd()
sys.path.append(libfolder)
from mem_rew_stimfuncs import *
print('done')

### Settings: can all be adjusted, see description

In [None]:
hrbp_path = 'C:\\Users\\lloydb\\surfdrive\\ExperimentData\\HRBP_MP'  # change this to own path

# settings 
minimum_distance_centre = 150  # this is a check done during the test --> make sure pps placed the item on/close to the circle, items just clicked as final location in centre screen = invalid response. 
sF = 40                        # pupil sample frequency 
set_invalid_threshold = 50.0   # % of invalid samples in event 

# Start and end event times
baseline_dur = 0.2              # duration of pre-item event
stim_win_end = 6.0              # epoch time to inspect timeseries during study phase: stimulus
fb_win_end=3.0                  # epoch time to inspect timeseries during study phase: feedback
stim_dur = 3                    # seconds the stimuli is in its location during encoding
duration_peak = 0.2             # within the two event timeseries, this is the duration of the mean around the peak pupil response (SP_peak_response and FB_peak_response)
duration_max_anticipation = 0.2 # this is pupil change measure during the anticipation event 
other_event_dur = 0.2
resp_onset_check = 6.0
# missing data: (top one not needed in script because it runs through only data that's available)
#total_missing_runs = {'003': [6], '023': [6], '034': [1], '035': [4], '037': [5], 
#                '038': [5,6], '040': [1], '041': [5,6], '047': [3], '049': [4,6], '051': [4]}
# missing data: 
pup_missing_runs = {'003': ['6'], '035': ['4'], '037': ['5'],  '049': ['4','6'], '051': ['4']}

In [None]:
'''
These are some handy quick functions that are used in the larger functions below. 
Should be self-explanatory 

'''

def get_behav_paths(sub):
    
    HRBP_data = {'study_data': glob.glob('{}/raw_data/data/HRBP_MP{}/run*_HRBP_MP_phase1*.xpd'.format(hrbp_path,sub)),
                 'test_data': glob.glob('{}/raw_data/data/HRBP_MP{}/run*_HRBP_MP_phase2*.xpd'.format(hrbp_path,sub)),
                 'triallist_path': '{}/triallist_ver2'.format(hrbp_path)}
    return HRBP_data

def get_pup_paths(sub, run):
    HRBP_data =   {'pup_rawdata': glob.glob('{}/raw_data/tobii_data/HRBP_MP{}/PupCor_output/*run{}*_raw_pup*'.format(hrbp_path,sub,run)),
                  'pup_smthdata': glob.glob('{}/raw_data/tobii_data/HRBP_MP{}/PupCor_output/*run{}*_expdata_smth_int_pup.txt'.format(hrbp_path,sub,run)),
                  'markerfile': glob.glob('{}/raw_data/tobii_data/HRBP_MP{}/study_phase*run{}*_markers.tsv'.format(hrbp_path,sub,run)),
                  'tsvdat_file': glob.glob('{}/raw_data/tobii_data/HRBP_MP{}/study_phase*run{}*_expdata.tsv'.format(hrbp_path,sub,run)),
                  'FS_corr_pup_smthdata': glob.glob('{}/raw_data/tobii_data/HRBP_MP{}/foreshadow_correction/PupCor_output/*run{}*_expdata_smth_int_pup.txt'.format(hrbp_path,sub,run))}
    return HRBP_data

def make_empty_array(time_start, time_end, sF):

    empty_array = np.empty((1,int((time_end--time_start)*sF)))
    empty_array[:] = np.NaN # make array list 
    empty_array=np.array(empty_array).ravel() # unravel
    
    return empty_array.tolist()

def chop_time_series(data, sF, onset, window_start, window_end):
    data_save = data[onset-int(window_start*sF):onset+int(window_end*sF)]
    return data_save

def chop_time_series_flipped(data, sF, onset, window_start, window_end):
    data_save = data[onset+int(window_start*sF):onset-int(window_end*sF)]
    return data_save

def bl_correct_percChange(data, baseline):
    data_corrected = ((data - baseline) / baseline)*100 
    
    return data_corrected

def bl_correct_absolutePupil(data, baseline):
    data_corrected = (data - baseline)
    
    return data_corrected
    
# define a function to compute z-scores
def zscore(x):
    return (x - x.mean()) / x.std()

def process_event_data(raw_pup, smth_pup, sF, event_data, set_invalid_threshold, baseline_method):
    raw_trial_dat = chop_time_series(raw_pup, sF, event_data["onset"], event_data["win_str"], event_data["win_end"])
    prop_invalid = (sum([int(x==-1) for x in raw_trial_dat])/len(raw_trial_dat))*100  # get proportion of invalid samles in event
    
    if (prop_invalid < set_invalid_threshold) and (event_data['BL_prop_invalid'] < set_invalid_threshold):
        event_data_cut = chop_time_series(smth_pup, sF, event_data["onset"], event_data["win_str"], event_data["win_end"])
        event_data_clean = remove_spur_samples(event_data_cut)
        
        # baseline correct the event 
        if baseline_method == None:
            event_data_BLcorrected = event_data_clean
        elif baseline_method == 'percentage':
            event_data_BLcorrected = bl_correct_percChange(event_data_clean, event_data["baseline"])
        elif baseline_method == 'absolute':
            event_data_BLcorrected = bl_correct_absolutePupil(event_data_clean, event_data["baseline"])
       
    else:
        event_data_BLcorrected = make_empty_array(event_data["win_str"], event_data["win_end"], sF)
    return event_data_BLcorrected

def get_peak_pup_resp(pup_timeseries, duration_peak, sF):

    start_point = int(np.argmax(pup_timeseries)-((duration_peak/2)*sF))
    end_point = int(np.argmax(pup_timeseries)+((duration_peak/2)*sF))
    if start_point < 0:
        start_point = 0
    if end_point > len(pup_timeseries):
        end_point = len(pup_timeseries)
    
    pup_response = np.mean(pup_timeseries[start_point:end_point])
    return pup_response

def apply_median_split(df, pupil):
    median_size = df[pupil].median()
    df[f'{pupil}_median_split'] = df[pupil].apply(lambda x: 'large' if pd.notna(x) and x >= median_size else ('small' if pd.notna(x) else np.nan))
    return df


In [None]:
def get_precision_data(sub, save = True):
    
    '''
    This func loads in the behavioural data:
        - extracts information from study and test phase
        - runs the distance-correction check on the test responses (if resp>dis then NA in that trial)
        - if save = True: saves a file with subject-specific trial-level data 
        - returns the subject trial data dataframe + some averages 
    '''
    
    # get study + test data 
    data = get_behav_paths(sub)
    #print(f'running subject: {sub}')
    if len(data['study_data']) < 6: 
        print(f'missing study runs for subject {sub}')
    elif len(data['test_data']) < 6: 
        print(f'missing test runs for subject {sub}')
        
    sub_Df = pd.DataFrame()
    ave_Df = pd.DataFrame()
    prop_rew = []
    
    # loop over runs
    for i, (study_data, test_data) in enumerate(zip(data['study_data'], data['test_data'])):   
        
        if study_data.split('\\')[-1][3] != test_data.split('\\')[-1][3]:
            raise ValueError(f'study and test data runs are not aligned for subject: {sub}')
        else: 
            run = study_data.split('\\')[-1][3]
        
        # confirm which triallists to take and open files 
        study_data = [line.rstrip() for line in open(study_data)]
        test_data = [line.rstrip() for line in open(test_data)]
        study_triallist = extract_tiallist('{}/HRBP_MP{}_phase1_run{}.txt'.format(data['triallist_path'],sub, str(run)))
        test_triallist = extract_tiallist('{}/HRBP_MP{}_phase2_run{}.txt'.format(data['triallist_path'],sub, str(run)))
        
        
        
        # get information from study data
        #-----------------------------------------
        # picture: picture_study
        picture_study = [line[4] for line in study_triallist]
        ITI_study = [line[6] for line in study_triallist]
        
        # location data: dot_resp_loc, cat_corr
        dot_resp_loc = [line.split(",")[-1] for line in study_data if "mouse_response" in line]
        time_stamp_dot_resp = [line.split(",")[2] for line in study_data if "mouse_response" in line]   # taking this timing because of pspm --> want to model mouse response in pupil (but this is the easiest way to get mouse timing!)

        all_trials = [line.split(",")[1] for line in study_data if "mouse_response" in line or "end_trial" in line]
        check = test_sequentiality(all_trials, 'end_trial', 2)  # check if there is skip (end_trial appears twice !)
        if check:
            for ind in check:
                new_index = round(float((ind+1)/2))
                dot_resp_loc.insert(new_index, "NaN")
                time_stamp_dot_resp.insert(new_index, 'NaN')
                
        if all_trials[0] == 'end_trial':
            dot_resp_loc.insert(0, "NaN")
            time_stamp_dot_resp.insert(0, 'NaN')
            
        start_exp_ts= int([line.split(",")[2] for line in study_data if "start_experiment" in line][0])
        corr_time_stamp_dot_resp = []
        for i in time_stamp_dot_resp:
            if i != 'NaN':
                corr_time_stamp = int(i) - start_exp_ts
                corr_time_stamp_dot_resp.append(corr_time_stamp)
            else:
                corr_time_stamp_dot_resp.append('NaN')

        target_dot = [line[9] for line in study_triallist]
        cat_corr = []
        for target, resp in zip(target_dot, dot_resp_loc):
            if resp != "NaN":
                if (rel_error_degrees(int(target), int(resp)) <= 20):
                    cat_corr.append(1)
                elif (rel_error_degrees(int(target), int(resp)) > 20):
                    cat_corr.append(0)
            elif resp == "NaN":
                cat_corr.append(0)
            
        # TN data: TN_corr, TN_rt
        TN_resp = [line.split(",")[3] for line in study_data if "key_resp_TN" in line]
        TN_key = [int(line[-1]) for line in study_triallist]
        TN_corr = []
        for resp, key in zip(TN_resp, TN_key):
            if resp != "NaN" and str(key) == resp:
                TN_corr.append(1)
            elif resp == "NaN" or str(key) != resp:
                TN_corr.append(0)
        TN_rt = [line.split(",")[4] for line in study_data if "key_resp_TN" in line]
        TN_rt = [int(i) if i != "NaN" else i for i in TN_rt]
        
        # feedback: feedback_type
        feedback_type = [line.split(",")[1] for line in study_data if "feedback_onset" in line]
      

    
    
        # get information from test data
        #-----------------------------------------
        # location data: corr_response_pos
        
        # bug fixing for skipped trial during test! (rarely happens but could skip the precision question randomly)
        response_pos = [line.split(",")[-1] for line in test_data if "mouse_response" in line]
        dis_from_centre = [line.split(",")[5] for line in test_data if "mouse_response" in line]
        all_trials = [line.split(",")[1] for line in test_data if "mouse_response" in line or "end_trial" in line]
        
        # check if there is skip (end_trial appears twice !)
        check = test_sequentiality(all_trials, 'end_trial', 2)
        if check:
            for ind in check:
                new_index = round(float((ind+1)/2))
                response_pos.insert(new_index, "NaN")
                dis_from_centre.insert(new_index, "NaN")
        if all_trials[0] == 'end_trial':
            response_pos.insert(0, "NaN")
            dis_from_centre.insert(0, "NaN")
        # correct precision response for distance from centre
        corr_response_pos = correct_resp_based_on_distance(dis_from_centre, response_pos, minimum_distance_centre)
        
        # calculate minimum error: target_pos - response_pos 
        target_pos = [line[5] for line in test_triallist]
        minimum_error=[]
        for i in range(len(target_pos)):
            if target_pos[i] == 'NaN':
                minimum_error.append('NaN')
            elif corr_response_pos[i] == 'NaN':
                minimum_error.append('NaN')
            else: 
                error = rel_error_degrees(int(target_pos[i]), int(corr_response_pos[i]))
                minimum_error.append(error)

        # Create run dataframe
        run_Df = pd.DataFrame({'sub': len(test_triallist) * [sub],
                               'run': len(test_triallist) * [run],
                               'cat_cond': [int(line[1]) for line in test_triallist],
                               'trial_type': [line[3] for line in test_triallist],
                               'picture': [line[2] for line in test_triallist],
                               'old':[1 if line[4] == 'old' else 0 for line in test_triallist],
                               'response_old': [1 if int(line.split(",")[3])==0 else 0 for line in test_data if "old_new_resp" in line],
                               'target_pos':target_pos,
                               'response_pos': corr_response_pos,
                               'minimum_error': minimum_error})
        
        # get proportion of trials rewarded: 
        prop_rew_run = [line.split(",")[-1] for line in study_data if "PROP_REWARDED" in line]
        prop_rew.append(float(prop_rew_run[0]))

        # drop new trials from dataframe (not analysing those)
        run_Df = run_Df.drop(run_Df[run_Df.old == 0].index)
        # reorder the dataframe to study order 
        run_Df['picture_order'] = picture_study
        run_Df["indices"] = run_Df["picture"].map(lambda x: picture_study.index(x))
        run_Df = run_Df.sort_values(by=["indices"])
        # drop unnecessary columns 
        run_Df = run_Df.drop(columns=['picture_order', 'indices'])
        
        # add the columns from the study information (all in order now)
        run_Df['dot_resp_loc'] = dot_resp_loc
        run_Df['cat_corr'] = cat_corr
        run_Df['TN_corr'] = TN_corr
        run_Df['TN_rt'] = TN_rt
        run_Df['ITI_dur'] =  ITI_study
        run_Df['feedback_type'] = feedback_type
        run_Df['trial'] = list(range(1, 15))
        run_Df['mouse_resp_timestamp'] = corr_time_stamp_dot_resp

        # Create a new column 'trial_type_tm1' based on previous row's value
        run_Df['trial_type_tm1'] = run_Df['trial_type'].shift(1)
        # Assign 'csm' if there is no previous row (i.e., the first index)
        run_Df.loc[run_Df.index[0], 'trial_type_tm1'] = 'csm'
        
        # do the same for ITI 
        run_Df['ITI_dur_tm1'] =  run_Df['ITI_dur'].shift(1)
        # Assign ave ITI if there is no previous row (i.e., the first index)
        run_Df.loc[run_Df.index[0], 'ITI_dur_tm1'] = 4500
        
        sub_Df = sub_Df.append(run_Df, ignore_index=True)
        
        if save:
            save_fn = f'../../stats/1_preprocessed/sub-{sub}/sub-{sub}_precision_data_old.csv'
            sub_Df.to_csv(save_fn, index=False)
        
    # log the proportion of rewarded trials 
    prop = round(float(sum(prop_rew) / len(prop_rew)),2)
    ave_Df = pd.DataFrame({'sub': sub,
                           'prop_rewarded': [prop]})

    return sub_Df, ave_Df

In [None]:
def add_pupilData(sub, overwrite = False, foreshortening_correction = True):
    
    '''
    This func first runs the behavioural data function:
        - then loads in pupil data (choose foreshortening corrected or not)
        - performs: 
             1. Low-pass filter tor run using a 10Hzfourth-order Butterworth filter
             2. Flags trial events with >% of invalid samples    
             3. 3 SD thresholding to event (remove spurious samples within event)
             4. Chops timeseries into events (with baseline correction if needed)
        - adds timeseries and averages as new columns into the behavioural dataframe 
        - returns this dataframe (can append subs and save later - whateverrr)
    '''

    behav_data, ave_Df = get_precision_data(sub)

    # make new columns in the behav df --> prepare for all the new pupil data :) 
    behav_data["SP_timeseries"] = [[] for _ in range(behav_data.shape[0])]
    behav_data["FB_timeseries"] = [[] for _ in range(behav_data.shape[0])]
    
    behav_data["SP_peak_response"] = np.nan
    behav_data["FB_peak_response"] = np.nan
    behav_data["preStim_bl"] = np.nan

    behav_data["anticipation"] = np.nan        # pupil change value (final 0.2 - first 0.2 of anticipation)
    behav_data["start_anticipation"] = np.nan  # based on first 0.2s of the anticipation ITI (pre-item bl corrected)
    behav_data["end_anticipation"] = np.nan    # based on first 0.2s of the anticipation ITI (pre-item bl corrected)
    behav_data["choice_onset_event"] = np.nan  # based on the 0.5s after choice onset (pre-item baseline corrected)
    behav_data["resp_onset_event"] = [[] for _ in range(behav_data.shape[0])]

    invalid_BL = []
    invalid_choice = []
    invalid_anticipation = []
    
    for run in set(behav_data['run']):  # only get pupil data for the equivilent behaivoural data (for example, if subject is missing only the behavioural data for a run, then their pupil data is not analysed)

        if sub in pup_missing_runs and run in pup_missing_runs[sub]:  # where subjects are missing only the pupil data --> there will be NaNs in the dataframe for all pupil columns
            print(f'missing pupil data for run {run}, sub:{sub}, keeping NaNs in df!')

        else:
            print(f'going on with run: {run}')

            all_data = get_pup_paths(sub, run)  # get the pupil data we need

            # load pupil data 
            if foreshortening_correction:   # gaze-corrected and blink interopelated/smoothed          
                smth_pup = np.loadtxt(all_data['FS_corr_pup_smthdata'][0])

            else:                           # only blink interopelated/smoothed        
                smth_pup = np.loadtxt(all_data['pup_smthdata'][0])
            raw_pup = np.loadtxt(all_data['pup_rawdata'][0])    
            markers = pd.read_csv(all_data['markerfile'][0], delimiter = "\t")
            tsvdat = pd.read_csv(all_data['tsvdat_file'][0], delimiter = "\t")
            
            


            # 1. Apply  low-pass filter tor run using a 10Hzfourth-order Butterworth filter
            #------------------------------------------------------------------------------

            bet, ab = signal.butter(4, 10, 'lowpass', fs=40) 
            smth_pup_filt = signal.filtfilt(bet, ab, smth_pup, padlen=0)

            # get the onset samples for each event (stim, TN, FB)
            stim_onset = round((markers.query("Marker == 'stim_onset'")['TimeStamp'] - tsvdat['TimeStamp'][0])/1000*sF).astype(int)
            resp_onset = list(round((markers.query("Marker == 'mouse_response'")['TimeStamp'] - tsvdat['TimeStamp'][0])/1000*sF).astype(int))
            tn_onset = round((markers.query("Marker == 'target_num_onset'")['TimeStamp'] - tsvdat['TimeStamp'][0])/1000*sF).astype(int)
            fb_onset = round((markers.query("Marker == 'feedback_onset'")['TimeStamp'] - tsvdat['TimeStamp'][0])/1000*sF).astype(int)

            for i, (a,b) in enumerate(zip(stim_onset, tn_onset)):  # resp_onset only looks at trials where a location response was made (check this is in line with behavioural data!)'
                if len(resp_onset) <= i:
                    resp_onset.append('nan')
                elif not a < resp_onset[i] < b:
                    resp_onset.insert(i, 'nan')
                    
            # Loop over trials of 3 events 
            for trial, (st_on, resp_on, tn_on, fb_on) in enumerate(zip(stim_onset, resp_onset, tn_onset,fb_onset)):
                trial = trial+1

                 # Baseline stim
                raw_trial_dat  = raw_pup[int(st_on)-int(baseline_dur*sF):int(st_on)]  # takes pupil samples 0.5 s before onset 
                BL_stim_prop_invalid = (sum([int(x==-1) for x in raw_trial_dat])/len(raw_trial_dat))*100
                mean_preStim_baseline = mean(remove_spur_samples(smth_pup[int(st_on)-int(baseline_dur*sF):int(st_on)]))  # calculate the average BL (removing spur samples)
                
                # log proportion of invalid trials BL: 
                if BL_stim_prop_invalid > set_invalid_threshold:
                    invalid_BL.append(1)
                else:
                    invalid_BL.append(0)
                    
                # Baseline feedback
                raw_trial_dat  = raw_pup[int(fb_on)-int(baseline_dur*sF):int(fb_on)]   # takes pupil samples 0.5 s before onset 
                BL_fb_prop_invalid = (sum([int(x==-1) for x in raw_trial_dat])/len(raw_trial_dat))*100
                mean_preFb_baseline = mean(remove_spur_samples(smth_pup[int(fb_on)-int(baseline_dur*sF):int(fb_on)]))  # calculate the average BL (removing spur samples)

                # save the raw basline data for each trial (to look at pupil drift)
                index = int(behav_data.index[(behav_data['run'] == run) & (behav_data['trial'] == trial)][0])
                behav_data.iat[index, behav_data.columns.get_loc('preStim_bl')] = mean_preStim_baseline


                # 2. Flag trial events with >% of invalid samples    (done below)
                #----------------------------------------------------------------------

                # 3. apply 3 SD thresholding to event (remove spurious samples within event)
                #---------------------------------------------------------------------------

                events = [
                    {"name": "SP_timeseries", "peak": "SP_peak_response", "onset": st_on, "win_str": baseline_dur, "win_end": stim_win_end, "baseline": mean_preStim_baseline, 'BL_prop_invalid': BL_stim_prop_invalid},
                    {"name": "FB_timeseries", "peak": "FB_peak_response", "onset": fb_on, "win_str": baseline_dur, "win_end": fb_win_end, "baseline": mean_preFb_baseline, 'BL_prop_invalid': BL_fb_prop_invalid},
                    {"name": "anticipation", "resp_onset": resp_on, "tn_onset": tn_on},
                    {"name": "start_anticipation","resp_onset": resp_on, "tn_onset": tn_on, "baseline": mean_preStim_baseline},
                    {"name": "end_anticipation", "resp_onset": resp_on, "tn_onset": tn_on, "baseline": mean_preStim_baseline},
                    {"name": "choice_onset_event", "baseline": mean_preStim_baseline, 'BL_prop_invalid': BL_stim_prop_invalid},
                    {"name": "resp_onset_event", "onset":resp_on, "win_str": 0, "win_end": 6.0, "baseline": mean_preStim_baseline, 'BL_prop_invalid': BL_stim_prop_invalid}
                ]

                for event in events:
                    
                    
                    # first save the timseries events (not used for analysis - just visualisation)
                    if (event['name'] == "SP_timeseries") or (event['name'] == "FB_timeseries"):
                        output = process_event_data(raw_pup, smth_pup, sF, event, set_invalid_threshold, baseline_method = 'absolute')  # <<-- this function cuts out the events, checks if there are too many invalid samples (in both baseline and event), applies the 3SD criteria and returns the 'final' trial event! 
                         
                        # slot in the timeseries to the correct rows in the dataframe
                        index = int(behav_data.index[(behav_data['run'] == run) & (behav_data['trial'] == trial)][0]) 
                        behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = output  # log the timeseries here
                        behav_data.iat[index, behav_data.columns.get_loc(event['peak'])] = get_peak_pup_resp(output, duration_peak, sF)  # log the peak pupil response here
                   
                            
                    elif event['name'] == 'resp_onset_event':
                        if resp_on == 'nan':
                            output = make_empty_array(0, 6.0, sF)   # log nan here because they did not respond to stimuli category condition
                            
                        else:
                            output = process_event_data(raw_pup, smth_pup, sF, event, set_invalid_threshold, baseline_method = 'absolute') 
                            
                        # slot in the timeseries to the correct rows in the dataframe 
                        index = int(behav_data.index[(behav_data['run'] == run) & (behav_data['trial'] == trial)][0]) 
                        behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = output  # log the timeseries here

                    else:
                        
                        if resp_on == 'nan':
                            behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = 'nan'  # log nan here because they did not respond to stimuli category condition

                        else:
                            if (event['name'] == "anticipation") or (event['name'] == "start_anticipation") or (event['name'] == "end_anticipation"): 
                                
                                # process the two pupil events of 0.5 s pupil (min and max)
                                raw_trial_dat_min = raw_pup[int(event['resp_onset'])+(stim_dur*sF):int(event['resp_onset'])+(stim_dur*sF)+int(duration_max_anticipation*sF)]
                                prop_invalid_min = (sum([int(x==-1) for x in raw_trial_dat_min])/len(raw_trial_dat_min))*100  # get proportion of invalid samles in event
                                raw_trial_dat_max = raw_pup[int(event['tn_onset'])-int(duration_max_anticipation*sF):int(event['tn_onset'])]
                                prop_invalid_max = (sum([int(x==-1) for x in raw_trial_dat_min])/len(raw_trial_dat_min))*100  # get proportion of invalid samles in event

                                # only log anticipation event if both segments contain fewer than 50% invalid samples
                                if (prop_invalid_min < set_invalid_threshold) and (prop_invalid_max < set_invalid_threshold):
                                    min_ant = np.mean(smth_pup[int(event['resp_onset'])+(3*sF):int(event['resp_onset'])+(3*sF)+int(duration_max_anticipation*sF)])
                                    max_ant = np.mean(smth_pup[int(event['tn_onset'])-int(duration_max_anticipation*sF):int(event['tn_onset'])])
                                    anticipation_peak = max_ant-min_ant  # pupil change index

                                    if (event['name'] == "anticipation"):
                                        behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = anticipation_peak  # log pupil change index here
                                    elif (event['name'] == "start_anticipation"):
                                        behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = min_ant - event['baseline'] 
                                    elif (event['name'] == "end_anticipation"):
                                        behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = max_ant - event['baseline'] 
                                    invalid_anticipation.append(0)
                                else:
                                    invalid_anticipation.append(1)
                                    behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = 'nan'  # log nan here because there are too many invalid samples 
                                    #print(f'too many invalid sampels, anticipation for trial {trial}: nan')
                            
                            else: 
                                
                                if event['name'] == 'choice_onset_event':
                                    start = int(st_on) + (stim_dur*sF)
                                    end = int(st_on) + (stim_dur*sF) + int(other_event_dur*sF)
                                    
                                     # log the proportion of invalid trials: 
                                    raw_trial_dat = raw_pup[start : end]
                                    choice_prop_invalid = (sum([int(x==-1) for x in raw_trial_dat])/len(raw_trial_dat))*100  # get proportion of invalid samles in event
                                    if choice_prop_invalid > set_invalid_threshold:
                                        invalid_choice.append(1)
                                    else:
                                        invalid_choice.append(0)             
                                                                           
                                raw_trial_dat = raw_pup[start : end]
                                prop_invalid = (sum([int(x==-1) for x in raw_trial_dat])/len(raw_trial_dat))*100  # get proportion of invalid samles in event
                                
                                # only log event if both segments contain fewer than 50% invalid samples
                                if (prop_invalid < set_invalid_threshold) and (event['BL_prop_invalid'] < set_invalid_threshold):
                                    mean_pup = np.mean(smth_pup[start : end]) - event['baseline']  
                                    behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = mean_pup  # log the pupil here

                                else:
                                    behav_data.iat[index, behav_data.columns.get_loc(event['name'])] = 'nan'  # log nan here because there are too many invalid samples 
                                    
    # Compute a few averages                                
    prop_choice = round(float(sum(invalid_choice) / len(invalid_choice)),2)
    prop_BL = round(float(sum(invalid_BL) / len(invalid_BL)),2)
    prob_anticipation = round(float(sum(invalid_anticipation) / len(invalid_anticipation)),2)
    ave_Df['choice_event_invalid'] = [prop_choice]
    ave_Df['BL_event_invalid'] = [prop_BL]
    ave_Df['anticipation_invalid'] = [prob_anticipation]
    
    # sort the data 
    behav_data = behav_data.sort_values(by=['run'], ascending=True)
    behav_data['trial_tot'] = [i+1 for i in range(len(behav_data['run']))]  # make total trial counter

    # calculate some median scores
    behav_data = behav_data.groupby('run').apply(apply_median_split, pupil='anticipation')
    behav_data = behav_data.groupby('run').apply(apply_median_split, pupil='preStim_bl')
    behav_data = behav_data.groupby('run').apply(apply_median_split, pupil='choice_onset_event')


    # make two new columns: subject exclusions and trial exclusions
    # --------------------------------------------------------------
    ## Here: check whether any participants score chance level behaviour on precision mem test. (mean abs error < 75 degrees)
    abs_minimum_error=[]
    for i in range(len(behav_data['minimum_error'])):
        if behav_data['minimum_error'][i] == 'NaN':
            abs_minimum_error.append('NaN')
        else: 
            abs_minimum_error.append(abs(behav_data['minimum_error'][i]))
    behav_data['abs_minimum_error'] = abs_minimum_error 
    # calculate the mean of the absolute error
    behav_data['abs_minimum_error'] = pd.to_numeric(behav_data['abs_minimum_error'], errors='coerce')
    mean_abs = behav_data['abs_minimum_error'].mean(skipna=True)
    # create the 'subj_excl' column
    subj_excl_list = [0 if mean_abs <= 75 else 1] * len(behav_data)
    # assign the list to a new column 'subj_excl'
    behav_data['subj_excl'] = subj_excl_list
    
    
    behav_data = pd.DataFrame(behav_data)
    #behav_data.to_csv(save_fn, index=False)
    return behav_data, ave_Df


In [None]:
# exclude a further 2 (018, 003, 034)
#exclude_cutoff = ['018', '003', '034']
exclude_cutoff = ['034']
finTrial_dat = finTrial_dat[~finTrial_dat['sub'].isin(exclude_cutoff)]
len(set(finTrial_dat['sub']))

## plotting functions

In [None]:
def plot_pupil_drift(sub):
    dat = add_pupilData(sub)
    dat = dat.sort_values(['run', 'trial'],
              ascending = [True, True])

    fig, axes = plt.subplots(figsize=(30, 5), nrows=1, ncols=len(set(dat['run'])))
    fig.suptitle(f"sub-{sub}", fontsize=16)
    for i,run_nr in enumerate(set(dat['run'])):
        run_data = dat.loc[dat['run'] == run_nr]
        sns.barplot(data=run_data, x="trial", y="preStim_bl", ax=axes[i]).set(title = f'baseline pup run {run_nr}')

def plot_reward_ts(data, event, baseline_dur, SEM=True, sF=40, ax = None,ylabel=None):
    if ax is None:
        ax = plt.gca()  # Get the current axis if ax is not provided
    if event == 'SP_timeseries':
        win_start = baseline_dur
        win_end = 6.0
        corr_len = sF * (win_end+win_start)
    elif event == 'FB_timeseries':
        win_start = baseline_dur
        win_end = 3.0
        corr_len = sF * (win_end+win_start)
    elif event == 'resp_onset_event':
        win_start = 0
        win_end = 6.0
        corr_len = sF * (win_end+win_start)
        
    x = np.array(range(1,int((win_end--win_start)*sF)+1))

    dat = data[data[event].apply(len) == corr_len] # if a trial has NAN for timeseries, it will not be included (i.e., doesn't meet inclusion criteria)
    
    # exclude trials which were wrongly classified at the start of the trial 
    excl_trial = dat['cat_corr']
    excl_bool_mask = np.where(np.array(excl_trial) == 0, False, True)
    dat = dat[excl_bool_mask]
    
    # collapse across subs
    grouped = dat.groupby(['sub', 'trial_type'])[event].apply(lambda x: np.nanmean(x.tolist(), axis=0))
    df = grouped.reset_index()
    df.columns = ['sub', 'trial_type', event]

    # calculate mean and SEM
    rew = df[df['trial_type'] == 'csp'][event]
    neu = df[df['trial_type'] == 'csm'][event]
    mean_rew = np.nanmean(rew.to_list(), axis=0)
    mean_neu = np.nanmean(neu.to_list(), axis=0)

    # Calculate the derivative or rate of change
    grad_neu = np.gradient(mean_neu, range(0, len(mean_neu)))

    # Find the maximum rate of change (steepness) rising to the peak
    peak_steepness_neu = np.max(grad_neu)
    print("Neutral: Steepness of slope rising to the peak:", peak_steepness_neu)
    
    rew_SEM = scipy.stats.sem(rew.to_list(), axis=0, nan_policy= 'omit')
    rew_SEM1 = mean_rew - rew_SEM 
    rew_SEM2 = mean_rew + rew_SEM

    neu_SEM = scipy.stats.sem(neu.to_list(), axis=0, nan_policy= 'omit')
    neu_SEM1 = mean_neu - neu_SEM 
    neu_SEM2 = mean_neu + neu_SEM

    # Statistical test: point-wise two-sided t-test
    stat, p_val_uncor = st.permtest_rel(rew.to_list(), neu.to_list(), axis=0)
    p_val_cor=sm.stats.fdrcorrection(p_val_uncor, alpha=0.05)
    sig_point=np.where(p_val_cor[0]==True)
    
    
    # plot data
    ax.plot(mean_rew, label='reward', color = 'green') 
    ax.plot(mean_neu, label='neutral', color =  'grey') 
    
    if SEM: 
        ax.fill_between(x, rew_SEM1,rew_SEM2, alpha = 0.4,color = 'green')
        ax.fill_between(x, neu_SEM1,neu_SEM2, alpha = 0.4, color =  'grey')
    plt.sca(ax)
    plt.xticks(np.arange(sF*baseline_dur, len(rew[1]), sF*1).tolist(),np.arange(0, round(win_end + baseline_dur), 1).tolist())
    ax.legend().remove()
    ax.set_xlabel("Time [s]")
    if ylabel:
        ax.set_ylabel("Pupil size [mm]")
    else:
        ax.set_ylabel("")
    if (event == 'SP_timeseries'):
        ax.hlines(y=max(rew_SEM2+0.05),xmin = 0.2*sF, xmax = 3.2*sF, linestyles = 'dashed') 
        ax.vlines(ymax=max(rew_SEM2+0.05), ymin=min(neu_SEM2-0.05), x=sF*3.2, colors='blue', linewidth=1, alpha=0.2)
        ax.vlines(ymax=max(rew_SEM2+0.05), ymin=min(neu_SEM2-0.05), x=sF*3.4, colors='blue', linewidth=1, alpha=0.2)
    elif event == 'FB_timeseries':
        ax.hlines(y=max(rew_SEM2+0.05),xmin =  0.2*sF, xmax = 0.7*sF, linestyles = 'dashed') 
    elif event == 'resp_onset_event':
        ax.hlines(y=max(rew_SEM2+0.05),xmin =  0.5*sF, xmax = 3.0*sF, linestyles = 'dashed') 
    if len(sig_point[0]) > 0:
        ax.hlines(y=max(rew_SEM2+0.05),xmin = sig_point[0][0], xmax = sig_point[0][-1], color = 'blue') # sig line 
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)    
        
    # save the plot
    #plt.savefig(f'../../stats/figures/pupil/timeseries_plots/{event}_split_reward.pdf')
    #plt.savefig(f'../../stats/figures/pupil/timeseries_plots/{event}_split_reward.png', dpi=300)


def plot_fb_ts(data, event, baseline_dur, sF=40,ax = None, ylabel=None):
    if ax is None:
        ax = plt.gca()  # Get the current axis if ax is not provided
    if event == 'SP_timeseries':
        win_start = baseline_dur
        win_end = 6.0 
        corr_len = sF * (win_end+win_start)
        
    elif event == 'FB_timeseries':
        win_start = baseline_dur
        win_end = 3.0
        corr_len = sF * (win_end+win_start)
        
    x = np.array(range(1,int((win_end--win_start)*sF)+1))

    dat = data[data[event].apply(len) == corr_len]
    # exclude trials which were wrongly classified at the start of the trial 
    excl_trial = dat['cat_corr']
    excl_bool_mask = np.where(np.array(excl_trial) == 0, False, True)
    dat = dat[excl_bool_mask]
    
    
    # collapse across subs
    grouped = dat.groupby(['sub', 'feedback_type'])[event].apply(lambda x: np.nanmean(x.tolist(), axis=0))
    df = grouped.reset_index()
    df.columns = ['sub', 'feedback_type', event]

    # calculate mean and SEM
    pos = df[df['feedback_type'] == 'pos_feedback_onset'][event]
    neg = df[df['feedback_type'] == 'neg_feedback_onset'][event]
    neu = df[df['feedback_type'] == 'neu_feedback_onset'][event]
    mean_pos = np.nanmean(pos.to_list(), axis=0)
    mean_neg = np.nanmean(neg.to_list(), axis=0)
    mean_neu = np.nanmean(neu.to_list(), axis=0)

    pos_SEM = scipy.stats.sem(pos.to_list(), axis=0, nan_policy= 'omit')
    pos_SEM1 = mean_pos - pos_SEM 
    pos_SEM2 = mean_pos + pos_SEM

    neg_SEM = scipy.stats.sem(neg.to_list(), axis=0, nan_policy= 'omit')
    neg_SEM1 = mean_neg - neg_SEM 
    neg_SEM2 = mean_neg + neg_SEM
    
    neu_SEM = scipy.stats.sem(neu.to_list(), axis=0, nan_policy= 'omit')
    neu_SEM1 = mean_neu - neu_SEM 
    neu_SEM2 = mean_neu + neu_SEM

    # plot data
    ax.plot(mean_pos, label='pos', color = 'green') 
    ax.plot(mean_neg, label='neg', color =  'orange') 
    ax.plot(mean_neu, label='neu', color =  'grey') 
    
    ax.fill_between(x, pos_SEM1,pos_SEM2, alpha = 0.4,color = 'green')
    ax.fill_between(x, neg_SEM1,neg_SEM2, alpha = 0.4, color =  'orange')
    ax.fill_between(x, neu_SEM1,neu_SEM2, alpha = 0.4, color =  'grey')
    plt.sca(ax)
    plt.xticks(np.arange(20, corr_len, sF*1).tolist(),np.arange(0, round(win_end + 0.5), 1).tolist())
    ax.legend().remove()
    ax.set_xlabel("Time [s]")
    if ylabel:
        ax.set_ylabel("Pupil size [mm]")
    else:
        ax.set_ylabel("")
    if event == 'SP_timeseries':
        ax.hlines(y=max(neg_SEM2+0.05),xmin = 20, xmax = 3.5*40, linestyles = 'dashed') # sharon sig line
    elif event == 'FB_timeseries':
        ax.hlines(y=max(neg_SEM2+0.05),xmin = 20, xmax = 1*40, linestyles = 'dashed') # sharon sig line
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # save the plot
   # plt.savefig(f'../../stats/figures/pupil/timeseries_plots/{event}_split_feedback.pdf')
    #plt.savefig(f'../../stats/figures/pupil/timeseries_plots/{event}_split_feedback.png', dpi=300)
    
def convert_pvalue_to_asterisks(pvalue):
    if pvalue <= 0.0001:
        return "****"
    elif pvalue <= 0.001:
        return "***"
    elif pvalue <= 0.01:
        return "**"
    elif pvalue <= 0.05:
        return "*"
    return "ns"

def plot_ave_pup(data, event, ax = None, ylabel=None):
    if ax is None:
        ax = plt.gca()  # Get the current axis if ax is not provided
    if event == 'SP_peak_response':
        title = 'stimulus presentation'
    elif event == 'FB_peak_response':
        title = 'feedback presentation'
    elif event == 'anticipation':
        title = 'reward anticipation'
    elif event == 'start_anticipation':
        title = 'early reward anticipation'
    elif event == 'end_anticipation':
        title = 'late reward anticipation'
    elif event == 'preStim_bl':
        title = 'pre-stimulus baseline (0.5s)'
    elif event == 'choice_onset_event':
        title = 'pupil at choice onset (0.5s)'
    elif event == 'resp_onset_event':
        title = 'pupil at mouse response (0.5s)'
        
    # exclude trials which were wrongly classified at the start of the trial 
    excl_trial = data['cat_corr']
    excl_bool_mask = np.where(np.array(excl_trial) == 0, False, True)
    data = data[excl_bool_mask]    
    
    grouped = data.groupby(['sub', 'trial_type'])[event].apply(lambda x: np.nanmean(x.tolist(), axis=0))
    df = grouped.reset_index()
    df.columns = ['sub', 'trial_type', event]
    
    # change name 
    df.loc[df['trial_type'] == 'csp', 'trial_type'] = 'reward'
    df.loc[df['trial_type'] == 'csm', 'trial_type'] = 'neutral'
    
    # Group by 'trial_type' and calculate mean
    means = df.groupby('trial_type')[event].mean()
    # Group by 'trial_type' and calculate standard error
    standard_errors = df.groupby('trial_type')[event].apply(sem)

    # Print the mean and standard error
    print('Mean:')
    print(means)
    print('\nStandard Error:')
    print(standard_errors)
    
    # perform paired t-test
    stat, pvalue = scipy.stats.ttest_rel(df[df['trial_type']==f'reward'][event],
    df[df['trial_type']==f'neutral'][event])
    sig_lev = convert_pvalue_to_asterisks(pvalue)
#     if pvalue < 0.5:
#         print(f'p = {round(pvalue,3)}')
    print(f'p = {pvalue,3}')
    if (event == 'SP_peak_response') or (event == 'FB_peak_response'):
        plt_title = f'peak pupil response: {title}'
    else: 
        plt_title = title
    
    sns.set(rc={'figure.figsize':(2,4)}) 
    sns.set_style('white')
    sns.barplot(data=df, x="trial_type", y=event,capsize=.1, edgecolor=".1", palette=['grey', 'green'], alpha=.8,ax=ax)
    sns.swarmplot(data=df, x="trial_type", y=event, color="0", alpha=.35, ax=ax)

    ax.text(0.5,max(df[event]),sig_lev,horizontalalignment='center', verticalalignment='top', fontsize = 13)
    
    if ylabel:
        if event == 'preStim_bl':
            ax.set_ylabel("Pupil size [mm]")
        elif event == 'anticipation':
            ax.set_ylabel("Change in pupil size [mm]")
        else:
            ax.set_ylabel("Pupil size [mm]")
    else:
        ax.set_ylabel("")
    ax.set_xlabel("")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # save the plot
    #plt.savefig(f'../../stats/figures/pupil/average_plots/{event}_split_reward.pdf')
    #plt.savefig(f'../../stats/figures/pupil/average_plots/{event}_split_reward.png', bbox_inches='tight', dpi = 300)
    
    
def plot_trial_pup(data, event):
    
    if event == 'SP_peak_response':
        title = 'stimulus presentation'
    elif event == 'FB_peak_response':
        title = 'feedback presentation'
    elif event == 'anticipation':
        title = 'reward anticipation'
    elif event == 'preStim_bl':
        title = 'pre-stimulus baseline'

   
    # change name 
    data.loc[data['trial_type'] == 'csp', 'trial_type'] = 'reward'
    data.loc[data['trial_type'] == 'csm', 'trial_type'] = 'neutral'

    sns.stripplot(data=trial_dat, x="trial_type", y='anticipation',  alpha=.35, size=2, hue = 'sub').set(title = f'trial-level peak pupil response: {title}')
    
    plt.ylabel("Pupil size [% change]")
    plt.xlabel("")

def plot_trial_level_regression_conditions(data, event_pup, event_behav):
    
    if event_pup == 'preStim_bl_zscore':
        title_y = 'Pupil size [z-score]'
    else:
        title_y = 'Pupil size [% change]'
        
    if event_behav == 'minimum_error':
        title_x  = 'Minimum Error (abs)'
    elif event_behav == 'TN_rt':
        title_x  = 'Reaction time (ms)'
    
    # Create a figure with two subplots side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    fig.suptitle(f"{event}: trial-level data", fontsize=16)
    
    # Filter the data by the 'reward' column
    reward_conditions = ['csp', 'csm']
    colours = ['green', 'grey']
    for i, condition in enumerate(reward_conditions):
        trial_dat_full = data.dropna(subset=[event, event_behav])
        
        # exclude trials which were wrongly classified at the start of the trial 
#         excl_trial = trial_dat_full['cat_corr']
#         excl_bool_mask = np.where(np.array(excl_trial) == 0, False, True)
#         trial_dat_full = trial_dat_full[excl_bool_mask]
        
        trial_dat_condition = trial_dat_full[trial_dat_full['trial_type'] == condition]
        trial_dat_condition = trial_dat_condition.dropna(subset=[event, event_behav])
        trial_dat_condition = trial_dat_condition[trial_dat_condition[event_behav] != 'NaN']
        trial_dat_condition = trial_dat_condition[trial_dat_condition[event] != 'NaN']
        
        if event_behav == 'minimum_error':
            x_condition = trial_dat_condition[event_behav].astype(float).abs()
        else: 
            x_condition = trial_dat_condition[event_behav].astype(float)
        y_condition = trial_dat_condition[event].astype(float)
        num_nans = np.isnan(x_condition).sum()

        # Carry out the regression analysis with the non-NaN values of X and Y
        slope, intercept, r_value, p_value, std_err = linregress(x_condition, y_condition)
        #print(f'{condition} - regression slope R = {round(r_value, 3)} and p = {round(p_value, 10)}')
    
        # Compute the regression line
        reg_line = slope*x_condition + intercept

        # Plot the scatter plot with the regression line in each subplot
        ax = [ax1, ax2][i]
        ax.scatter(x_condition, y_condition, s=3, color=colours[i])
        ax.plot(x_condition, reg_line, color=colours[i], linewidth=2, 
                label=f'y = {slope:.2f}x + {intercept:.2f}\n r = {r_value:.2f}')
        ax.set_xlabel(title_x)
        ax.set_ylabel(title_y)
        ax.set_title(condition.capitalize())
        ax.legend()

    plt.tight_layout()
    plt.show()


def plot_trial_level_regression_collapsed(data, event_pup, event_behav):
    
    if event_pup == 'preStim_bl_zscore':
        title_y = 'Pupil size [z-score]'
    else:
        title_y = 'Pupil size [% change]'
        
    if event_behav == 'minimum_error':
        title_x  = 'Minimum Error (abs)'
    elif event_behav == 'TN_rt':
        title_x  = 'Reaction time (ms)'
    

    trial_dat_full = data.dropna(subset=[event, event_behav])
    trial_dat_full = trial_dat_full.dropna(subset=[event, event_behav])
    trial_dat_full = trial_dat_full[trial_dat_full[event_behav] != 'NaN']
    trial_dat_full = trial_dat_full[trial_dat_full[event] != 'NaN']

    if event_behav == 'minimum_error':
        x_condition = trial_dat_full[event_behav].astype(float).abs()
    else: 
        x_condition = trial_dat_full[event_behav].astype(float)
    y_condition = trial_dat_full[event].astype(float)
    num_nans = np.isnan(x_condition).sum()

    # Carry out the regression analysis with the non-NaN values of X and Y
    slope, intercept, r_value, p_value, std_err = linregress(x_condition, y_condition)
    #print(f'{condition} - regression slope R = {round(r_value, 3)} and p = {round(p_value, 10)}')

    # Compute the regression line
    reg_line = slope*x_condition + intercept

    # Plot the scatter plot with the regression line in each subplot
    plt.scatter(x_condition, y_condition, s=3)
    plt.plot(x_condition, reg_line, linewidth=2, 
            label=f'y = {slope:.2f}x + {intercept:.2f}\n r = {r_value:.2f}')
    plt.xlabel(title_x)
    plt.ylabel(title_y)
    plt.legend()

    #plt.tight_layout()
    plt.show()

    
def subject_level_regression_conditions(data, event):

    if event == 'preStim_bl_zscore':
        title_plot = 'Pupil size [z-score]'
    else:
        title_plot = 'Pupil size [% change]'

        # remove nans
    data = data[data['minimum_error'] != 'NaN']
    data = data[data[event] != 'NaN']

    data['minimum_error_abs'] = data['minimum_error'].abs()
    data['minimum_error_abs'] = data['minimum_error_abs'].astype(float)

    averaged = data.groupby(['sub', 'trial_type'])[event, 'minimum_error_abs'].mean().reset_index()

    # # Separate the data based on condition
    reward = averaged[averaged['trial_type'] == 'csp']
    neutral = averaged[averaged['trial_type'] == 'csm']

    # get statistics 
    slope_rew, intercept_rew, r_value_rew, p_value_rew, std_err_rew = linregress(reward['minimum_error_abs'], reward[event])
    # # Compute the regression line
    reg_line_rew = slope_rew*reward['minimum_error_abs'] + intercept_rew
    print(f'p-val for reward regression: {p_value_rew}')
    # # get statistics 
    slope_neu, intercept_neu, r_value_neu, p_value_neu, std_err_neu = linregress(neutral['minimum_error_abs'], neutral[event])
    # Compute the regression line
    reg_line_neu = slope_neu*neutral['minimum_error_abs'] + intercept_neu
    print(f'p-val for neutral regression: {p_value_neu}')
    # # make regression plot 
    plt.scatter(reward['minimum_error_abs'], reward[event], color = 'green', s=20)
    plt.plot(reward['minimum_error_abs'], reg_line_rew, color='green', linewidth=2,
                     label=f'y = {slope_rew:.2f}x + {intercept_rew:.2f}\n r = {r_value_rew:.2f}')

    plt.scatter(neutral['minimum_error_abs'], neutral[event], color = 'grey', s=20)
    plt.plot(neutral['minimum_error_abs'], reg_line_neu, color='grey', linewidth=2,
                     label=f'y = {slope_neu:.2f}x + {intercept_neu:.2f}\n r = {r_value_neu:.2f}')


    # Add axis labels and legend
    plt.xlabel('Minimum Error (abs)')
    plt.ylabel(title_plot)
    plt.legend()

    plt.title(f"{event}: participant-level average")
    
def plot_start_end_ant(data, ax = None, ylabel=True):
    data = finTrial_dat
    title = 'anticipatory pupil'

    # exclude trials which were wrongly classified at the start of the trial 
    excl_trial = data['cat_corr']
    excl_bool_mask = np.where(np.array(excl_trial) == 0, False, True)
    data = data[excl_bool_mask]    

    #grouped = data.groupby(['sub'])['start_anticipation'].apply(lambda x: np.nanmean(x.tolist(), axis=0))
    grouped = data.groupby(['sub'])['start_anticipation', 'end_anticipation'].mean()
    df = grouped.reset_index()
    df = pd.melt(df, id_vars='sub', value_vars=['start_anticipation', 'end_anticipation'])
    df.columns = ['sub', 'event', 'pupil_size']
    palette = sns.color_palette(['grey'], len(df['sub'].unique()))
    fig = sns.set(rc={'figure.figsize':(2,4)}) 
    fig = sns.set_style('white')
    #sns.barplot(data=df, x="event", y='pupil_size',capsize=.1, edgecolor=".1", palette=['grey', 'green'], alpha=.8)
    fig = sns.swarmplot(data=df, x="event", y='pupil_size', color="0", alpha=.35, size = 5)
    fig = sns.lineplot(x="event", y="pupil_size", hue='sub', sort=False, palette =palette, data=df,legend=False)
    fig.set_ylabel("Pupil size [mm]")

    fig.spines['top'].set_visible(False)
    fig.spines['right'].set_visible(False)

    # save the plot
    #plt.savefig(f'../../stats/figures/pupil/average_plots/{event}_split_reward.pdf')
    #plt.savefig(f'../../stats/figures/pupil/average_plots/{event}_split_reward.png', bbox_inches='tight', dpi = 300)
    

In [None]:
def plot_behaviour_checks(data, event,ax = None, ylabel= None):
    if ax is None:
        ax = plt.gca()  # Get the current axis if ax is not provided
    if event == 'cat_corr':
        title = 'category classification'
    elif event == 'TN_corr':
        title = 'target number classification'
    elif event == 'TN_rt':
        title = 'target number reaction time'
 
    data[event] = pd.to_numeric(data[event], errors='coerce')
    grouped = data.groupby(['sub', 'trial_type'])[event].apply(lambda x: np.nanmean(x.tolist(), axis=0))
    df = grouped.reset_index()
    
    df.columns = ['sub', 'trial_type', event]
    
    # change name 
    df.loc[df['trial_type'] == 'csp', 'trial_type'] = 'reward'
    df.loc[df['trial_type'] == 'csm', 'trial_type'] = 'neutral'
    
# Group by 'trial_type' and calculate mean
    means = df.groupby('trial_type')[event].mean()

    # Group by 'trial_type' and calculate standard error
    standard_errors = df.groupby('trial_type')[event].apply(sem)

    # Print the mean and standard error
    print('Mean:')
    print(means)
    print('\nStandard Error:')
    print(standard_errors)

    
    # perform paired t-test
    stat, pvalue = scipy.stats.ttest_rel(df[df['trial_type']==f'reward'][event],
    df[df['trial_type']==f'neutral'][event])
    sig_lev = convert_pvalue_to_asterisks(pvalue)
    #if pvalue < 0.5:
    print(f'{event} p = {pvalue}')
        
    sns.set(rc={'figure.figsize':(2,4)}) 
    sns.set_style('white')
    sns.barplot(data=df, x="trial_type", y=event,capsize=.1, edgecolor=".1", palette=['grey', 'green'],alpha=.8,ax=ax)
    sns.swarmplot(data=df, x="trial_type", y=event, color="0", alpha=.35, ax=ax)

    if (event == 'TN_rt'):
        ax.text(0.5,max(df[event])+50,sig_lev,horizontalalignment='center', verticalalignment='top', fontsize = 13)
    else:
        ax.text(0.5,max(df[event])+0.1,sig_lev,horizontalalignment='center', verticalalignment='top', fontsize = 13)
    #ax.set_ylim(top=1.1)
    if ylabel:
        if (event == 'cat_corr') or (event == 'TN_corr'):
            ax.set_ylabel("Proportion correct")
        else:
            ax.set_ylabel("Reaction time (ms)")
    else: 
        ax.set_ylabel("")
    ax.set_xlabel("")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #plt.tight_layout()
    # save the plot
    #plt.savefig(f'../../stats/figures/behaviour/{event}_split_reward.pdf')
    #plt.savefig(f'../../stats/figures/behaviour/{event}_split_reward.png',bbox_inches='tight', dpi = 300)
    


In [None]:
'''
Run the preprocessing here:
'''

hrbp_path = 'C:\\Users\\lloydb\\surfdrive\\ExperimentData\\HRBP_MP'
file_list = glob.glob(os.path.join(hrbp_path, 'raw_data', 'data', 'HRBP*'))
subjectlist = [os.path.split(file)[1][-3:] for file in file_list]

trial_dat = pd.DataFrame()
ave_dat = pd.DataFrame()
for sub in subjectlist:
    print(f'running subject: {sub}')
    get_precision_data(sub)
    sub_dat, ave_sub = add_pupilData(sub)
    
    trial_dat = trial_dat.append(sub_dat, ignore_index=True)
    ave_dat = ave_dat.append(ave_sub, ignore_index=True)
    

### Run the exclusion checks

In [None]:
## excl the chance-level behaviour people! (12 excluded) defined as a mean absolute error of \> 75 degrees (where chance = 90) in location response
finTrial_dat = trial_dat[trial_dat['subj_excl'] == 0]
set(finTrial_dat['sub'])
len(set(finTrial_dat['sub']))

##### Sub-034 is the additional sub who had to be removed after removing incorrectly classified trials (‘Do you expect a reward on this trial?’) -> too few remaining trials for model fitting
##### Note: after revisions, we ran an outlier detection analysis, this flagged subs 018 and 003 --> all analyses can be checked by removing these two here. 

In [None]:
# exclude a further 2 (018, 003, 034)
#exclude_cutoff = ['018', '003', '034']
exclude_cutoff = ['034']
finTrial_dat = finTrial_dat[~finTrial_dat['sub'].isin(exclude_cutoff)]
len(set(finTrial_dat['sub']))

## Make figure plots  (other plots made in R)
### figure 2:

In [None]:
fig, (ax1,ax2,ax3,ax4) = plt.subplots(1, 4, figsize=(14,3))
plot_ave_pup(data=finTrial_dat, event='preStim_bl', ax = ax1, ylabel=True)
plot_ave_pup(data=finTrial_dat, event='SP_peak_response', ax = ax2, ylabel=True)
plot_ave_pup(data=finTrial_dat, event='choice_onset_event', ax = ax3, ylabel=False)
plot_ave_pup(data=finTrial_dat, event='anticipation', ax = ax4, ylabel=False)
plt.tight_layout()
plt.savefig(f'../../stats/figures/Figure2_pupBarplots.pdf')

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(14,3))
plot_reward_ts(data=finTrial_dat, event='SP_timeseries', baseline_dur=baseline_dur, ax = ax1,ylabel=True)
plot_reward_ts(data=finTrial_dat, event='FB_timeseries', baseline_dur=baseline_dur, ax = ax2, ylabel=False)
plot_fb_ts(data=trial_dat,  event='FB_timeseries', baseline_dur=baseline_dur,ax = ax3,ylabel=False)
plt.tight_layout()
plt.savefig(f'../../stats/figures/Figure2_TSplots.pdf')

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6,3))

plot_behaviour_checks(finTrial_dat, event = 'cat_corr', ax = ax1, ylabel= True)
plot_behaviour_checks(finTrial_dat, event = 'TN_corr', ax = ax2, ylabel= False)
plot_behaviour_checks(finTrial_dat, event = 'TN_rt', ax = ax3, ylabel= True)

plt.tight_layout()
plt.savefig(f'../../stats/figures/Figure2A.pdf')

In [None]:
fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(7,3)) #plt.figure(figsize=(3, 2.8))
sub_regression_plot(finTrial_dat, 'preStim_bl', ax = ax1, ylabel = True)
sub_regression_plot(finTrial_dat, 'choice_onset_event', ax = ax2, ylabel = False)
sub_regression_plot(finTrial_dat, 'anticipation', ax = ax3, ylabel = False)
plt.tight_layout()
plt.savefig(f'../../stats/figures/Fig2B_new_regression_plots.pdf')
plt.show()

### Print some descriptive information! 

In [None]:

final_subs = set(finTrial_dat['sub'])
ave_dat_fin = ave_dat[ave_dat['sub'].isin(final_subs)]

means = ave_dat_fin['prop_rewarded'].mean()
standard_errors = ave_dat_fin['prop_rewarded'].sem()

# Print the mean and standard error
print('Mean prop_rewarded:')
print(round(means,2))
print('Standard Error prop_rewarded:')
print(round(standard_errors,2))

means = ave_dat_fin['choice_event_invalid'].mean()
standard_errors = ave_dat_fin['choice_event_invalid'].sem()

# Print the mean and standard error
print('\nMean choice_event_invalid:')
print(round(means,2))
print('Standard Error choice_event_invalid:')
print(round(standard_errors,2))

means = ave_dat_fin['BL_event_invalid'].mean()
standard_errors = ave_dat_fin['BL_event_invalid'].sem()

# Print the mean and standard error
print('\nMean BL_event_invalid:')
print(round(means,2))
print('Standard Error BL_event_invalid:')
print(round(standard_errors,2))

means = ave_dat_fin['anticipation_invalid'].mean()
standard_errors = ave_dat_fin['anticipation_invalid'].sem()

# Print the mean and standard error
print('\nMean anticipation_invalid:')
print(round(means,2))
print('Standard Error anticipation_invalid:')
print(round(standard_errors,2))

### Save data here for further analysis in R

In [None]:
# save csv 
finTrial_dat.to_csv(f'{hrbp_path}\\stats\\1_preprocessed\\group_data\\HRBP_trial_dataframe.csv',index=True)

In [None]:
# select all columns except timeseries to save data
non_timeseries_dat = trial_dat.loc[:, trial_dat.columns != 'SP_timeseries']
non_timeseries_dat = non_timeseries_dat.loc[:, non_timeseries_dat.columns != 'FB_timeseries']

# save csv 
non_timeseries_dat.to_csv(f'{hrbp_path}\\stats\\1_preprocessed\\group_data\\behav_pupil_trial_data.csv',index=True)