# Plots created for the publication

In [None]:
%reset # only run when space is needed (variables will be deleted from memory)

In [None]:
import warnings  # hide warnings
import pandas as pd  # operate with dataframes
import matplotlib.pyplot as plt  # mother of plots focr Python
import numpy as np  # array/matrix operations (e.g. linear algebra)
import seaborn as sns  # matplotlib plotting nice with shortcuts
from scipy.stats import ks_2samp
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import multipletests
from statsmodels.stats.anova import AnovaRM
from matplotlib.lines import Line2D
import scipy.io  # to load Matlab data
import pingouin as pg
import math
from  matplotlib.patches import Arc
import matplotlib.patheffects as pe


from scipy.interpolate import interp1d
from shapely.geometry import Polygon
from matplotlib.patches import ConnectionPatch
from tqdm.notebook import tqdm, trange  # mother of progressbars
from itertools import groupby
import pyxdf  # read XDF files (LSL streams recordings)


In [None]:
warnings.simplefilter(action="ignore", category=FutureWarning)

# raw and processed data paths
PATH_RAW = "./data/raw"
PATH_PROC = "./data/processed"
PATH_FOREYE = "./data/processed/MAD_sacc" 
PATH_EEG = "./data/processed/EEG"
PATH_HANDLABELS = "./data/processed/Data_HandLabeling"
PATH_TRG = "./data/processed/Trigger_MAD"


def pbar_fork_hack():
    """
    Hack to enforce progress bars to be displayed by fork processes on
    IPython Apps like Jupyter Notebooks.

    Avoids [IPKernelApp] WARNING | WARNING: attempted to send message from fork

    Important: pass this function as argument for the initializer parameter
    while initializing a multiprocessing pool to make it work. E.g.:

    pool = Pool(processes=N_CORES, initializer=pbar_fork_hack)

    Source:
     - https://github.com/ipython/ipython/issues/11049#issue-306086846
     - https://github.com/tqdm/tqdm/issues/485#issuecomment-473338308
    """
    print(" ", end="", flush=True)

In [None]:
recordings = pd.read_csv("./recordings_village.csv", index_col=0)

# only keep the subjects we will include in the paper!
subj_to_inlcude = [1,2,5,7,12,17,18,19,20,21,22,27,29,30,32,33,34,37,38,] # final subjects for the paper  
recordings["files"] = [
    int(f[:2]) if int(f[:2]) in subj_to_inlcude else 0
    for f in recordings["file"]
]

recordings = recordings[recordings["files"] > 0]
recordings = recordings.drop("files", axis=1)

display(recordings)

In [None]:
# color pallet used in the paper

# blue
gaze_color_1 = "#066da8"#"#0570b0"
gaze_color_2 = "#77aacf"#"#74a9cf" #d 
gaze_color_3 = "#bfcae0"#"#bdc9e1" #d

# red
sacc_color_1 = "#e31a1c" #d
sacc_color_2 = "#fd8d3c" #d
sacc_color_3 = "#fecc5c" #d

# green
vel_eye_color = "#238443" #d
vel_head_color = "#d8b365" #"#c2e699"

# grey/other 525252
ten_thres_color = "#525252"
dd_thres_color = "#a4a1a4" #"#bdbdbd"
colliders_color = "#987284"

# Experiment Setup

In [None]:
fname = "Arial" # font name
numbersize = 60 #A, B etc.

plt.figure(figsize=(30, 20), constrained_layout=True)

sns.set_style("white") 
# grid to have the subplots arranged nicely
ax1 = plt.subplot2grid(shape=(2, 4), loc=(0, 0), rowspan=2, colspan=2)
ax2 = plt.subplot2grid(shape=(2, 4), loc=(0, 2), rowspan=1, colspan=2)
ax3 = plt.subplot2grid(shape=(2, 4), loc=(1, 2), rowspan=1, colspan=2)

# load image
img = plt.imread("./images/IMG_4552.jpg") # image credits: (c) Simone Reukauf
ax1.imshow(img)
ax1.axis('off') # no black box surrounding the image
ax1.text(0.18, -0.015, 'Photo: (c) Simone Reukauf',fontsize=30, horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes, fontname=fname) # credit

ax1.set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.02, pad=-30, fontname=fname) 

img = plt.imread("./images/unity2.png")
ax2.imshow(img)
ax2.axis('off')
ax2.set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 

img = plt.imread("./images/birds-eye-view.png")
ax3.imshow(img) 
ax3.axis('off')
ax3.set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 


# Check % of data interpolated

In [None]:
''' 
Check % of data that has been interpolated.
Code taken from 7v - the interpolation
Adjusted to save the amount of interpolated data.
'''

interpolation = {} # save the interpolated data

ids = recordings.index.tolist()
idd = ids[:]
# loop through all subjects
for uid in idd:
    print(uid)
    interpolation[uid] = {}
    
    # replacement for nans
    nanrep = np.nan

    # to determin the area around the blinks
    min_blink_duration = 0.02
    dilate_nan = 0.023  # add it as two samples before and after blink onset

    # get the cooridnates
    hit_sort = pd.read_csv(f"{PATH_PROC}/Behavior_new_{uid}.csv", index_col=0)
    hit_pos = pd.read_csv(f"{PATH_PROC}/HitsSorted_new_{uid}.csv", index_col=0)

    # delete identical rows:
    hit_sort = hit_sort[~hit_sort.index.duplicated(keep="first")]
    hit_pos = hit_pos[~hit_pos.index.duplicated(keep="first")]

    # to check for wrong samples:
    x_origin = hit_sort["ETWoriginX"].tolist()

    # the data we interpolate
    x_coord = hit_sort["ETWdirectionX"].tolist()
    hpoo_x = hit_pos["HPOOX"].tolist()

    # and get the timestamps
    times = hit_sort.index.tolist()
    times_hpoo = hit_pos.index.tolist()

    # get valid data
    val = hit_sort["valid"].tolist()  # original valid stream
    
    # adjust valid: if both eyes are closed: change to 0.0 in valid
    no_valids_corr_df = hit_sort[
        (hit_sort.leftBlink == 1.000)
        & (hit_sort.rightBlink == 1.000)
        & (hit_sort.valid == 1.000)
    ]
    no_val = [
        0.0 if t in no_valids_corr_df.index else 1.0 for t in times
    ]  # get a list out of df

    # adjust valid: if the position is too far away: change to -2.0 in valid
    no_valids_corr_df = hit_sort[
        (hit_sort.ETWoriginX < 400.0) & (hit_sort.valid == 1.000)
    ]
    no_val1 = [
        0.0 if t in no_valids_corr_df.index else 1.0 for t in times
    ]  # get a list out of df
    valid = [
        0.0
        if (no_val[v] == 0.0) or (no_val1[v] == 0.0) or (val[v] == 0.0)
        else 1.0
        for v in range(len(val))
    ]  # add it t
    
    # save % of data that is invalid before we do the blink correction
    interpolation[uid]['inval_before_blinks'] = (len(valid) - sum(valid))/len(valid)*100

    
    # to check how many close samples there were
    cnt = False
    cnt_item = 0.0

    # to get the blink onset
    blinking = False  # will be true during the blinking
    blink_time = 0.0  # will be updated to blink onset
    blinks = [0.0] * len(times)
    
    # go through the entire list of timestamps
    for t, item in enumerate(times):
        # if this is not part of the hpoo:
        if not item in times_hpoo:
            hpoo_x.insert(t, nanrep)
            times_hpoo.insert(t, item)
            
        # to check for blinks: get blink onset
        if (valid[t] != 1.0) and not blinking:
            blinking = True
            # get the time of blinking onset
            blink_time = item
        # if they are over, do:
        elif valid[t] == 1.0 and blinking:
            # reset the blinking parameter
            blinking = False
            # get the end of the blink (so the previous time stamp)
            it = times[t - 1]
            # if the blink duration is bigger than the min blink duration:
            # adapt the blink duration plus the dialate_nan
            if it - blink_time >= min_blink_duration:
                # we want to exchange all the items with nan, so get the time interval we need to change it in:
                ts = times[
                    times.index(
                        list(
                            filter(
                                lambda i: i >= (blink_time - dilate_nan), times
                            )
                        )[0]
                    ) : times.index(
                        list(filter(lambda i: i <= (it + dilate_nan), times))[
                            -1
                        ]
                    )
                    + 1
                ]  # get all timestamps in the important time window
                for t_s in ts:
                    # if this is not part of the hpoo (as we add two elements after the current one):
                    if not t_s in times_hpoo:
                        hpoo_x.insert(times.index(t_s), nanrep)
                        times_hpoo.insert(times.index(t_s), item)

                    # ETW
                    x_coord[times.index(t_s)] = nanrep
                    hpoo_x[times.index(t_s)] = nanrep
                
                    # add the number to blinks
                    blinks[times.index(t_s)] = nanrep
            # if the blink duration is too small, we do not add an additional window around it
            else:
                ts = times[
                    times.index(
                        list(filter(lambda i: i >= (blink_time), times))[0]
                    ) : times.index(
                        list(filter(lambda i: i <= (it), times))[-1]
                    )
                    + 1
                ]  # get all timestamps in the important time window
                for t_s in ts:
                    # ETW
                    x_coord[times.index(t_s)] = nanrep
        
                    # hpoo
                    hpoo_x[times.index(t_s)] = nanrep
                    
    # go through the other list and delete all elmenets that are not in time
    # so all the elements that we cannot add the the .csv like this
    to_del = list(set(times_hpoo) - set(times))
    for i in to_del:
        hpoo_x.pop(times_hpoo.index(i))
        times_hpoo.pop(times_hpoo.index(i))

    # create a df out of all the important lists:
    for_eye = list(
        zip(
            times,
            valid,
            x_coord,
            hpoo_x,
            blinks,
        )
    )
    for_eye = pd.DataFrame(
        for_eye,
        columns=[
            "time",
            "valid",
            "xcoord",
            "xhpoo",
            "blinks",
        ],
    )
    for_eye.set_index("time")

    # get % of data that is considered invalid after the blink correctoin
    v_nan = for_eye[~for_eye["xcoord"].isnull()] 
    interpolation[uid]['inval_after_blinks'] = (len(for_eye) - len(v_nan))/len(for_eye)*100
    
    # interpolate
    for column_name in for_eye:
        # do not interpolate these columns
        if column_name not in [
            "time",
            "valid",
        ]:
            
            b = for_eye[column_name].values.tolist()
            # get number of nan
            v = [
                len(list(group))
                for key, group in groupby(b, key=pd.isnull)
                if key
            ]


            # get corresponding time for each group in v
            idx = [
                idx + 1
                for idx in range(len(b) - 1)
                if not pd.isnull(b[idx]) and pd.isnull(b[idx + 1])
            ]
            if pd.isnull(b[0]):
                idx.insert(
                    0, 0
                )  # if the first element is nan, it will be added here

            # interpolate data
            for_eye[column_name] = for_eye[column_name].interpolate(
                method="linear", limit_direction="both"
            )
            # go through v: if the beginning and end difference is bigger than allowed, replace interpolated data with nan
            b = for_eye[column_name].values.tolist()
            b = np.array(
                b
            )  # for the filling in an array is needed instead of a list
            for t, item in enumerate(idx):
                # finish for the last timestamp
                if item + v[t] == len(times):
                    break
                # if the distance is bigger then 250ms we do not want to interpolate --> replace values with nan
                if times[item + v[t]] - times[item] > 0.25:
                    b[item : item + v[t]] = np.nan * len(b[item : item + v[t]])

            # replace the column with interpolated one
            for_eye[column_name] = b.tolist()

    # get the amount of data that had been interpolated
    b_nan = for_eye[~for_eye["xcoord"].isnull()] 
    interpolation[uid]['not_interpolated'] = (len(for_eye) - len(b_nan))/len(b)*100
    interpolation[uid]['interpolated'] = ((len(for_eye) - len(v_nan)) - (len(for_eye) - len(b_nan)))/len(for_eye)*100

# save itas a df
interpolation = pd.DataFrame(interpolation).transpose()    
display(interpolation)

# calculate median 
print()
print(f"Median Amount of Interpolation: {np.nanmedian(interpolation['interpolated'])}")

# Interquartile range:
q75, q25 = np.nanpercentile(interpolation['interpolated'].tolist(), [75, 25])
iqr_interp = q75 - q25
print(f"IQR gaze DD: {iqr_interp} ({q25}-{q75})")

# Schematic of translational movement

In [None]:
labelsize = 40 #text
legendsize = 40 #ledgend
ticksize = 30 #ticks
numbersize = 60 #A, B etc.
fname = "Arial" # font name
labelsize = 40 #text
linewidth = 5

# setting up the figure (sns + plt)
sns.set(rc={"figure.figsize": (30, 30)})

sns.set_style("white")
f, (ax1,ax2) = plt.subplots(1,2) # two plots next to each other
img = plt.imread("images/S1.png") # we have the same background for both of them

######## positions ########
# note: 0/0 is in the upper left corner of the image
sub_1_x = 80 # x-position of left points
sub_2_x = 280 # x-position of the right points
sub_y = 380 # y-position of the subject position 
sub_1_y = 190 # y-position of the left hitpoint
sub_2_y = 112 # y-positon of the right hitpoint

######## pannel 1 ########
ax1.imshow(img, alpha = 0.85) # display the background image

# points
ax1.scatter([sub_1_x,sub_2_x,sub_1_x,sub_2_x], [sub_y,sub_y,sub_1_y,sub_2_y], marker='o', facecolors='black', 
            zorder=3, linewidth = 2, edgecolors='black', s=300)
# arrows 
V = np.array([[0,0.535], [0,0.76], [0.575,0.76], [0.565,0]]) # direction of the arrows
origin = np.array([[sub_1_x, sub_2_x, sub_1_x, sub_1_x],[sub_y, sub_y, sub_y, sub_y]]) # origin points: [all x coords], [all y coords]
# create the arrows
ax1.quiver(*origin, V[:,0], V[:,1], color=[gaze_color_1,sacc_color_1,vel_eye_color, ten_thres_color], 
           scale=1, linewidth = 0.75, edgecolor='black', width=0.011)
