In [1]:
import fmEphys
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd

In [None]:
hffm = pd.read_pickle('/home/niell_lab/Data/freely_moving_ephys/batch_files/062022/hffm_062022_gt.pickle')

In [None]:
def apply_win_to_comp_sacc(comp, gazeshift, win=0.25):
    bad_comp = np.array([c for c in comp for g in gazeshift if ((g>(c-win)) & (g<(c+win)))])
    comp_times = np.delete(comp, np.isin(comp, bad_comp))
    return comp_times

def keep_first_saccade(eventT, win=0.020):
    duplicates = set([])
    for t in eventT:
        new = eventT[((eventT-t)<win) & ((eventT-t)>0)]
        duplicates.update(list(new))
    out = np.sort(np.setdiff1d(eventT, np.array(list(duplicates)), assume_unique=True))
    return out

In [None]:
saccthresh = { # deg/sec
        'head_moved': 60,
        'gaze_stationary': 120,
        'gaze_moved': 240
    }

for stim in ['FmLt']:
        for s, name in enumerate(hffm.data['session'].unique()):
            print('{} stim of {} recording'.format(stim, name))
            dHead = hffm.data[stim+'_dHead'][hffm.data['session']==name].iloc[0]
            dGaze = hffm.data[stim+'_dGaze'][hffm.data['session']==name].iloc[0]
            eyeT = hffm.data[stim+'_eyeT'][hffm.data['session']==name].iloc[0][:-1]
            
            gazeL = eyeT[(dHead > saccthresh['head_moved']) & (dGaze > saccthresh['gaze_moved'])]
            gazeR = eyeT[(dHead < -saccthresh['head_moved']) & (dGaze < -saccthresh['gaze_moved'])]

            compL = eyeT[(dHead > saccthresh['head_moved']) & (dGaze < saccthresh['gaze_stationary']) & (dGaze > -saccthresh['gaze_stationary'])]
            compR = eyeT[(dHead < -saccthresh['head_moved']) & (dGaze > -saccthresh['gaze_stationary']) & (dGaze < saccthresh['gaze_stationary'])]
            
            compL = apply_win_to_comp_sacc(compL, gazeL)
            compR = apply_win_to_comp_sacc(compR, gazeR)
            
            # SDFs
            for ind in tqdm(hffm.data[hffm.data['session']==name].index.values):
                spikeT = hffm.data.loc[ind,stim+'_spikeT']
                
                movements = [gazeL, gazeR, compL, compR]
                movkeys = [stim+'_gazeshift_left_saccPSTH_dHead',
                        stim+'_gazeshift_right_saccPSTH_dHead',
                        stim+'_comp_left_saccPSTH_dHead',
                        stim+'_comp_right_saccPSTH_dHead']
                timekeys = [stim+'_gazeshift_left_saccTimes_dHead',
                        stim+'_gazeshift_right_saccTimes_dHead',
                        stim+'_comp_left_saccTimes_dHead',
                        stim+'_comp_right_saccTimes_dHead']
                for x in range(4):
                    movkey = movkeys[x]; timekey = timekeys[x]
                    eventT = movements[x]
                    
                    # save the spike density function
                    _, sdf = calc_kde_sdf(spikeT, eventT)
                    hffm.data.at[ind, movkey] = sdf.astype(object)
                    
                    # save the saccade times
                    hffm.data.at[ind, timekey] = eventT.astype(object)

In [None]:
demo = hffm[hffm['session']=='102621_J558NC_control_Rig2'].iloc[0]

left = demo['FmLt_gazeshift_left_saccTimes_dHead1']
right = demo['FmLt_gazeshift_right_saccTimes_dHead1']
comp = np.hstack([demo['FmLt_comp_left_saccTimes_dHead1'], demo['FmLt_comp_right_saccTimes_dHead1']])

plotinds = np.sort(np.random.choice(np.arange(eyeT.size), size=int(np.ceil(eyeT.size/25)), replace=False))
gazemovs = np.hstack([left, right])

for i in plotinds:
    dGaze_i = np.abs(dHead_data[i]+dEye_data[i])
    if (eyeT[i] in gazemovs) or (dGaze_i>240):
        c = colors['gaze']
    elif (eyeT[i] in comp) or (dGaze_i<120):
        c = colors['comp']
    elif (dGaze_i<240) and (dGaze_i>120): 
        c = 'dimgray'
    else:
        continue
    ax_dEyeHead.plot(dHead_data[i], dEye_data[i], '.', color=c, markersize=2)

ax_dEyeHead.set_aspect('equal','box')
ax_dEyeHead.set_xlim([-600,600])
ax_dEyeHead.set_ylim([-600,600])
ax_dEyeHead.set_xlabel('head velocity (deg/sec)')
ax_dEyeHead.set_ylabel('eye velocity (deg/sec)')
ax_dEyeHead.plot([-500,500],[500,-500], linestyle='dashed', color='k', linewidth=1)
# ax_dEyeHead.annotate('left', xy=[350,500], color='k')
# ax_dEyeHead.annotate('right', xy=[-550,-500], color='k')
# ax_dEyeHead.annotate('gaze shift', xy=[-620,470], color=colors['gaze'])
# ax_dEyeHead.annotate('compensated', xy=[-620,550], color=colors['comp'])
ax_dEyeHead.set_xticks(np.linspace(-600,600,5))
ax_dEyeHead.set_yticks(np.linspace(-600,600,5))