# add all text
ax1.text(sub_1_x, sub_y+28, 'Subject\nPosition 1', horizontalalignment='center', verticalalignment='center', rotation=0, 
         color='black', size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax1.text(sub_2_x, sub_y+28, 'Subject\nPosition 2', horizontalalignment='center', verticalalignment='center', rotation=0, 
         color='black', size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax1.text(sub_1_x+(sub_2_x-sub_1_x)/2, sub_y, 'Translation', horizontalalignment='center', verticalalignment='center', rotation=0, 
         color=ten_thres_color, size=labelsize, weight="bold", fontname=fname,
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax1.text(sub_1_x, sub_1_y-28, 'Hit\nPoint 1', horizontalalignment='center', verticalalignment='center', rotation=0, 
         color='black', size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax1.text(sub_2_x, sub_2_y-28, 'Hit\nPoint 2', horizontalalignment='center', verticalalignment='center', rotation=0,
         color='black', size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
# remove axis ticks
ax1.set_xticks([])
ax1.set_yticks([])



######## pannel 2 ########
ax2.imshow(img, alpha = 0.85) # display the background image

# points
ax2.scatter([sub_1_x,sub_2_x,sub_1_x,sub_2_x], [sub_y,sub_y,sub_1_y,sub_2_y], marker='o', facecolors='black', 
            zorder=3, linewidth = 2, edgecolors='black', s=300)

# gaze directions (lines instead of arrows)
ax2.plot([sub_2_x,sub_2_x], [sub_y,sub_2_y], linestyle='solid', marker='', color='k', alpha = 0.5, 
         zorder=1, linewidth = linewidth+2.5) # outline 
ax2.plot([sub_2_x,sub_2_x], [sub_y,sub_2_y], linestyle='solid', marker='', color=dd_thres_color, alpha = 0.5, 
         zorder=1, linewidth = linewidth)
ax2.plot([sub_1_x+0.4,sub_1_x+0.4], [sub_y,sub_1_y], linestyle='solid', marker='', color='k', alpha = 1, 
         zorder=1, linewidth = linewidth+2.5) # outline 
line_1 = Line2D([sub_1_x,sub_1_x], [sub_y,sub_1_y], linestyle='solid', marker='', color=ten_thres_color, alpha = 0.8, 
                zorder=1, linewidth = linewidth)
ax2.add_line(line_1)
line_2 = Line2D([sub_1_x,sub_2_x], [sub_y,sub_2_y], linestyle='solid', marker='', color=ten_thres_color, alpha = 0.8, 
                zorder=1, linewidth = linewidth, 
               path_effects=[pe.Stroke(linewidth=linewidth+2.5, foreground='black'), pe.Normal()])
ax2.add_line(line_2)

# arrow
V = np.array([[0,0.222], [0,0.222], [0.575,0.222]]) # direction of to arrows 
origin = np.array([[sub_1_x, sub_1_x, sub_1_x],[sub_y, sub_1_y, sub_1_y]]) # origin points: [all x coords], [all y coords]
# plot the arrows
ax2.quiver(*origin, V[:,0], V[:,1], color=[sacc_color_1,sacc_color_1,vel_eye_color], scale=1, linewidth = 0.75, 
           edgecolor='black', width=0.011)

# v_gaze_inplane - line 
ax2.plot([sub_1_x,sub_2_x], [sub_2_y,sub_2_y], linestyle='solid', marker='', color=gaze_color_1, alpha = 1, zorder=1, 
         linewidth = linewidth, path_effects=[pe.Stroke(linewidth=linewidth+2.5, foreground='black'), pe.Normal()])

# draw the angle 
l1xy = line_1.get_xydata()
# calcualte the angle between line1 and x-axis
slope1 = (l1xy[1][1] - l1xy[0][1]) / float(l1xy[1][0] - l1xy[0][0])
angle1 = math.degrees(math.atan(slope1)) # Taking only the positive angle
l2xy = line_2.get_xydata()
# calculate the angle between line2 and x-axis
slope2 = (l2xy[1][1] - l2xy[0][1]) / float(l2xy[1][0] - l2xy[0][0])
angle2 = math.degrees(math.atan(slope2))

# now use this infromatino to draw an angle between the two lines
theta1 = min(angle1, angle2)
theta2 = max(angle1, angle2)
angle = theta2 - theta1
offset = 160
origin = [sub_1_x,sub_y]
len_x_axis = 1
len_y_axis = 1
angle_plot = Arc(origin, len_x_axis*offset, len_y_axis*offset, 0, theta1, theta2, color=sacc_color_2, zorder=0.5, 
                label = str(angle)+u"\u00b0", linewidth = linewidth+8,
                path_effects=[pe.Stroke(linewidth=linewidth+10.5, foreground='black'), pe.Normal()])
ax2.add_patch(angle_plot) # To display the angle arc

# add all text
ax2.text(sub_1_x, sub_2_y+(sub_1_y-sub_2_y)/2, 'eye-vec', horizontalalignment='center', verticalalignment='center', 
         rotation=0, color=sacc_color_1, size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax2.text(sub_1_x+(sub_2_x-sub_1_x)/2, sub_2_y, 'v-eye-in-plane', horizontalalignment='center', verticalalignment='center', 
         rotation=0, color=gaze_color_1, size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax2.text(sub_1_x+(sub_2_x-sub_1_x)/2, sub_2_y+(sub_1_y-sub_2_y)/2, 'v-eye-vec', horizontalalignment='center', verticalalignment='center', 
         rotation=0, color=vel_eye_color, size=labelsize, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))
ax2.text(sub_1_x+14, sub_y-40, 'w-eye', horizontalalignment='center', verticalalignment='center', 
         rotation=0, color=sacc_color_2, size=labelsize+12, weight="bold", fontname=fname, 
         bbox=dict(facecolor='w', edgecolor='black', boxstyle='round,pad=0.2'))

ax2.set_xticks([])
ax2.set_yticks([])

ax1.set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax2.set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)

# For LSL part of paper:

In [None]:
'''
Check the differnce between the beginning and end.
'''

ids = recordings.index.tolist()
idd = ids[:]

stats = {} # df to save the results
# loop through all
for uid in idd:
    stats[uid] = {}
    # load raw data
    part = recordings.loc[uid].file
    data, _ = pyxdf.load_xdf(f"{PATH_RAW}/{part}")

    starts = [] # save all Unity starts
    ends = [] # save all unity ends
    # loop through all streams
    for s in data:
        # stream name
        s_name = s["info"]["name"][0]
        # if EEG signal, save start and end
        if "openvibeSignal" in s_name:
            time_eeg = s["time_stamps"][0]
            time_eeg_end = s["time_stamps"][-1]
        # if not EEG, save start and end
        elif "openvibeMarkers" not in s_name:
            starts.append(s["time_stamps"][0])
            ends.append(s["time_stamps"][-1])
        # to compare difference between Unity start and ETW start (currently set to zero)
        if "EyeTrackingWorld" in s_name:
            et_start = s["time_stamps"][0]

    # get the difference of the start and end
    start = min(starts) - time_eeg
    end = max(ends) - time_eeg_end
    
    stats[uid]['start'] = start
    stats[uid]['end'] = end
    stats[uid]['difference'] = end - start
stats = pd.DataFrame(stats).transpose()
display(stats)

# calculate median and IQR
print()
print(f"Median temporal shift: {round(np.nanmedian(stats['difference']),3)*1000}")
q75, q25 = np.nanpercentile(stats['difference'].tolist(), [75, 25])
iqr_seg = q75 - q25
print(f"IQR gaze DD: {iqr_seg} ({round(q25,3)*1000}-{round(q75,3)*1000})")

# Data for Manual Classification

In [None]:
'''
For the manual classification, we take the recorded data before interpolation and blink detection.
'''
ids = recordings.index.tolist()
idd = ids[:1]
# loop through all subjects
for uid in idd:
    # get the data before we interpolate (so the raw data measured)
    hit_sort = pd.read_csv(f"{PATH_PROC}/Behavior_new_{uid}.csv", index_col=0)
    # get the timestamps
    times = hit_sort.index.tolist()
    
    # adjust the valid data
    val = hit_sort["valid"].tolist()  # original valid stream
    # adjust valid: if both eyes are closed: change to 0.0 in valid
    no_valids_corr_df = hit_sort[
        (hit_sort.leftBlink == 1.000)
        & (hit_sort.rightBlink == 1.000)
        & (hit_sort.valid == 1.000)
    ]
    no_val = [
        0.0 if t in no_valids_corr_df.index else 1.0 for t in times
    ]  # get a list out of df
    # adjust valid: if the position is too far away: change to -2.0 in valid
    no_valids_corr_df = hit_sort[
        (hit_sort.ETWoriginX < 400.0) & (hit_sort.valid == 1.000)
    ]
    no_val1 = [
        0.0 if t in no_valids_corr_df.index else 1.0 for t in times
    ]  # get a list out of df
    valid = [
        0.0
        if (no_val[v] == 0.0) or (no_val1[v] == 0.0) or (val[v] == 0.0)
        else 1.0
        for v in range(len(val))
    ]  # add it t
    
    # save valid data so that we have the information in only one list
    hit_sort["valid"] = valid
    

    # change column names so that they are consistent with for_eye
    hit_sort["xcoord"] = hit_sort["ETWdirectionX"]
    hit_sort["ycoord"] = hit_sort["ETWdirectionY"]
    hit_sort["zcoord"] = hit_sort["ETWdirectionZ"]
    hit_sort["time"] = times
    
    # only select a subset of columns 
    columns = [
        "time",
        "valid",
        "xcoord",
        "ycoord",
        "zcoord",
    ]
    hit_sort = hit_sort[columns]
    
    # save df:
    hit_sort.to_csv(f"{PATH_HANDLABELS}/hand_labeling_{uid}.csv", index=True)

# Segment Statistics

In [None]:
# save the median interval duration of data-driven 
list_interval_dur = []

ids = recordings.index.tolist()
idd = ids[:]
# loop though all data
for it,uid in enumerate(idd):
    # load the interval df
    int_data = pd.read_csv(
        f"{PATH_FOREYE}/interval_mad_wobig_{uid}.csv", index_col=0
    )
    # go through all intervals and save the duration for each of them
    interval_dur = []
    for i in int_data.index.tolist():
        interval_dur = interval_dur + [int_data.iloc[i]["end"] - int_data.iloc[i]["start"]]
        
    # add the value to the list of all subjects
    list_interval_dur = list_interval_dur + [np.median(interval_dur)]
    
    
# transfrom the list_interval_dur list into a df for statistics
list_interval_dur_df = list(zip(list_interval_dur,))
list_interval_dur_df = pd.DataFrame(list_interval_dur_df,columns=["average_interval_len",],)
# display min, max, median and mean
print()
print(f"Min Data Segment Duration: {round(np.nanmin(list_interval_dur_df['average_interval_len']),3)}")
print(f"Max Data Segment Duration: {round(np.nanmax(list_interval_dur_df['average_interval_len']),3)}")
print(f"Median Data Segment Duration: {round(np.nanmedian(list_interval_dur_df['average_interval_len']),3)}")
print(f"Mean Data Segment Duration: {round(np.nanmean(list_interval_dur_df['average_interval_len']),3)}")
# Interquartile range:
q75, q25 = np.nanpercentile(list_interval_dur_df['average_interval_len'].tolist(), [75, 25])
iqr_seg = q75 - q25
print(f"IQR gaze DD: {iqr_seg} ({round(q25,3)}-{round(q75,3)})")

# Classification Fit

In [None]:
# this can be done for all subjects if desired
ids = recordings.index.tolist()
idd = ids[:1] # right now, we only plot one, but this code can also be used to check the individual subjects

# select a short window to plot
window_lower = 647.2 # s: time stamp of the start of the window
window_upper = 649.5 # s: time stamp of he end of the window

# define size of: 
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name

# for the grid of the plot
nr_r = 32 # number of rwos
rs = 6 # rowspan

# loop through the subejcts you want to plot
for i, uid in enumerate(idd):
    # prepare the figure layout
    f = plt.figure(figsize=(30, 34), constrained_layout=True)
    sns.set_style("white") 
    ax0 = plt.subplot2grid(shape=(nr_r, 1), loc=(0, 0), rowspan=rs) # 10-ssecond hit points
    ax1 = plt.subplot2grid(shape=(nr_r, 1), loc=(6, 0), rowspan=rs) # data-driven hit points
    ax2 = plt.subplot2grid(shape=(nr_r, 1), loc=(13, 0), rowspan=rs) # 10-second gaze direction
    ax3 = plt.subplot2grid(shape=(nr_r, 1), loc=(19, 0), rowspan=rs) # data-driven gaze direction
    ax4 = plt.subplot2grid(shape=(nr_r, 1), loc=(26, 0), rowspan=rs) # velocities 
    

    # creat text 
    ax0.text(647.15,60, '10-Second Method', horizontalalignment='left', verticalalignment='center', rotation=0,
            size=labelsize, fontname=fname, bbox=dict(facecolor='w', boxstyle='round,pad=0.2'))
    ax1.text(647.15,60, 'Data-Driven Method', horizontalalignment='left', verticalalignment='center', rotation=0,
            size=labelsize, fontname=fname, bbox=dict(facecolor='w', boxstyle='round,pad=0.2'))
    ax2.text(647.15,0.86, '10-Second Method', horizontalalignment='left', verticalalignment='center', rotation=0,
            size=labelsize, fontname=fname, bbox=dict(facecolor='w', boxstyle='round,pad=0.2'))
    ax3.text(647.15,0.86, 'Data-Driven Method', horizontalalignment='left', verticalalignment='center', rotation=0,
            size=labelsize, fontname=fname, bbox=dict(facecolor='w', boxstyle='round,pad=0.2'))
    
    # to allow for loops
    axis = [ax0, ax1, ax2, ax3, ax4]
    
    # a loop for the 10-second and data-driven method:
    for fe in range(2):
        # loop for hit points and gaze direction
        for gh in range(2):
            # load data
            # 10-second
            if fe == 0:
                for_eye = pd.read_csv(
                    f"{PATH_FOREYE}/correTS__10sec_{uid}.csv", index_col="time"
                )
                half_t = "10-Second" # second half of title
            # data-driven
            else:
                for_eye = pd.read_csv(
                    f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv",
                    index_col="time", 
                )
                half_t = "Data-Driven" # second half of title
            if gh == 0:
                titel = "Hit Points: " + half_t 
            else:
                titel = "Gaze Directions: " + half_t
                
            # get the direction vector
            if gh == 1:
                # as we do not save this, we have to recompute, considering the translational movement
                # get individual coordinates
                # eye position
                Xcorr_position = for_eye["xcoord_orig"].tolist()
                Ycorr_position = for_eye["ycoord_orig"].tolist()
                Zcorr_position = for_eye["zcoord_orig"].tolist()
                subj = list(
                    zip(Xcorr_position, Ycorr_position, Zcorr_position)
                )

                # hit points
                hpooX = for_eye["xhpoo"].tolist()
                hpooY = for_eye["yhpoo"].tolist()
                hpooZ = for_eye["zhpoo"].tolist()
                hpoo = list(zip(hpooX, hpooY, hpooZ))

                # gaze_vec(t) is a unit vector in the direction of the gaze (eye+head) in world coordinates
                g_vec = [
                    np.array(hpoo[v] - np.array(subj[v]))
                    for v in range(len(subj))
                ]
                gaze_vec = [
                    np.array(v) / np.linalg.norm(np.array(v)) for v in g_vec
                ]
                
                # create df our of the direction vector to plot it
                gaze_vec = pd.DataFrame(
                    gaze_vec, columns=["gvx", "gvy", "gvz"]
                )
                gaze_vec["time"] = for_eye.index.tolist()
                gaze_vec = gaze_vec.set_index("time")
                # add it to for_eye so we can differentiate between gaze and saccade
                for_eye = pd.concat([for_eye, gaze_vec], axis=1)
            
            # get time:
            ts = for_eye.index.tolist()  # to make it easier
            # get a shot time interval, defined by window_lower and window_upper
            time = ts[
                ts.index(
                    list(filter(lambda i: i > window_lower, ts))[0]
                ) : ts.index(list(filter(lambda i: i < window_upper, ts))[-1])
                + 1
            ]  # get all timestamps in the important time window

            # now use the short time list to shorten the df
            for_eye = for_eye.iloc[
                ts.index(time[0]) : (ts.index(time[-1]) + 1)
            ]
            
            
            # hit points:
            if gh == 0:
                # substract 600 from the x and y coordinates for easier plotting (this is due to the
                # coordinate system of the Unity project)
                for_eye["xhpoo"] = list(
                    map(lambda x: x - 600, for_eye["xhpoo"].tolist())
                )  
                for_eye["zhpoo"] = list(
                    map(lambda x: x - 600, for_eye["zhpoo"].tolist())
                )  
            
            # get the timepoints whenever a new collider was hit (plottet as faint lines)
            hon = for_eye["hon"].tolist()
            hon_ts = [
                ti for cnt, ti in enumerate(time) if isinstance(hon[cnt], str)
            ]  # timestamps

            # separate between gazes and saccades to be plotted in different colors
            # get gazes:
            gaze = for_eye[~for_eye["isFix"].isnull()]
            gaze = gaze[~gaze["long_events"].isnull()]
            # get saccades:
            sacc = for_eye[~for_eye.index.isin(gaze.index)]
            sacc = sacc[~sacc["long_events"].isnull()]

            # rename the columns, so that the plotting can be donefor gaze direction and hit points with the same code
            # and for the ledgend to be the same
            if gh == 0:
                gaze = gaze.rename(
                    {"xhpoo": "Gaze x", "yhpoo": "Gaze y", "zhpoo": "Gaze z"}, axis=1
                )
                sacc = sacc.rename(
                    {"xhpoo": "Sacc x", "yhpoo": "Sacc y", "zhpoo": "Sacc z"}, axis=1
                )
            else:
                gaze = gaze.rename(
                    {"gvx": "Gaze x", "gvy": "Gaze y", "gvz": "Gaze z"}, axis=1
                )

                sacc = sacc.rename(
                    {"gvx": "Sacc x", "gvy": "Sacc y", "gvz": "Sacc z"}, axis=1
                )

            # plot long events as outliers
            long_events = for_eye[for_eye["long_events"].isnull()]
            long_events = long_events.rename({"xhpoo": "outliers"}, axis=1)
            
            # get the blins
            blinks = for_eye[for_eye["blinks"].isnull()]

            # get the axis to plot:
            if gh == 0:
                axis_nr = fe # hit points
            else:
                axis_nr = 2+fe # gaze direction

            # plot collider changes:
            for x, xc in enumerate(hon_ts):
                if not np.isnan(xc):
                    axis[axis_nr].axvline(
                        x=xc, color=colliders_color, alpha=0.4, label="_Hidden label"
                    )

            # assign colors to the different columns
            color_gaze = {
                "Gaze x": gaze_color_1,
                "Gaze y": gaze_color_2,
                "Gaze z": gaze_color_3,
            }
            color_sacc = {
                "Sacc x": sacc_color_1,
                "Sacc y": sacc_color_2, 
                "Sacc z": sacc_color_3,
            }
            
            # plot the data
            gaze[["Gaze x", "Gaze y", "Gaze z"]].plot(
                color=[
                    color_gaze.get(x, "#333333")
                    for x in gaze[["Gaze x", "Gaze y", "Gaze z"]]
                ],
                ax=axis[axis_nr],
                marker="o", ms=8,
                ls="",
            )
            sacc[["Sacc x", "Sacc y", "Sacc z"]].plot(
                color=[
                    color_sacc.get(x, "#333333")
                    for x in sacc[["Sacc x", "Sacc y", "Sacc z"]]
                ],
                ax=axis[axis_nr],
                marker="o", ms=8,
                ls="",
            )

            # we only want x-ticks for the last plot
            axis[axis_nr].set_xticklabels([])
            
            # we only want the legend in the first and last plot
            if axis_nr > 0 :
                axis[axis_nr].get_legend().remove()
            else:
                handles, labels = axis[0].get_legend_handles_labels()
                axis[axis_nr].get_legend().remove()
                
            axis[axis_nr].xaxis.label.set_visible(False) # no x-axis label either
            axis[axis_nr].set_ylabel("Coordinates", fontsize=labelsize, fontname=fname) # y-axis label
            # change the ticksize and fontname 
            for label in axis[axis_nr].get_yticklabels():
                label.set_fontproperties(fname)
            axis[axis_nr].yaxis.set_tick_params(labelsize=ticksize)  # change tick size
            
            # we want to plot the last plot only once
            if fe == 0 and gh == 0:
                # head velcoity
                axis[4].plot(
                    time, for_eye["HT_combined_vel"].tolist(), vel_head_color,linestyle =":",linewidth=5, label="Angular Velocity (Head)"
                )
                # eye velocity
                axis[4].plot(
                    time,
                    for_eye["combined_vel"].tolist(),
                    vel_eye_color,
                    linestyle = "--",
                    linewidth=3,
                    label="Angular Velocity (Eye)",
                    alpha=0.9
                )
                # get the colors for the thresholds
                c = [ten_thres_color, dd_thres_color]
            if gh == 0:
                # plot the threshold for the 10-second and data-driven intervals
                axis[4].plot(
                    time,
                    for_eye["thresh"].tolist(),
                    c[fe],
                    linewidth=4,
                    label=half_t + " Threshold",
                )


    f.legend(handles, labels, bbox_to_anchor=(0.5, 0.486, 0.5, 0.5), fontsize=legendsize, fancybox=True, markerscale=2.) #, framealpha=1)
    
    axis[4].set_ylim(0, 600)

    # add a ledgent to the last plot
    legend = axis[4].legend(loc="upper right", fontsize=legendsize, markerscale=2.)
    # add axis labels to the last plot
    axis[4].set_xlabel("Time (sec)", fontsize=labelsize, fontname=fname)
    axis[4].set_ylabel("Veloctiy", fontsize=labelsize, fontname=fname)
    # the the ticksize and font of the last plot
    for label in axis[4].get_xticklabels():
        label.set_fontproperties(fname)
    for label in axis[4].get_yticklabels():
        label.set_fontproperties(fname)
    axis[4].yaxis.set_tick_params(labelsize=ticksize)  # change tick size
    axis[4].xaxis.set_tick_params(labelsize=ticksize)

    # plot labels
    axis[0].set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname)
    axis[2].set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname)
    axis[4].set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname)
    
    
    ####### Add boxes ########
    y_up, y_down = axis[0].get_ylim(), axis[4].get_ylim() # mind and max y-limit
    
    # blue square:
    axis[0].hlines(max(y_up), 647.6, 647.65, linewidth=5, color=gaze_color_1)
    axis[4].hlines(min(y_down)+20, 647.6, 647.65, linewidth=5, color=gaze_color_1)
    line1 = ConnectionPatch(xyA=[647.6,min(y_down)+20], xyB=[647.6,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[4], axesB=axis[0], color=gaze_color_1, lw=5,zorder=2)
    line2 = ConnectionPatch(xyA=[647.65,min(y_down)+20], xyB=[647.65,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[4], axesB=axis[0], color=gaze_color_1, lw=5,zorder=2)
    axis[4].add_artist(line1)
    axis[4].add_artist(line2)

    # red square:
    axis[0].hlines(max(y_up), 648.3, 648.5, linewidth=5, color=sacc_color_1)
    axis[4].hlines(min(y_down)+20, 648.3, 648.5, linewidth=5, color=sacc_color_1)
    line1 = ConnectionPatch(xyA=[648.3,min(y_down)+20], xyB=[648.3,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[4], axesB=axis[0], color=sacc_color_1, lw=5,zorder=2)
    line2 = ConnectionPatch(xyA=[648.5,min(y_down)+20], xyB=[648.5,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[4], axesB=axis[0], color=sacc_color_1, lw=5,zorder=2)
    axis[4].add_artist(line1)
    axis[4].add_artist(line2)
    
    # green square:
    y_down1 =  axis[1].get_ylim()
    axis[0].hlines(max(y_up), 648.85, 648.9, linewidth=5, color=vel_eye_color)
    axis[1].hlines(min(y_down1)+5, 648.85, 648.9, linewidth=5, color=vel_eye_color)
    line1 = ConnectionPatch(xyA=[648.85,min(y_down1)+5], xyB=[648.85,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[1], axesB=axis[0], color=vel_eye_color, lw=5,zorder=3)
    line2 = ConnectionPatch(xyA=[648.9,min(y_down1)+5], xyB=[648.9,max(y_up)], coordsA="data", coordsB="data",
                          axesA= axis[1], axesB=axis[0], color=vel_eye_color, lw=5,zorder=3)
    axis[1].add_artist(line1)
    axis[1].add_artist(line2)



    plt.show()

# Eye and Head Movements

In [None]:
'''
Compare if there is a statistical significant difference 
between distance of gaze and saccade 
This code is only  for the data-driven interval 
'''

ids = recordings.index.tolist()
idd = ids[:]

gze = [] # save median gaze distance of each subject
sac = [] # save median saccade distance of each subject
for uid in idd:
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )

    # separate between gaze and saccade
    sacc = for_eye[for_eye["events"] == 1.0]
    # exclude long events
    sacc = sacc[~sacc["long_events"].isnull()]
    # add median saccade distance
    sac.append(np.nanmedian(sacc["avg_dist"]))

    # separate between gaze and saccade
    gaze = for_eye[for_eye["events"] == 2.0]
    # exclude long events
    gaze = gaze[~gaze["long_events"].isnull()]
    # add median gaze distance
    gze.append(np.nanmedian(gaze["avg_dist"]))

print("data-driven interval")
# as they are not completely normally distributed, we will use median to get the average results:
med_gaze = np.nanmedian(gze)
med_sacc = np.nanmedian(sac)
print(f"Median gaze distance: {med_gaze}")
print(f"Median sacc distance: {med_sacc}")
# Interquartile range:
q75_g, q25_g = np.nanpercentile(gze, [75, 25])
iqr_gaze = q75_g - q25_g
q75_s, q25_s = np.nanpercentile(sac, [75, 25])
iqr_sacc = q75_s - q25_s
print(f"IQR gaze distance: {iqr_gaze} ({q25_g}-{q75_g})")
print(f"IQR sacc distance: {iqr_sacc} ({q25_s}-{q75_s})")

# Perform the KS test
print()
statistic, pvalue = ks_2samp(gze, sac)
print(f"KS statistic: {statistic:.4f}")
print(f"P-value: {pvalue:.4f}")
alpha = 0.05
print(f"Alpha: {alpha:.4f}")

print()

In [None]:
# Function to calcualte the saccade amplitude
def compute_centroid(df, columns):
    '''
    Computes the mean of the columns of a df.
    
    Parameters:
        df: the df to compute the mean of
        columns: column names of the column's to computer the mean of
    '''
    centroid = df[columns].mean().values
    return centroid

def compute_saccade_amplitude(prev_gaze_centroid, prev_subject_centroid, gaze_centroid, subject_centroid):
    '''
    Compute the saccade amplitude while correcting for translational movement.
    
    Parameters:
        prev_gaze_centroid: the previous mean hit position
        prev_subject_centroid: the previous mean eye position
        gaze_centroid: the current mean hit position
        subject_centroid: the current mean eye position
    '''
    v_eye_vec = gaze_centroid - prev_gaze_centroid 
    eye_vec = prev_gaze_centroid - prev_subject_centroid
    eye_vec = eye_vec/np.linalg.norm(eye_vec)
    projection = np.dot(v_eye_vec, eye_vec) * eye_vec
    v_eye_inplane = np.linalg.norm(v_eye_vec - projection)
    sacc_amplitude = np.arctan2(v_eye_inplane, np.linalg.norm(prev_subject_centroid - prev_gaze_centroid))
    return np.degrees(sacc_amplitude)

In [None]:
'''
Plot the eye and head movements, position and distance.
'''
ids = recordings.index.tolist()
idd = ids[:1]
gze = []
sac = []
cnt = 0 # to adjust line style for the distance plot
# used for the position plot
min_x = 505
max_x = 642
min_y = 519
max_y = 655

# set plotting parameters for:
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name

plt.figure(figsize=(32, 32), constrained_layout=True)

sns.set_style("white") 
plt.rcParams["font.family"] = fname
# define figure grid
ax11 = plt.subplot2grid(shape=(6, 6), loc=(0, 0), rowspan=2, colspan=2)
ax12 = plt.subplot2grid(shape=(6, 6), loc=(0, 2), rowspan=2, colspan=2)
ax13 = plt.subplot2grid(shape=(6, 6), loc=(0, 4), rowspan=2, colspan=2)
ax21 = plt.subplot2grid(shape=(6, 6), loc=(2, 0), rowspan=2, colspan=2)
ax22 = plt.subplot2grid(shape=(6, 6), loc=(2, 2), rowspan=2, colspan=2)
ax23 = plt.subplot2grid(shape=(6, 6), loc=(2, 4), rowspan=2, colspan=2)
ax5 = plt.subplot2grid(shape=(6, 6), loc=(4, 0), rowspan=2, colspan=2)
ax3 = plt.subplot2grid(shape=(6, 6), loc=(4, 2), rowspan=2, colspan=2)
ax4 = plt.subplot2grid(shape=(6, 6), loc=(4, 4), rowspan=2, colspan=2)
# ensure that the first 6 plots are all squares
ax11.set_aspect(1.0/ax11.get_data_ratio(), adjustable='box')
ax12.set_aspect(1.0/ax12.get_data_ratio(), adjustable='box')
ax13.set_aspect(1.0/ax13.get_data_ratio(), adjustable='box')
ax21.set_aspect(1.0/ax11.get_data_ratio(), adjustable='box')
ax22.set_aspect(1.0/ax12.get_data_ratio(), adjustable='box')
ax23.set_aspect(1.0/ax13.get_data_ratio(), adjustable='box')


####### Gazes in Space #######
# plot the background image (bird-eye-view of the city center)
img = plt.imread("unity_scene_bird_quad.png")
ax4.imshow(img, extent=[min_x, max_x, min_y, max_y])
# define window and axis
ax4.set_xlim(min_x, max_x)
ax4.set_ylim(min_y, max_y)
ax4.set_xticks([])
ax4.set_yticks([])
ax4.set_title("G", fontsize=numbersize, fontweight="bold",loc="left", x=-0.09, y=1.05, pad=-30, fontname=fname)

min_x = 505
max_x = 642
min_y = 519
max_y = 657

# plot the red outline:
# this was done manually, the points are in relation to the window of the plot (so how much one zooms in)
trianglex = [516, 610.75, 610.75, 630, 630, 612.6, 614.8, 578,   563,   559,   
             556,   550,   545,   541, 538,   535, 529, 526,   527, 516, 516] 
triangley = [565, 565,    555,    555, 621, 621,   611.5, 611.5, 609.5, 608.8, 
             608.5, 607.9, 607.8, 608, 608.5, 609.0, 611.0, 612.5, 621, 621, 565]
for i in range(3):
    plt.plot( trianglex, triangley, color=sacc_color_1, linestyle='-', linewidth=4)

    
### Add the data from all subjects
amplitude = []
peak_vel = []
# loop through all subjects
for i, uid in enumerate(idd):
    # load data:
    for_eye = pd.read_csv(f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0)
    
    ####### Direction Vectors World #######
    # eye position
    Xcorr_position = for_eye["xcoord_orig"].tolist()
    Ycorr_position = for_eye["ycoord_orig"].tolist()
    Zcorr_position = for_eye["zcoord_orig"].tolist()
    subj = list(zip(Xcorr_position, Ycorr_position, Zcorr_position))
    # hit position
    hpooX = for_eye["xhpoo"].tolist()
    hpooY = for_eye["yhpoo"].tolist()
    hpooZ = for_eye["zhpoo"].tolist()
    hpoo = list(zip(hpooX, hpooY, hpooZ))
    # gaze_vec(t) is a unit vector in the direction of the gaze (eye+head) in world coordinates
    g_vec = [np.array(hpoo[v] - np.array(subj[v])) for v in range(len(subj))]
    gaze_vec = [np.array(v) / np.linalg.norm(np.array(v)) for v in g_vec]
    # create df to plot
    gaze_vec = pd.DataFrame(gaze_vec, columns=["gvx", "gvy", "gvz"])
    # add it to for_eye so we can differentiate between gaze and saccade
    for_eye_new = pd.concat([for_eye, gaze_vec], axis=1)
    # separate between gaze and saccade
    gaze = for_eye_new[~for_eye_new["isFix"].isnull()]
    sacc = for_eye_new[for_eye_new["isFix"].isnull()]
    gaze = gaze[~gaze["long_events"].isnull()]
    sacc = sacc[~sacc["long_events"].isnull()]
    # create three plots as we have 3 coordinates
    # X vs Y
    ax11.scatter(gaze["gvx"].tolist(), gaze["gvy"].tolist(), color =gaze_color_1,s=4,edgecolors='none', alpha=0.04)
    # Z vs Y
    ax12.scatter(gaze["gvz"].tolist(), gaze["gvy"].tolist(), color =gaze_color_1,s=4,edgecolors='none', alpha=0.04)
    # X vs Z
    ax13.scatter(gaze["gvx"].tolist(), gaze["gvz"].tolist(), color =gaze_color_1,s=4,edgecolors='none', alpha=0.04)
    
    
    ####### Direction Vectors Local, Head, Saccade Vector #######
    # separate between gaze and saccade
    gaze = for_eye[~for_eye["isFix"].isnull()]
    gaze = gaze[~gaze["long_events"].isnull()]
    # get the saccade vectors: (difference between beginning and end of saccade)
    # at first, define saccade onset and offset (correction needed in case the saccade is only one sample long)
    events = [-1.0 if (for_eye["events"][ev - 1] == 1.0 and for_eye["events"][ev] == 2.0)
        else for_eye["events"][ev]
        for ev in for_eye.index.tolist()[2:]
    ]
    events = [for_eye["events"][0]] + [for_eye["events"][1]]  + events
    for_eye_new["events"] = events
    # get the saccade vectors: (diff between beginning and end of saccade)
    sacc_start = for_eye_new[for_eye_new["events"] == 1.0]
    sacc_end = for_eye_new[for_eye_new["events"] == -1.0]
    # for one subject, the beginning is messed up, so correct it:
    if sacc_end.index[0] < sacc_start.index[0]:
        sacc_end = sacc_end[1:]
    # separate saccades
    sacc = for_eye_new[for_eye_new["isFix"].isnull()]
    # now get the diff between beginning and end
    sacc_x = (
        np.array(sacc_start["xcoord"].tolist())
        - np.array(sacc_end["xcoord"].tolist())
    ).tolist()
    sacc_y = (
        np.array(sacc_start["ycoord"].tolist())
        - np.array(sacc_end["ycoord"].tolist())
    ).tolist()
    sacc_z = (
        np.array(sacc_start["zcoord"].tolist())
        - np.array(sacc_end["zcoord"].tolist())
    ).tolist()
    sacc_v = list(zip(sacc_x, sacc_y, sacc_z))
    sacc_v_norm = [np.array(v) / np.linalg.norm(np.array(v)) for v in sacc_v]
    sacc_v_norm_df = pd.DataFrame(sacc_v_norm, columns = ['x', 'y', 'z'])
    # Head: plot the x vs y coordiante 
    ax21.scatter(gaze["xhead"].tolist(), gaze["yhead"].tolist(), color =gaze_color_1,s=4,edgecolors='none', alpha=0.04)
    # local eye-tracking data: plot the local eye-in-head direction vector
    ax22.scatter(gaze["xlocal_dir"].tolist(), gaze["ylocal_dir"].tolist(), color =gaze_color_1,s=4,edgecolors='none', alpha=0.04)
    # sacc vectors: plot them
    for mk in range(len(sacc_x)):
        ax23.plot([0, sacc_x[mk]], [0, sacc_y[mk]], sacc_color_1, alpha=0.15)

    
    
    ####### Sacc Amp vs. Peak Vel #######
    # --- SACCADE AMPLITUDE ---
    prev_gaze_centroid = None
    prev_subject_centroid = None
    for_eye['saccade_amplitude'] = np.nan # create a row in the df with nans
    # get the event onsets and offsets
    start_indices = for_eye[for_eye['events'] == 2.0].index
    end_indices = for_eye[for_eye['events'] == -2.0].index
    start_sacc = for_eye[for_eye['events'] == 1.0].index # get the start of saccades to appropriately add the saccade amplitudes to the conditions
    # correct for two cases:
    # check if first start_sacc is smaller than first start_indices --> if not, add the first start_indices to the start_sacc list
    if start_sacc[0] > start_indices[0]:
        start_sacc = [start_indices[0]] + start_sacc
    # if the last start_sacc is bigger than the last start_indices --> remove it
    if start_sacc[-1] > start_indices[-1]:
        start_sacc = start_sacc[:-1]
    if len(end_indices) > len(start_indices):
        # in case there has been an end index without a start one, get rid of this
        if end_indices[0] < start_indices[0]:
            end_indices = end_indices[1:]
        else: 
            print(uid)
    # loop through all events
    for start, end, start_s in zip(start_indices, end_indices, start_sacc):
        # get a smaller df
        group_df = for_eye.loc[start:end]
        # Compute the centroid of the gaze positions and the mean timestamp
        gaze_centroid = compute_centroid(group_df, ['xhpoo', 'yhpoo', 'zhpoo'])
        subject_centroid = compute_centroid(group_df, ['xcoord_orig', 'ycoord_orig', 'zcoord_orig'])
        # Compute the saccade amplitude and store it in the dataframe
        if prev_gaze_centroid is not None and prev_subject_centroid is not None:
            for_eye.loc[start_s:end, 'saccade_amplitude'] = compute_saccade_amplitude(prev_gaze_centroid, prev_subject_centroid, gaze_centroid, subject_centroid)
        # Update the previous centroids and timestamp
        prev_gaze_centroid = gaze_centroid
        prev_subject_centroid = subject_centroid
    # --- SACCADE AMPLITUDE ---
    # --- PEAK VELOCITY ---
    # loop through all events
    for start_s, start in zip(start_sacc, start_indices):
        # get a small df
        cur = for_eye.iloc[start_s : start-1]
        # I expect to see RuntimeWarnings in this block
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            # comput the maximum velocity
            try:
                for_eye.loc[start_s:start-1, 'peak_velocity'] = np.nanmax(cur["combined_vel"].tolist())
            except ValueError:
                pass
    # --- PEAK VELOCITY ---
    # only take the saccade onsets, so we plot one datapoint per event
    sacc = for_eye[for_eye["events"] == 1.0]
    sacc = sacc[~sacc["long_events"].isnull()]
    amplitude = amplitude + sacc['saccade_amplitude'].tolist()
    peak_vel = peak_vel + sacc['peak_velocity'].tolist()
    
    
    ####### Distance Distribution #######
    # separate between gaze and saccade
    sacc = for_eye[for_eye["events"] == 1.0]
    # exclude long events
    sacc = sacc[~sacc["long_events"].isnull()]
    sacc.replace([np.inf, -np.inf], np.nan, inplace=True)
    sac.append(np.nanmedian(sacc["avg_dist"]))
    # plot saccades
    sns.kdeplot(
        sacc["avg_dist"],
        color=sacc_color_1,
        fill=False,
        clip=[0, 100],
        alpha=0.6,
        label="Saccade Distance",
        ax = ax3
    )
    cnt = cnt + 1
    # separate between gaze and saccade
    gaze = for_eye[for_eye["events"] == 2.0]
    # exclude long events
    gaze = gaze[~gaze["long_events"].isnull()]
    gze.append(np.nanmedian(gaze["avg_dist"]))
    # plot gazes
    sns.kdeplot(
        gaze["avg_dist"],
        color=gaze_color_1,
        fill=False,
        clip=[0, 100],
        alpha=0.9,
        label="Gaze Distance",
        ax = ax3
    )
    ax3.lines[cnt].set_linestyle("--") # set the lnie style for gazes
    cnt = cnt + 1
    
    
    ####### Gazes in Space #######
    # only get gazes without nans and wihtout long events
    gaze = for_eye[for_eye["events"] == 2.0]
    gaze = gaze[~gaze["long_events"].isnull()]  
    # plot the hit points on top of the image displayed at the beginning of this cell
    ax4 = plt.scatter(
        x=gaze["xhpoo"].tolist(),
        y=gaze["zhpoo"].tolist(),
        color="k",
        marker=".",
        alpha=0.2,
        linewidth=0,
    )  


# Set axis styles:

####### Direction Vectors World #######
ax11.set_xlabel("X-Coordinates", fontsize=labelsize, fontname=fname)
ax11.set_ylabel("Y-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax11.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax11.get_yticklabels():
    label.set_fontproperties(fname)
ax11.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax11.xaxis.set_tick_params(labelsize=ticksize)
ax11.set_xticks([-1, 0, 1])
ax11.set_yticks([-1, 0, 1])
ax11.set_xlim(-1.1, 1.1)
ax11.set_ylim(-1.1, 1.1)
ax11.set_xticks(np.linspace(-1.0, 1.0, num=3))

# Z vs Y
ax12.set_xlabel("Z-Coordinates", fontsize=labelsize, fontname=fname)
ax12.set_ylabel("Y-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax12.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax12.get_yticklabels():
    label.set_fontproperties(fname)
ax12.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax12.xaxis.set_tick_params(labelsize=ticksize)
ax12.set_xticks([-1, 0, 1])
ax12.set_yticks([-1, 0, 1])
ax12.set_xlim(-1.1, 1.1)
ax12.set_ylim(-1.1, 1.1)

# X vs Z
ax13.set_xlabel("X-Coordinates", fontsize=labelsize, fontname=fname)
ax13.set_ylabel("Z-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax13.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax13.get_yticklabels():
    label.set_fontproperties(fname)
ax13.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax13.xaxis.set_tick_params(labelsize=ticksize)
ax13.set_xticks([-2, -1, 0, 1, 2])
ax13.set_yticks([-2, -1, 0, 1, 2])
ax13.set_xlim(-1.1, 1.1)
ax13.set_ylim(-1.1, 1.1)
 
####### Direction Vectors Local, Head, Saccade Vector #######
# Head
ax21.set_xlabel("X-Coordinates", fontsize=labelsize, fontname=fname)
ax21.set_ylabel("Y-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax21.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax21.get_yticklabels():
    label.set_fontproperties(fname)
ax21.set_xticks([-1, 0, 1])
ax21.set_yticks([-1, 0, 1])
ax21.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax21.xaxis.set_tick_params(labelsize=ticksize)
ax21.set_xlim(-1.1, 1.1)
ax21.set_ylim(-1.1, 1.1)
# Local ET
ax22.set_xlabel("X-Coordinates", fontsize=labelsize, fontname=fname)
ax22.set_ylabel("Y-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax22.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax22.get_yticklabels():
    label.set_fontproperties(fname)
ax22.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax22.xaxis.set_tick_params(labelsize=ticksize)
ax22.set_xticks([-2, -1, 0, 1, 2])
ax22.set_yticks([-2, -1, 0, 1, 2])
ax22.set_xlim(-1.1, 1.1)
ax22.set_ylim(-1.1, 1.1)
# Sacc Vector
ax23.set_xlabel("X-Coordinates", fontsize=labelsize, fontname=fname)
ax23.set_ylabel("Y-Coordinates", fontsize=labelsize, fontname=fname)
for label in ax23.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax23.get_yticklabels():
    label.set_fontproperties(fname)
ax23.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax23.xaxis.set_tick_params(labelsize=ticksize)
ax23.set_xticks([-2, -1, 0, 1, 2])
ax23.set_yticks([-2, -1, 0, 1, 2])
ax23.set_xlim(-1.1, 1.1)
ax23.set_ylim(-1.1, 1.1)

    
####### Sacc Amp vs. Peak Vel #######
# plot it against each other
stats = {}
stats["saccade_amplitude"] = amplitude
stats["peak_velocity"] = peak_vel
stats = pd.DataFrame(stats)
sns.scatterplot(data=stats,x = "saccade_amplitude", y = "peak_velocity",ax=ax5,marker="x", color=sacc_color_1,alpha = 0.3)
ax5.set_xlabel("Saccade Amplitude (deg)", fontsize=labelsize)
ax5.set_ylabel("Peak Velocity (deg/s)", fontsize=labelsize)
for label in ax5.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax5.get_yticklabels():
    label.set_fontproperties(fname)
ax5.xaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax5.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax5.set_xlim([-1, 43])
ax5.set_ylim([-30, 900])
ax5.set_xticks([0,10,20,30,40]) 
    
####### Distance Distribution #######
# plot the ledgend
handles, labels = ax3.get_legend_handles_labels()
leg = ax3.legend(
    [handles[0], handles[-1]],
    [labels[0], labels[-1]],
    loc="upper right",
    fontsize=legendsize,
)
# set alpha of ledgend symbols to 1
for lh in leg.legendHandles:
    lh.set_alpha(1)
ax3.set_xlabel("Distance (in Unity Units)", fontsize=labelsize, fontname=fname)
ax3.set_ylabel("Density of Distances", fontsize=labelsize, fontname=fname)
for label in ax3.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax3.get_yticklabels():
    label.set_fontproperties(fname)
ax3.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax3.xaxis.set_tick_params(labelsize=ticksize)
ax3.set_yticks([0.01,0.03,0.05]) 
ax3.locator_params(nbins=6, axis='y')

ax11.set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax21.set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax22.set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax23.set_title("D", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax5.set_title("E", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)
ax3.set_title("F", fontsize=numbersize, fontweight="bold",loc="left", x=-0.1, y=1.05, pad=-30, fontname=fname)

plt.show()

##### Stats Saccade Amplitude vs. Peak Vel.
# calculate the correlation between sacc amplitude and peak velocity
# drop rows that contain NaN values for the statistics
stats = stats.dropna()
r, p = scipy.stats.pearsonr(stats["saccade_amplitude"], stats["peak_velocity"])
print(f"statistic: {r}")
print(f"pvalue: {round(p,3)}")
print(f"degress of freedom : {len(stats) - 2}")

In [None]:
##### Get the max for the local eye-in-head direction vectors (for x and y)
ids = recordings.index.tolist()
idd = ids[:]

max_x = [] # to add the data for each subject
max_y = []

for i, uid in enumerate(idd):
    # load data
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )
    gaze = for_eye[~for_eye["isFix"].isnull()]
    gaze = gaze[~gaze["long_events"].isnull()]

    # we don't care about the sign, just the absolute value
    abs_x = [abs(ele) for ele in gaze["xlocal_dir"].tolist()]
    abs_y = [abs(ele) for ele in gaze["ylocal_dir"].tolist()]
    # add the maximum of each subejct
    max_x = max_x + [np.nanmax(abs_x)]
    max_y = max_y + [np.nanmax(abs_y)]

columns=['max_x','max_y']
df = pd.DataFrame(list(zip(max_x,max_y)),columns=columns,index=idd)
# give an overview over the data
display(df.describe())
print()
# calculate a t-test between the maximum dimensions
scipy.stats.ttest_rel(max_x,max_y)

# Data Distribution 

In [None]:
''' Calculate and compare the number of events for gazes and saccades.
This is done for both, the 10-second and data-driven method.
'''
ids = recordings.index.tolist()
idd = ids[:]

stats = {} # df to collect the amount of events

for uid in idd:
    stats[uid] = {}
    for fe in range(2):
        # load data
        # 10 sec
        if fe == 0:
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS__10sec_{uid}.csv", index_col=0
            )
            condition = "10"
        else:
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
            )
            condition = "dd"

        # get gazes, but not outliers
        gaze = for_eye[for_eye["events"] == 2.0]
        gaze = gaze[~gaze["long_events"].isnull()]

        # get saccades but not outliers
        sacc = for_eye[for_eye["events"] == 1.0]
        sacc = sacc[~sacc["long_events"].isnull()]

        # save the number of events
        stats[uid][condition + "_gaze"] = gaze.index.size
        stats[uid][condition + "_sacc"] = sacc.index.size

stats = pd.DataFrame(stats).transpose()
print()
print("Number of Events:")
print()

# print the median number of events
print(f"Median Gaze DD: {np.nanmedian(stats['dd_gaze'])}")
print(f"Median Gaze 10: {np.nanmedian(stats['10_gaze'])}")
print(f"Median Sacc DD: {np.nanmedian(stats['dd_sacc'])}")
print(f"Median Sacc 10: {np.nanmedian(stats['10_sacc'])}")
print()

# Interquartile range:
q75, q25 = np.nanpercentile(stats['dd_gaze'].tolist(), [75, 25])
iqr_gaze_dd = q75 - q25
print(f"IQR gaze DD: {iqr_gaze_dd} ({q25}-{q75})")

q75, q25 = np.nanpercentile(stats['dd_sacc'].tolist(), [75, 25])
iqr_sacc_dd = q75 - q25
print(f"IQR sacc DD: {iqr_sacc_dd} ({q25}-{q75})")

q75_g, q25 = np.nanpercentile(stats['10_gaze'].tolist(), [75, 25])
iqr_gaze_10 = q75 - q25
print(f"IQR gaze 10: {iqr_gaze_10} ({q25}-{q75})")

q75, q25 = np.nanpercentile(stats['10_sacc'].tolist(), [75, 25])
iqr_sacc_10 = q75 - q25
print(f"IQR sacc 10: {iqr_sacc_10} ({q25}-{q75})")
print()

# KS-tests
print()
print('KS-Test')
alpha = 0.05
adjusted_alpha = alpha / 4  # muultiple tests were performed
print(f"Adjusted alpha: {adjusted_alpha:.4f}")
print()

statistic, p1 = ks_2samp(stats['10_gaze'].tolist(), stats['dd_gaze'].tolist())
print(f"p-value gaze: {p1:.4f}")
statistic, p2 = ks_2samp(stats['10_sacc'].tolist(), stats['dd_sacc'].tolist())
print(f"p-value sacc: {p2:.4f}")


In [None]:
''' 
Plot the number of gazes, saccades, long events (outliers) and invalid data
One datapoint will be one participant.
We show the data for both data segmentation intervals.
'''

ids = recordings.index.tolist()
idd = ids[:2]

# set up figure
sns.set(rc={"figure.figsize": (30, 15)})
sns.set_style(
    "white"
) 
f, (ax) = plt.subplots(1, 2, sharey=True, constrained_layout=True)

# define 
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name
plt.rcParams["font.family"] = fname # set font name

# go through both data segmentation intervals
for fe in range(2):
    # lists to save the data from every subject
    nr_gaze = []
    nr_sacc = []
    nr_gz_out = []
    nr_sc_out = []
    non_val = []
    total_nr = []
    stats = {} # df for the results
    # loop through all subjects (once per segmentation method)
    for uid in idd:
        # load data
        # 10 sec
        if fe == 0:
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS__10sec_{uid}.csv", index_col=0
            )
        else:
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
            )

        # get the total duration
        total = for_eye.index.size

        # get the number of invalid datapoints
        outbig = for_eye[for_eye["valid"] == 0.0]
        for_eye = for_eye[for_eye["valid"] == 1.0]

        # for all valid datapoints, separate the data between gaze and saccade
        gaze = for_eye[~for_eye["isFix"].isnull()]
        sacc = for_eye[for_eye["isFix"].isnull()]

        # further, separate gazes and saccades between outliers and valid events
        gaze_out = gaze[gaze["long_events"].isnull()]
        sacc_out = sacc[sacc["long_events"].isnull()]
        gaze = gaze[~gaze["long_events"].isnull()]
        sacc = sacc[~sacc["long_events"].isnull()]

        # now calculate the percentage for each category
        nr_gaze = nr_gaze + [gaze.index.size * 100 / total]
        nr_sacc = nr_sacc + [sacc.index.size * 100 / total]
        nr_gz_out = nr_gz_out + [gaze_out.index.size * 100 / total]
        nr_sc_out = nr_sc_out + [sacc_out.index.size * 100 / total]
        non_val = non_val + [outbig.index.size * 100 / total]
        # santiy check that the numbers add up
        total_nr = total_nr + [gaze.index.size * 100 / total + sacc.index.size * 100 / total + gaze_out.index.size * 100 / total + sacc_out.index.size * 100 
                               / total + outbig.index.size * 100 / total]
    # all all numbers to a df
    stats["Gaze"] = nr_gaze
    stats["Sacc"] = nr_sacc
    stats["Out Gaze"] = nr_gz_out
    stats["Out Sacc"] = nr_sc_out
    stats["Invalid"] = non_val
    stats["Total"] = total_nr
    stats = pd.DataFrame(stats)
    # display(stats) # enable this, when wanting to print the individual values
    
    # plot the resuls: one suplot per segmentation method
    plt.subplot(1, 2, fe + 1)
    # define the color pallet
    pallet = {
        "Gaze": gaze_color_1,
        "Sacc": sacc_color_1,
        "Out Gaze": gaze_color_2,
        "Out Sacc": sacc_color_2,
        "Invalid": colliders_color,
    }
    # plot the data as violinplot
    sns.violinplot(
        data=stats[
            ["Gaze", "Sacc","Out Gaze", "Out Sacc", "Invalid"]
        ],
        palette=pallet,
        orient="v",
        inner="box",
        alpha=0.2,
    )
    # plot the individual datapoints on top of the violinplot
    sns.swarmplot(
        data=stats[
            ["Gaze", "Sacc","Out Gaze", "Out Sacc", "Invalid"]
        ],
        color="black",
        marker="o",
        size=6,
    )
    
    # add a y-axis label only for the left plot
    if fe == 0:
        ax[fe].set_ylabel("Distribution in %", fontsize=labelsize, fontname=fname)
    
    # set axis ticks
    for label in ax[fe].get_xticklabels(): # change tick font
        label.set_fontproperties(fname)
    for label in ax[fe].get_yticklabels():
        label.set_fontproperties(fname)
    ax[fe].xaxis.set_tick_params(labelsize=labelsize)  # change tick size
    ax[fe].yaxis.set_tick_params(labelsize=ticksize)  # change tick size

ax[0].set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname)
ax[1].set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname)

plt.show()

In [None]:
'''
Check if there is a statistical significnace between any of the plotted options.
This is done for gazes, saccades, long events (outliers) and invalid data.
Similar code as above, but this time we don't plot the results but do statistical tests.
'''

print("% of events and outliers")
print()

ids = recordings.index.tolist()
idd = ids[:]

# do this for each segmentation algorithm
for fe in range(2):
    # lists to save the data from every subject
    nr_gaze = []
    nr_sacc = []
    nr_gz_out = []
    nr_sc_out = []
    non_val = []
    total_nr = []
    nr_out = []
    # create individual dfs + get the segment name
    if fe == 0:
        stats_10 = {}
        condition = "10 second intervals"
    else:
        stats_dd = {}
        condition = "Data driven intervals"
    mean_std = {}
    # loop through all subjects (once per segmentation method)
    for uid in idd:
        # load data
        if fe == 0: # 10 sec
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS__10sec_{uid}.csv", index_col=0
            )
        else: # data-driven
            for_eye = pd.read_csv(
                f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
            )

        # get the total duration
        total = for_eye.index.size

        # get the number of invalid datapoints
        outbig = for_eye[for_eye["valid"] == 0.0]
        for_eye = for_eye[for_eye["valid"] == 1.0]

        # for all valid datapoints, separate the data between gaze and saccade
        gaze = for_eye[~for_eye["isFix"].isnull()]
        sacc = for_eye[for_eye["isFix"].isnull()]

        # further, separate gazes and saccades between outliers and valid events
        gaze_out = gaze[gaze["long_events"].isnull()]
        sacc_out = sacc[sacc["long_events"].isnull()]
        gaze = gaze[~gaze["long_events"].isnull()]
        sacc = sacc[~sacc["long_events"].isnull()]

        # now calculate the percentage for each category
        nr_gaze = nr_gaze + [gaze.index.size * 100 / total]
        nr_sacc = nr_sacc + [sacc.index.size * 100 / total]
        nr_gz_out = nr_gz_out + [gaze_out.index.size * 100 / total]
        nr_sc_out = nr_sc_out + [sacc_out.index.size * 100 / total]
        nr_out = nr_out + [outbig.index.size * 100 / total + gaze_out.index.size * 100 / total + sacc_out.index.size * 100 / total]
        non_val = non_val + [outbig.index.size * 100 / total]
        # santiy check that the numbers add up
        total_nr = total_nr + [gaze.index.size * 100 / total + sacc.index.size * 100 / total + gaze_out.index.size * 100 / total + sacc_out.index.size * 100 / total + outbig.index.size * 100 / total]

    # now get a df for mean + std across all subjects:
    mean_std['mean'] = {}
    mean_std['std'] = {}
    mean_std['mean']['nr_gaze'] = np.mean(nr_gaze)
    mean_std['std']['nr_gaze'] = np.std(nr_gaze)
    mean_std['mean']['nr_sacc'] = np.mean(nr_sacc)
    mean_std['std']['nr_sacc'] = np.std(nr_sacc)
    mean_std['mean']['nr_gz_out'] = np.mean(nr_gz_out)
    mean_std['std']['nr_gz_out'] = np.std(nr_gz_out)
    mean_std['mean']['nr_sc_out'] = np.mean(nr_sc_out)
    mean_std['std']['nr_sc_out'] = np.std(nr_sc_out)
    mean_std['mean']['nr_out'] = np.mean(nr_out)
    mean_std['std']['nr_out'] = np.std(nr_out)
    mean_std['mean']['non_val'] = np.mean(non_val)
    mean_std['std']['non_val'] = np.std(non_val)
    mean_std = pd.DataFrame(mean_std)

    # display the results of median and IQR
    print(f"% events and outliers - {condition}")
    print()
    # Gaze
    print(f"Median % Gaze: {np.nanmedian(nr_gaze)}")
    q75, q25 = np.nanpercentile(nr_gaze, [75, 25])
    iqr = q75 - q25
    print(f"IQR % Gaze: {iqr} ({q25}-{q75})")
    # Sacc
    print(f"Median % Sacc: {np.nanmedian(nr_sacc)}")
    q75, q25 = np.nanpercentile(nr_sacc, [75, 25])
    iqr = q75 - q25
    print(f"IQR % Sacc: {iqr} ({q25}-{q75})")
    # Gaze_out
    print(f"Median % Gaze_out: {np.nanmedian(nr_gz_out)}")
    q75, q25 = np.nanpercentile(nr_gz_out, [75, 25])
    iqr = q75 - q25
    print(f"IQR % Gaze_out: {iqr} ({q25}-{q75})")
    # Sacc_out
    print(f"Median % Sacc_out: {np.nanmedian(nr_sc_out)}")
    q75, q25 = np.nanpercentile(nr_sc_out, [75, 25])
    iqr = q75 - q25
    print(f"IQR % Sacc_out: {iqr} ({q25}-{q75})")
    print(f"Median % Total Outliers: {np.nanmedian(nr_out)}")
    print()

    # add the results to the individual dfs
    if fe == 0:
        stats_10["Gaze"] = nr_gaze
        stats_10["Sacc"] = nr_sacc
        stats_10["Out Gaze"] = nr_gz_out
        stats_10["Out Sacc"] = nr_sc_out
        stats_10["Invalid"] = non_val
        stats_10["Total"] = total_nr
        stats_10["total_out"] = nr_out
        stats_10 = pd.DataFrame(stats_10)
    else:
        stats_dd["Gaze"] = nr_gaze
        stats_dd["Sacc"] = nr_sacc
        stats_dd["Out Gaze"] = nr_gz_out
        stats_dd["Out Sacc"] = nr_sc_out
        stats_dd["Invalid"] = non_val
        stats_dd["Total"] = total_nr
        stats_dd["total_out"] = nr_out
        stats_dd = pd.DataFrame(stats_dd)

# Invalid: display the results and IQR
print(f"Median % Invalid: {np.nanmedian(non_val)}")
# Interquartile range:
q75, q25 = np.nanpercentile(non_val, [75, 25])
iqr = q75 - q25
print(f"IQR % Invalid: {iqr} ({q25}-{q75})")
print()

alpha = 0.05
adjusted_alpha = alpha / 4  # Four tests were performed
print(f"Adjusted alpha: {adjusted_alpha:.4f}")
print()

# Display a KS test between the two segmentation intervals
for column in stats_dd:
    if column not in ['Invalid', 'Total', 'total_out']:
        statistic, pvalue = ks_2samp(stats_10[column], stats_dd[column])
        # Print the p-values
        if pvalue <= adjusted_alpha: # if significant 
            print(f"Adjusted p-value {column}: {pvalue:.4f} - this test is significant after Bonferroni correction")
        else: # if non significant
            print(f"Adjusted p-value {column}: {pvalue:.4f}")


# Event Durations

In [None]:
'''
Plot the durations of gazes and saccades.
Additionally, a KS test to see if there is a significant difference between both is clauclates.
'''

# plot the data without outliers across all subject
ids = recordings.index.tolist()
idd = ids[:]


# prepare the plot
sns.set(rc={"figure.figsize": (30,20)})
sns.set_style(
    "white"
) 
f, (axis) = plt.subplots(1, 1, sharey=True,)

# define:
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name
plt.rcParams["font.family"] = fname

# lists used for the statistical tests
gze = []
sac = []

# we want to plot all saccades first and then all gazes, so we loop through all subejcts twice:

# saccades
for uid in idd:
    # load the data
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )
    c = sacc_color_1 # sacc color
    l = "Average Saccade Durations" # name for ledgend

    # separate between gaze and saccade
    sacc = for_eye[for_eye["events"] == 1.0]
    # exclude long events
    sacc = sacc[~sacc["long_events"].isnull()]
    # add the median for statistics
    sac.append(np.nanmedian(sacc["length"]))
    # plot the saccades
    sns.kdeplot(
        sacc["length"],
        color=c,
        fill=False,
        alpha=0.2,
        clip=[0, 1],
        linewidth=3.5,
        label=l,
    )  

# gazes:
for uid in idd:
    # load data
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )
    c = gaze_color_1 # gaze color
    l = "Average Gaze Durations" # ledgend

    # separate between gaze and saccade
    gaze = for_eye[for_eye["events"] == 2.0]
    # exclude long events
    gaze = gaze[~gaze["long_events"].isnull()]
    # add for statistics
    gze.append(np.nanmedian(gaze["length"]))
    # plot gazes
    sns.kdeplot(
        gaze["length"],
        color=c,
        fill=False,
        alpha=0.5,
        cut = 0,
        clip=[0, 1],
        linewidth=3.5,
        label=l,
    )

# calculate and display the statistics
print("data-driven interval")
# as they are not completely normally distributed, we will use median to get the average results:
med_gaze = np.nanmedian(gze)
med_sacc = np.nanmedian(sac)
print(f"Median gaze duration: {med_gaze}")
print(f"Median sacc duration: {med_sacc}")
# Interquartile range:
q75_g, q25_g = np.nanpercentile(gze, [75, 25])
iqr_gaze = q75_g - q25_g
q75_s, q25_s = np.nanpercentile(sac, [75, 25])
iqr_sacc = q75_s - q25_s
print(f"IQR gaze duration: {iqr_gaze} ({q25_g}-{q75_g})")
print(f"IQR sacc duration: {iqr_sacc} ({q25_s}-{q75_s})")

# Perform the KS test
print()
statistic, pvalue = ks_2samp(gze, sac)
print(f"KS statistic: {statistic:.4f}")
print(f"P-value: {pvalue:.4f}")
alpha = 0.05
print(f"Alpha: {alpha:.4f}")
# Now also get the mean to compare to other studies:
mean_gaze = np.nanmean(gze)
mean_sacc = np.nanmean(sac)
std_gaze = np.std(gze)
std_sacc = np.std(sac)
print(f"Mean gaze duration: {mean_gaze}; std: {std_gaze}")
print(f"Mean sacc duration: {mean_sacc}; std: {std_sacc}")
print()

# back to the plot:
# set legend 
handles, labels = axis.get_legend_handles_labels()

leg = axis.legend([handles[0], handles[-1]],
                [labels[0], labels[-1]],
                loc="upper right", fontsize=legendsize)#
for lh in leg.legendHandles:
    lh.set_alpha(1)

# set axis labels and ticks
axis.set_xlabel("Duration (sec)", fontsize=labelsize, fontname=fname)
axis.set_ylabel("Density of Durations", fontsize=labelsize, fontname=fname)
for label in axis.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in axis.get_yticklabels():
    label.set_fontproperties(fname)
axis.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
axis.xaxis.set_tick_params(labelsize=ticksize)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# Velocity distribution in relation to gaze onset

In [None]:
'''
Create a plot for the velocity distributions:
1. show the distribution of one subject
2. show the average distribution
3. plot the peak velocity
'''
ids = recordings.index.tolist()
idd = ids[:2]

# set up the farames before and after gaze onset
start_fr = 20  # how many frames before gaze onset
end_fr = 35  # how many frames to include after gaze onset

data = {} # save data
trials_all = []
start_fr_peak = 10  # how many frames before gaze onset

# set up figure
sns.set(rc={"figure.figsize": (30, 30)})
sns.set_style("white") 
f, (ax) = plt.subplots(3,1, constrained_layout=True)

# define:
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name
plt.rcParams["font.family"] = fname

# loop through all subjects
for i, uid in enumerate(idd):
    # load data
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )

    ####### Velocity over trials #######
    # we only want to show the individual subject for the first one
    if i == 0:
        # we create a new df, as we will round down the velocities to 200
        # this is done for this plot only
        for_eye_new = for_eye 
        # replace all velocities over 200 with 200
        combined_vel = for_eye["combined_vel"].tolist()
        c_v = [200 if cv > 200 else cv for cv in combined_vel]
        for_eye_new["combined_vel"] = c_v # add velocities back to new df
        # get gazes
        gaze = for_eye_new[for_eye_new["events"] == 2.0]
        gaze = gaze[~gaze["long_events"].isnull()]
        # get a subset of trials
        gaze = gaze.iloc[50:250]
        # now go through the trials
        trials = {}
        for g, gz in enumerate(gaze.index.tolist()):
            # get the data segment for each trial (using the previously defined number of gazes)
            cur = for_eye.iloc[gz - start_fr : gz + end_fr]
            # add them to a trials df
            trials[g] = cur["combined_vel"].tolist()
        
        trials = pd.DataFrame(trials).transpose()
        trials = trials.to_numpy()
        # take out all nan samples and replace them with 0 (for plotting)
        trials[np.isnan(trials)] = 0
        # plot the results
        sns.heatmap(trials, cmap="YlGnBu", xticklabels=10, yticklabels=50, ax=ax[0])  
        # set a red line at gaze onset
        ax[0].axvline(x=start_fr, linewidth=3, color=sacc_color_1)
        # set axis labels + font
        ax[0].set_xlabel("Time (ms)", fontsize=labelsize, fontname=fname)
        ax[0].set_ylabel("Trials", fontsize=labelsize, fontname=fname)
        # set axis ticks + font
        for label in ax[0].get_xticklabels(): # change tick font
            label.set_fontproperties(fname)
        for label in ax[0].get_yticklabels():
            label.set_fontproperties(fname)
        ax[0].yaxis.set_tick_params(labelsize=ticksize)  # change tick size
        ax[0].xaxis.set_tick_params(labelsize=ticksize)
        # set color bar
        cbar = ax[0].collections[0].colorbar
        cbar.ax.tick_params(labelsize=ticksize)
        cbar.set_label("Angular Velocity", fontsize=labelsize, fontname=fname)
        # set x tick labels
        ax[0].set_xticklabels([-220,-110,0,110,220,330],)
        
        # adjust the frames after gaze onset, so that the red lines in the next plots are aligned
        end_fr = end_fr + 3

    ####### Velocity over subjects #######
    # get gazes
    gaze = for_eye[for_eye["events"] == 2.0]
    gaze = gaze[~gaze["long_events"].isnull()]
    # now go through the trials
    trials = {}
    for g, gz in enumerate(gaze.index.tolist()):
        # only use trials that have datapoints for all frames needed
        if (gz - start_fr) > 0 and (gz + end_fr) < len(for_eye):
            # get the current segment
            cur = for_eye.iloc[gz - start_fr : gz + end_fr]
            # add it to trial df
            trials[g] = cur["combined_vel"].tolist()
    trials = pd.DataFrame(trials).transpose()
    # add the median for this subjects to data (will be used to plot median across subjects)
    data[uid[:5]] = trials.median(skipna=True)
    # plot this result
    ax[1].plot(list(range(-start_fr,end_fr)), data[uid[:5]], color=gaze_color_1, 
                linestyle='solid',linewidth=1.5, alpha=0.4, label="Average Velocity per Participant")

    ####### Peak Velocities #######
    # get gazes
    gaze = for_eye[for_eye["events"] == 2.0]
    gaze = gaze[~gaze["long_events"].isnull()]
    trials = []
    # loop though all gazes
    for g, gz in enumerate(gaze.index.tolist()):
        # only use trials that have datapoints for all frames needed
        if gz - start_fr_peak > 0 and gz + 3 < len(gaze.index.tolist()):
            # get data segment
            cur = for_eye.iloc[gz - start_fr_peak : gz + 3]
            # get the index of the peak velocity
            trials.append(
                cur["combined_vel"]
                .tolist()
                .index(np.nanmax(cur["combined_vel"].tolist()))
            )
            # get the index of the peak velocity but save it for all subjects
            trials_all.append(
                cur["combined_vel"]
                .tolist()
                .index(np.nanmax(cur["combined_vel"].tolist()))
            )
    # plot the peak vels for each subject
    sns.kdeplot(
        trials,
        bw_adjust=0.5, #0.4,
        color=gaze_color_1,
        alpha=0.06,
        fill=True,
        ax = ax[2]
    )

    
####### Velocity over subjects #######
data = pd.DataFrame(data).transpose()
# get median across subjects
finished_data = data.median(skipna=True)
# plot the resul
ax[1].plot(list(range(-start_fr,end_fr)), finished_data, color='k', 
    linestyle='solid', marker='o',linewidth=4, label="Average Velocity across Participants")
# legend:
handles, labels = ax[1].get_legend_handles_labels()
line2d_obj = Line2D([0], [0], linewidth=3, linestyle='solid', alpha = 1.0)
handles[-2] = line2d_obj
ax[1].legend(handles[-2:],labels[-2:],loc="upper right", fontsize=legendsize)
# plot a red line at gaze onset
ax[1].axvline(x=0, linewidth=3, color=sacc_color_1)
# axis labels
ax[1].set_xlabel("Time (ms)", fontsize=labelsize, fontname=fname)
ax[1].set_ylabel("Average Velocity", fontsize=labelsize, fontname=fname)
# ticks 
for label in ax[1].get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax[1].get_yticklabels():
    label.set_fontproperties(fname)
ax[1].yaxis.set_tick_params(labelsize = ticksize) # change tick size
ax[1].xaxis.set_tick_params(labelsize = ticksize)
ax[1].set_xticklabels([-330,-220,-110,0,110,220,330],)

####### Peak Velocities #######
# plot the peak velocities across subjects
sns.kdeplot(
    trials_all,
    bw_adjust=0.5,
    color="k",
    linewidth = 3,
    ax = ax[2]
)
# plot a red line a gaze onset
ax[2].axvline(x=start_fr_peak, linewidth=3, color=sacc_color_1)
# axis labels
ax[2].set_xlabel("Time (ms)", fontsize=labelsize, fontname=fname)
ax[2].set_ylabel("Density of Peak Velocities", fontsize=labelsize, fontname=fname)
for label in ax[2].get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax[2].get_yticklabels():
    label.set_fontproperties(fname)
# ticks
ax[2].yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax[2].xaxis.set_tick_params(labelsize=ticksize)
ax[2].set_xticks(list(np.arange(-3, 15, 1)),)
ax[2].set_xticklabels(list(np.arange(-143, 55, 11)),)
xticks = ax[2].xaxis.get_major_ticks()
ticks = [0,1,2,4,5,6,7,9,10,11,12,14,15,16,17]
for x in ticks:
    xticks[x].label1.set_visible(False)  # Hide the 4th x-tick
ax[2].locator_params(nbins=5, axis='y')

ax[0].set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname)  
ax[1].set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname) 
ax[2].set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.08, y=1.05, pad=-30, fontname=fname)  

plt.show()

# Dispersion distribution in relation to gaze onset

In [None]:
# display the dispersion distribution
ids = recordings.index.tolist()
idd = ids[:]

start_fr = 20  # how many frames before gaze onset
end_fr = 38  # how many frames to include after gaze onset

# save the data
data_driven = {}


# plot it:
sns.set(rc={"figure.figsize": (30, 7)})
sns.set_style(
    "white"
)
f, (ax) = plt.subplots(1)

# define
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name
plt.rcParams["font.family"] = fname

data = {}  # save data
# we go through each subject:
for i, uid in enumerate(idd):
    # load data
    for_eye = pd.read_csv(
        f"{PATH_FOREYE}/correTS_mad_wobig_{uid}.csv", index_col=0
    )

    # get distances between consecutive datapoints
    distances = np.sqrt(np.square(for_eye['xhpoo'].diff()) + np.square(for_eye['yhpoo'].diff()) + np.square(for_eye['zhpoo'].diff()))
    for_eye["distance_hpoo"] = distances

    # get gazes
    gaze = for_eye[for_eye["events"] == 2.0]
    # get rid of long gaze events
    gaze = gaze[~gaze["long_events"].isnull()]

    # now go through the trials
    trials = {}
    for g, gz in enumerate(gaze.index.tolist()):
        # only use trials that have datapoints for all frames needed
        if (gz - start_fr) > 0 and (gz + end_fr) < len(for_eye):
            # det the corresponding datasegment
            cur = for_eye.iloc[gz - start_fr : gz + end_fr]
            # add it to df
            trials[g] = cur["distance_hpoo"].tolist()

    trials = pd.DataFrame(trials).transpose()
    # add the median for this subjects to data
    data[uid[:5]] = trials.median(skipna=True)
    # plot the dispersion for each subject
    ax.plot(
        list(range(-start_fr, end_fr)),
        data[uid[:5]],
        color=gaze_color_1,
        linestyle="solid",
        linewidth=1,
        alpha=0.4,
        label="Average Dispersion per Participant"
    )

# create df
data = pd.DataFrame(data).transpose()
# get median across subjects
finished_data = data.median(skipna=True)
# plot the data across subjects
ax.plot(
    list(range(-start_fr, end_fr)),
    finished_data,
    color="k",
    linestyle="solid",
    marker="o",
    linewidth=4,
    label="Average Dispersion across Participants"
)

# legend:
handles, labels = plt.gca().get_legend_handles_labels()
line2d_obj = Line2D([0], [0], linewidth=3, linestyle='solid', alpha = 1.0)
handles[-2] = line2d_obj
ax.legend(handles[-2:],labels[-2:],loc="upper right", fontsize=legendsize)
# plot a red line at gaze onset
ax.axvline(x=0, linewidth=3, color=sacc_color_1)
# axis labels
ax.set_xlabel("Time (ms)", fontsize=labelsize, fontname=fname)
ax.set_ylabel("Change in Dispersion", fontsize=labelsize, fontname=fname)
for label in ax.get_xticklabels(): # change tick font
    label.set_fontproperties(fname)
for label in ax.get_yticklabels():
    label.set_fontproperties(fname)
# ticks
ax.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
ax.xaxis.set_tick_params(labelsize=ticksize)
ax.set_xticklabels([-330,-220,-110,0,110,220,330],)
plt.show()

# EEG results

## ERP

In [None]:
'''
Compute Number of trials rejected and the number of trials going into the average ERPs
data calculated in figures_ET_EEG_paper.m
'''

ids = recordings.index.tolist()
idd = ids[:]

mat = scipy.io.loadmat(f"{PATH_EEG}/nr_trials_erp.mat")
mat = mat["nr_events"]
mat = pd.DataFrame(mat)


print(f"Median Amount of Trials for one fERP: {np.nanmedian(mat[0])}")
# Interquartile range:
q75, q25 = np.nanpercentile(mat[0].tolist(), [75, 25])
iqr_interp = q75 - q25
print(f"IQR gaze DD: {iqr_interp} ({q25}-{q75})")
# substract original nr of trials and those left


tr = []
for nr, uid in enumerate(idd):
    trigger_file = pd.read_csv(f"{PATH_TRG}/TriggerFile_newTSdd_{uid}.csv", index_col=0)
    tr = tr + [len(trigger_file) - mat.iloc[nr]]


print(f"Median Amount of Trials rejected: {np.nanmedian(tr)}")
# Interquartile range:
q75, q25 = np.nanpercentile(tr, [75, 25])
iqr_interp = q75 - q25
print(f"IQR: {iqr_interp} ({q25}-{q75})")

In [None]:
'''
Plot ERPs:
1. across subjects
2. & 3. topoplots
'''
# set up figure
# rc used to move the axins labels into the middle of the plot
rc = {"xtick.direction" : "inout", "ytick.direction" : "inout",
      "xtick.major.size" : 30, "ytick.major.size" : 30,
      "xtick.major.width" : 2, "ytick.major.width" : 2,
     "figure.figsize": (30, 32)}
with plt.rc_context(rc):
    plt.figure(constrained_layout=True)

    # define
    labelsize = 40 # text
    legendsize = 40 # ledgend
    ticksize = 30 # ticks
    numbersize = 60 # A, B etc.
    fname = "Arial" # font name
    plt.rcParams["font.family"] = fname
    # set up grid
    ax1 = plt.subplot2grid(shape=(5, 9), loc=(0, 0), rowspan=2, colspan=9)
    ax2 = plt.subplot2grid(shape=(5, 9), loc=(2, 2), rowspan=2, colspan=4)
    ax31 = plt.subplot2grid(shape=(5, 9), loc=(4, 0), rowspan=1, colspan=2)
    ax32 = plt.subplot2grid(shape=(5, 9), loc=(4, 2), rowspan=1, colspan=2)
    ax33 = plt.subplot2grid(shape=(5, 9), loc=(4, 4), rowspan=1, colspan=2)
    ax34 = plt.subplot2grid(shape=(5, 9), loc=(4, 6), rowspan=1, colspan=2)
    ax35 = plt.subplot2grid(shape=(5, 9), loc=(4, 8), rowspan=1, colspan=1)

    ####### Average ERPs #######
    # load the appropriate file: created with Matlab
    mat = scipy.io.loadmat(f"{PATH_EEG}/avg_erps.mat")
    times = mat["times"].tolist()[0] # get time
    mat = mat["avg_erps_no"] # get average erps
    mat = pd.DataFrame(mat)
    # plot each subject ERP
    for i in range(len(mat.index[:])):
        # very first subject, is the sample one, plot it in red
        if i == 0:
            ax1.plot(times,mat.iloc[i].values.tolist(),color=sacc_color_1,alpha=0.5,label="Individual ERP (sample)")
        # first one, save the label for the ledgend
        elif i ==1:
            ax1.plot(times,mat.iloc[i].values.tolist(),color=gaze_color_1,alpha=0.2,label="Individual ERPs")
        # for the rest, don't save the labels
        else:
            ax1.plot(times,mat.iloc[i].values.tolist(),color=gaze_color_1,alpha=0.2,label="_Hidden label")
    # plot the average ERP
    ax1.plot(times,mat.mean().tolist(),color="k",linewidth=4,label="Average ERP")
    
    # legend:
    handles, labels = ax1.get_legend_handles_labels()
    # make sure to only append three labels
    new_handlers, new_labels = [], []
    for h,l in zip(handles, labels):
        if l in ['Individual ERP (sample)','Individual ERPs',"Average ERP"]:
            new_handlers.append(h)
            new_labels.append(l)
    # change the legend item alpha to 1.0
    line2d_obj1 = Line2D([0], [0], linewidth=3, linestyle='solid', alpha = 1.0, color = sacc_color_1)
    line2d_obj2 = Line2D([0], [0], linewidth=3, linestyle='solid', alpha = 1.0, color = gaze_color_1)
    new_handlers[-3] = line2d_obj1
    new_handlers[-2] = line2d_obj2
    legend = ax1.legend(new_handlers,new_labels,loc="upper right", fontsize=legendsize, frameon=False)
    
    # axis labels
    ax1.set_xlabel("ms", fontsize=labelsize, fontname=fname)
    ax1.xaxis.set_label_coords(0.985,0.376)
    ax1.set_ylabel(u'$\it{\u03bc}$' + 'V', fontsize=labelsize, fontname=fname, loc='bottom', rotation = 0)
    ax1.yaxis.set_label_coords(0.315,-0.005)
    ax1.text(-227,-0.1,"Oz", fontsize=labelsize)
    
    # ticks
    x = [-200,-100,100,200,300,400,500]
    ax1.set_xticks(x)
    y = [-2,-1,1,2,3]
    ax1.set_yticks(y)
    for label in ax1.get_xticklabels(): # change tick font
        label.set_fontproperties(fname)
    for label in ax1.get_yticklabels():
        label.set_fontproperties(fname)
    ax1.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
    ax1.xaxis.set_tick_params(labelsize=ticksize)
    # text: A
    ax1.set_title("A", fontsize=50, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 

    # adjust spines
    ax1.spines[['top','right']].set_visible(False)
    ax1.spines[['bottom','left']].set_linewidth(2)
    ax1.spines[['bottom','left']].set_position(('data', 0))
    ax1.spines[['bottom','left']].set_clip_on(True)
    ax1.spines[['left']].set_bounds([- 2,3])
    ax1.spines[['bottom']].set_bounds([- 200,500])

    ax1.xaxis.set_ticks_position('bottom')
    ax1.yaxis.set_ticks_position('left')

                                 
    ####### Topoplot ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_ERPs.png")
    ax2.imshow(img)
    ax2.axis('off')
    ax2.set_title("B", fontsize=50, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 


    ####### Topoplot 0.0 ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_0.png")
    ax31.imshow(img)
    ax31.axis('off')
    ax31.set_title("C", fontsize=50, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 
    # add time interval description
    ax31.text(0.5, -0.05, '0 : 20 ms',fontsize=27, horizontalalignment='center', verticalalignment='center', transform=ax31.transAxes, fontname=fname)

    ####### Topoplot 0.08 ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_0.08.png")
    ax32.imshow(img)
    ax32.axis('off')
    # add time interval description
    ax32.text(0.5, -0.05, '80 : 100 ms',fontsize=27, horizontalalignment='center', verticalalignment='center', transform=ax32.transAxes, fontname=fname)

    ####### Topoplot 0.15 ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_0.15.png")
    ax33.imshow(img)
    ax33.axis('off')
    # add time interval description
    ax33.text(0.5, -0.05, '150 : 170 ms',fontsize=27, horizontalalignment='center', verticalalignment='center', transform=ax33.transAxes, fontname=fname)

    ####### Topoplot 0.28 ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_0.28.png")
    ax34.imshow(img)
    ax34.axis('off')
    # add time interval description
    ax34.text(0.5, -0.05, '280 : 300 ms',fontsize=27, horizontalalignment='center', verticalalignment='center', transform=ax34.transAxes, fontname=fname)

    ####### Colorbar ######
    # load the image (created with Matlab)
    img = plt.imread(f"{PATH_EEG}/Topoplot_0_colorbar.png")
    ax35.imshow(img)
    ax35.axis('off')
    # add colorbar label
    ax35.text(0.25, -0.05, u'\u03bc' + 'V', fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax35.transAxes, fontname=fname)


    plt.show()

## ERSPs

In [None]:
labelsize = 35 #text

# set up figure
plt.figure(figsize=(30, 22), constrained_layout=True)
sns.set_style(
    "white"
)
# define:
labelsize = 40 # text
legendsize = 40 # ledgend
ticksize = 30 # ticks
numbersize = 60 # A, B etc.
fname = "Arial" # font name
plt.rcParams["font.family"] = fname
# define grid
ax1 = plt.subplot2grid(shape=(4, 9), loc=(0, 0), rowspan=3, colspan=4)
ax2 = plt.subplot2grid(shape=(4, 9), loc=(0, 5), rowspan=3, colspan=4)
ax3 = plt.subplot2grid(shape=(4, 9), loc=(3, 0), rowspan=1, colspan=2)
ax4 = plt.subplot2grid(shape=(4, 9), loc=(3, 2), rowspan=1, colspan=2)
ax5 = plt.subplot2grid(shape=(4, 9), loc=(3, 4), rowspan=1, colspan=2)
ax6 = plt.subplot2grid(shape=(4, 9), loc=(3, 6), rowspan=1, colspan=2)
ax7 = plt.subplot2grid(shape=(4, 9), loc=(3, 8), rowspan=1, colspan=1)


####### ERSP single subject ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/NEW_ERSP_at_Oz_single_Subject.png")
ax1.imshow(img)
ax1.set_yticklabels([])
ax1.set_xticklabels([])
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax1.spines['left'].set_visible(False)
ax1.set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 
# axis labels
ax1.text(0.45, -0.03, 'time (s)',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes, fontname=fname)
ax1.text(-0.03, 0.5, 'frequency (Hz)',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', rotation = 90, transform=ax1.transAxes, fontname=fname)
ax1.text(0.89, 0.04, 'Power',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes, fontname=fname)
ax1.text(0.89, 0.005, '(db)',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax1.transAxes, fontname=fname)

####### ERSP all subjects ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/NEW_ERSP_at_Oz_all_Subject.png")
ax2.imshow(img)
ax2.set_yticklabels([])
ax2.set_xticklabels([])
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.05, pad=-30, fontname=fname) 
# axis labels
ax2.text(0.45, -0.03, 'time (s)',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes, fontname=fname)
ax2.text(-0.03, 0.5, 'frequency (Hz)',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', rotation = 90, transform=ax2.transAxes, fontname=fname)
ax2.text(0.89, 0.04, 'Power',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes, fontname=fname)
ax2.text(0.89, 0.005, '(db)',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax2.transAxes, fontname=fname)

####### Topo 1 single subject ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/ERSP_topo_single_Subject_0-0.15s.jpg")
ax3.imshow(img)
ax3.set_yticklabels([])
ax3.set_xticklabels([])
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
ax3.spines['bottom'].set_visible(False)
ax3.spines['left'].set_visible(False)
ax3.set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.2, pad=-30, fontname=fname) 
# axis labels
ax3.text(0.15, -0.03, 'Freq = [4 15]',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax3.transAxes, fontname=fname)
ax3.text(0.5, 1.05, '0 ms - 150 ms',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax3.transAxes, fontname=fname)

####### Topo 2 single subject ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/ERSP_topo_single_Subject_0.16-0.3s.jpg")
ax4.imshow(img)
ax4.set_yticklabels([])
ax4.set_xticklabels([])
ax4.spines['top'].set_visible(False)
ax4.spines['right'].set_visible(False)
ax4.spines['bottom'].set_visible(False)
ax4.spines['left'].set_visible(False)
ax4.set_title("D", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.2, pad=-30, fontname=fname) 
# axis labels
ax4.text(0.15, -0.03, 'Freq = [4 15]',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax4.transAxes, fontname=fname)
ax4.text(0.5, 1.05, '160 ms - 300 ms',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax4.transAxes, fontname=fname)

####### Topo 1 all subjects ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/ERSP_topo_all_Subject_0-0.15s.jpg")
ax5.imshow(img)
ax5.set_yticklabels([])
ax5.set_xticklabels([])
ax5.spines['top'].set_visible(False)
ax5.spines['right'].set_visible(False)
ax5.spines['bottom'].set_visible(False)
ax5.spines['left'].set_visible(False)
ax5.set_title("E", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.2, pad=-30, fontname=fname) 
# axis labels
ax5.text(0.15, -0.03, 'Freq = [4 15]',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax5.transAxes, fontname=fname)
ax5.text(0.5, 1.05, '0 ms - 150 ms',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax5.transAxes, fontname=fname)

####### Topo 2 all subjects ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/ERSP_topo_all_Subject_0.16-0.3s.jpg")
ax6.imshow(img)
ax6.set_yticklabels([])
ax6.set_xticklabels([])
ax6.spines['top'].set_visible(False)
ax6.spines['right'].set_visible(False)
ax6.spines['bottom'].set_visible(False)
ax6.spines['left'].set_visible(False)
ax6.set_title("F", fontsize=numbersize, fontweight="bold",loc="left", x=-0.07, y=1.2, pad=-30, fontname=fname) 
# axis labels
ax6.text(0.15, -0.03, 'Freq = [4 15]',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax6.transAxes, fontname=fname)
ax6.text(0.5, 1.05, '160 ms - 300 ms',fontsize=labelsize, horizontalalignment='center', verticalalignment='center', transform=ax6.transAxes, fontname=fname)

####### Colorbar single subject ######
# load image (created with Matlab)
img = plt.imread(f"{PATH_EEG}/NEW_Topoplot_all_Subject_colorbar.png")
ax7.imshow(img)
ax7.set_yticklabels([])
ax7.set_xticklabels([])
ax7.spines['top'].set_visible(False)
ax7.spines['right'].set_visible(False)
ax7.spines['bottom'].set_visible(False)
ax7.spines['left'].set_visible(False)
# axis labels
ax7.text(0.2, -0.03, 'Power',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax7.transAxes, fontname=fname)
ax7.text(0.2, -0.09, '(db)',fontsize=labelsize-10, horizontalalignment='center', verticalalignment='center', transform=ax7.transAxes, fontname=fname)

plt.show()

## Correlation ERPs

In [None]:
# Get the mean correlation for the no-shift condition
ids = recordings.index.tolist()
idd = ids[:]

no_shift = []
for i, uid in enumerate(idd):
    # load data
    cond = "dd"
    #load the appropriate file
    mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_{cond}_{subj_to_inlcude[i]}.mat")
    # get the data out if the mat file
    mat = mat["corr_coef"]
    mat = pd.DataFrame(mat).transpose()

    no_shift = no_shift + [np.median(mat[mat[0] != 0][0].tolist())]

display(np.mean(no_shift))


In [None]:
# Compute and ANOVA to test if the shifts are significantly different

ids = recordings.index.tolist()
idd = ids[:]

no_shift = []
pos_shift = []
neg_shift = []

for i, uid in enumerate(idd):
    # load data
    cond = "dd"
    #load the appropriate file
    mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_{cond}_{subj_to_inlcude[i]}.mat")
    # get the data out if the mat file
    mat = mat["corr_coef"]
    mat = pd.DataFrame(mat).transpose()

    # save the median correlation for each shift condition
    no_shift = no_shift + [np.median(mat[mat[0] != 0][0].tolist())]
    pos_shift = pos_shift + [np.median(mat[mat[1] != 0][1].tolist())]
    neg_shift = neg_shift + [np.median(mat[mat[2] != 0][2].tolist())]

# create a df
corr_long = {}
corr_long['subject'] = [int(li) for li in np.linspace(0,18,19)] * 3 # subject nr
corr_long['shift'] = ['no_shift'] * len(no_shift) + ['pos_shift'] * len(pos_shift) + ['neg_shift'] * len(neg_shift)
corr_long['corr_coef'] = no_shift + pos_shift + neg_shift
corr_long = pd.DataFrame(corr_long)

### Repeated measure ANOVA
print("Repeated measure ANOVA")
print(AnovaRM(data=corr_long, depvar='corr_coef',
              subject='subject', within=['shift']).fit())
print()
### Check Assumtions:
# 1. Sphericity using the Mauchly test
mauchly_result = pg.sphericity(corr_long, dv='corr_coef', subject='subject', 
                               within='shift')
print(mauchly_result)
print()


### Compute pairwise comparisons with Fisher's LSD test
stat,p1 = scipy.stats.ttest_rel(no_shift, pos_shift)
stat,p2 = scipy.stats.ttest_rel(no_shift, neg_shift)
stat,p3 = scipy.stats.ttest_rel(pos_shift, neg_shift)

p_values = [p1,p2,p3]

print([round(pv,3) for pv in p_values])
print()
print('corrected p-values')
corrected_p_values = multipletests(p_values, alpha=0.05, method='bonferroni')[1]
print(corrected_p_values)
corrected_p_values = [round(pv, 3) for pv in corrected_p_values]
# Print the corrected p-value4
print(corrected_p_values)

## Correlation ERSPs

In [None]:
### Finished Version --> use this (20.03.23)
ids = recordings.index.tolist()
idd = ids[:]

no_shift = []

for i, uid in enumerate(idd):
    # load data
    cond = "dd"
    #load the appropriate file
    mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_ERSP_{cond}_{subj_to_inlcude[i]}.mat")

    # get the data out if the mat file
    mat = mat["corr_coef"]
    mat = pd.DataFrame(mat).transpose()
    # get the median correlation
    no_shift = no_shift + [np.median(mat[mat[0] != 0][0].tolist())]

# display the mean across subjects
display(np.mean(no_shift))


In [None]:
# Compute and ANOVA to test if the shifts are significantly different

ids = recordings.index.tolist()
idd = ids[:]

no_shift = []
pos_shift = []
neg_shift = []

for i, uid in enumerate(idd):
    # load data
    cond = "dd"
    #load the appropriate file
    mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_ERSP_{cond}_{subj_to_inlcude[i]}.mat")

    # get the data out if the mat file
    mat = mat["corr_coef"]
    mat = pd.DataFrame(mat).transpose()
    # save the median correlation for each shift condition
    no_shift = no_shift + [np.median(mat[mat[0] != 0][0].tolist())]
    pos_shift = pos_shift + [np.median(mat[mat[1] != 0][1].tolist())]
    neg_shift = neg_shift + [np.median(mat[mat[2] != 0][2].tolist())]


# create df
corr_long = {}
corr_long['subject'] = [int(li) for li in np.linspace(0,18,19)] * 3 # subject nr
corr_long['shift'] = ['no_shift'] * len(no_shift) + ['pos_shift'] * len(pos_shift) + ['neg_shift'] * len(neg_shift)
corr_long['corr_coef'] = no_shift + pos_shift + neg_shift
corr_long = pd.DataFrame(corr_long)

### Repeated measure ANOVA
print("Repeated measure ANOVA")
print(AnovaRM(data=corr_long, depvar='corr_coef',
              subject='subject', within=['shift']).fit())

print()
### Check Assumtions:
# 1. Sphericity using the Mauchly test
mauchly_result = pg.sphericity(corr_long, dv='corr_coef', subject='subject', 
                               within='shift')
print(mauchly_result)
print()

### Compute pairwise comparisons with Fisher's LSD test
stat,p1 = scipy.stats.ttest_rel(no_shift, pos_shift)
stat,p2 = scipy.stats.ttest_rel(no_shift, neg_shift)
stat,p3 = scipy.stats.ttest_rel(pos_shift, neg_shift)
p_values = [p1,p2,p3]

print([round(pv,3) for pv in p_values])
print()
print('corrected p-values')
corrected_p_values = multipletests(p_values, alpha=0.05, method='bonferroni')[1]
print(corrected_p_values)
corrected_p_values = [round(pv, 3) for pv in corrected_p_values]
# Print the corrected p-value4
print(corrected_p_values)


## Plot for both correlations

In [None]:
# Plot the correlations of ERPs and ERSPs

# define
labelsize = 40 #text
legendsize = 40 #ledgend
ticksize = 30 #ticks
numbersize = 60 #A, B etc.
fname = "Arial" # font name
pallet = [gaze_color_1,sacc_color_1,vel_eye_color]

# set up figure
plt.figure(figsize=(30, 20), constrained_layout=True)
sns.set_style("white") 
ax1 = plt.subplot2grid(shape=(2, 3), loc=(0, 0), rowspan=2, colspan=1)
ax2 = plt.subplot2grid(shape=(2, 3), loc=(0, 1), rowspan=1, colspan=2)
ax3 = plt.subplot2grid(shape=(2, 3), loc=(1, 1), rowspan=1, colspan=2)

# lists to save the individual subject medians
no_shift = []
pos_shift = []
neg_shift = []

# plot ERPs and ERSPs
for i in range(2):
    if i == 0:
        mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_dd_{subj_to_inlcude[i]}.mat")
    else:
        mat = scipy.io.loadmat(f"{PATH_EEG}/correlation_shift_ERSP_dd_{subj_to_inlcude[i]}.mat")

    # get the data out if the mat file
    mat = mat["corr_coef"]
    mat = pd.DataFrame(mat).transpose()
    # plot the bars for no-shift and the two shift conditions for the first subject
    ax1.bar(i,mat[mat[0] != 0][0].median(),width = 0.2, color = pallet[0], alpha=0.7, label='No Shift')
    ax1.bar(i+0.2,mat[mat[1] != 0][1].median(),width = 0.2, color = pallet[1], alpha=0.7, label='Positive Shift')
    ax1.bar(i+0.4,mat[mat[2] != 0][2].median(),width = 0.2, color = pallet[2], alpha=0.7, label='Negative Shift')
    # set labels
    ax1.set_ylabel("Correlation of Trials and Avg. ERP/ERSP", fontsize=labelsize, fontname=fname)
    # set ticks
    for label in ax1.get_xticklabels(): # change tick font
        label.set_fontproperties(fname)
    for label in ax1.get_yticklabels():
        label.set_fontproperties(fname)
    ax1.set_xticks(np.linspace(0.2, 1.2, num=2),['ERP','ERSP'],fontsize=labelsize)
    ax1.set_yticks(np.linspace(0, 0.6, num=7))
    ax1.yaxis.set_tick_params(labelsize=ticksize)  # change tick size
    
    #display(mat[mat[0] != 0][0].median())
    
    mat = mat.rename(columns={0: "no_shift", 1: "pos_shift", 2: "neg_shift"})
    # for easier code, set the axis
    if i == 0: 
        axis = ax2
    else:
        axis = ax3
    # plot the across subject data  
    ax = sns.kdeplot(mat["no_shift"], color=pallet[0], ax = axis, lw=4, fill=True, alpha = 0.2, label="No Shift")
    ax = sns.kdeplot(mat["pos_shift"], color=pallet[1], ax = axis, lw=4, fill=True, alpha = 0.2, label="Positive Shift")
    ax = sns.kdeplot(mat["neg_shift"], color=pallet[2], ax = axis, lw=4, fill=True, alpha = 0.2, label="Negative Shift")
    # set the axis labels and ticks
    for label in axis.get_xticklabels(): # change tick font
        label.set_fontproperties(fname)
    for label in axis.get_yticklabels():
        label.set_fontproperties(fname)
    ax.set_xlim(-1, 1)
    ax.set_ylim(0, 3.1)
    ax.set_xlabel("Correlation", fontsize=labelsize, fontname=fname)
    ax.set_ylabel("Density", fontsize=labelsize, fontname=fname)
    ax.set_xticks(np.linspace(-1, 1, num=5))
    ax.set_yticks(np.linspace(0.0, 3, num=4))
    ax.xaxis.set_tick_params(labelsize=ticksize)  # change tick size
    ax.yaxis.set_tick_params(labelsize=ticksize)  # change tick size

    # for one plot, display the ledgend
    ax2.legend(loc="upper right", fontsize=legendsize)

    ax1.set_title("A", fontsize=numbersize, fontweight="bold",loc="left", x=-0.09, y=1.02199, pad=-30, fontname=fname) 
    ax2.set_title("B", fontsize=numbersize, fontweight="bold",loc="left", x=-0.055, y=1.05, pad=-30, fontname=fname) 
    ax3.set_title("C", fontsize=numbersize, fontweight="bold",loc="left", x=-0.055, y=1.05, pad=-30, fontname=fname) 
plt.show()