# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 

import xarray as xr
import matplotlib.gridspec as gridspec
import matplotlib as mpl

from tqdm.notebook import tqdm
from matplotlib.backends.backend_pdf import PdfPages
from scipy.signal import medfilt
from pathlib import Path

titles = [r'$\theta$',r'$\phi$',r'$\rho$',r'$\omega$']
move_clrs = ['blue','orange','green','red']

mpl.rcParams.update({'font.size':         24,
                     'axes.linewidth':    3,
                     'xtick.major.size':  5,
                     'xtick.major.width': 3,
                     'ytick.major.size':  5,
                     'ytick.major.width': 2,
                     'axes.spines.right': False,
                     'axes.spines.top':   False,
                     'font.sans-serif':   "Arial",
                     'font.family':       "sans-serif",
                     'pdf.fonttype':      42,
                    })


In [None]:
date_ani = '070921/J553RT'
stim_type = 'fm1'
save_dir = Path('/home/seuss/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1')

fmimu_file = list((Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type).glob('*imu.nc'))[0]
eye_file = list((Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type).glob('*REYE.nc'))[0]
world_file = list((Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type).glob('*world.nc'))[0]
ephys_file = list((Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type).glob('*ephys_props.h5'))[0]

fontsize = 24
dT = 5
Tmin = 1497
Tmax = Tmin + 10
dT_all = Tmax-Tmin


ephys_data = pd.read_hdf(ephys_file)
spikeT = ephys_data['spikeT']

imu_data = xr.open_dataset(fmimu_file)
accT = imu_data.IMU_data.sample # imu timestamps
acc_chans = imu_data.IMU_data # imu dample data
# raw gyro values
gx = np.array(acc_chans.sel(channel='gyro_x_raw'))
gy = np.array(acc_chans.sel(channel='gyro_y_raw'))
gz = np.array(acc_chans.sel(channel='gyro_z_raw'))
# gyro values in degrees
gx_deg = np.array(acc_chans.sel(channel='gyro_x'))
gy_deg = np.array(acc_chans.sel(channel='gyro_y'))
gz_deg = np.array(acc_chans.sel(channel='gyro_z'))
# pitch and roll in deg
groll = medfilt(np.array(acc_chans.sel(channel='roll')),11)
gpitch = medfilt(np.array(acc_chans.sel(channel='pitch')),11)

eye_data = xr.open_dataset(eye_file)
eyeT = eye_data.timestamps.copy()
eye_params = eye_data['REYE_ellipse_params']
eyeT = eye_data.timestamps  - ephys_data['t0'].iloc[0]

th = np.array((eye_params.sel(ellipse_params = 'theta'))*180/np.pi)#-np.nanmean(eye_params.sel(ellipse_params = 'theta')))*180/3.14159)
phi = np.array((eye_params.sel(ellipse_params = 'phi'))*180/np.pi)#-np.nanmean(eye_params.sel(ellipse_params = 'phi')))*180/3.14159)

world_data = xr.open_dataset(world_file)
world_vid_raw = np.uint8(world_data['WORLD_video'])
worldT = world_data.timestamps.copy()
worldT = worldT - ephys_data['t0'].iloc[0]

if eyeT[0]<-600:
    eyeT = eyeT + 8*60*60 # 8hr offset for some data
if worldT[0]<-600:
    worldT = worldT + 8*60*60
accT2 = accT.sample.data - ephys_data['t0'].iloc[0]

moveT0 = np.nanargmin(np.abs(accT2-0))
moveTmin = np.nanargmin(np.abs(accT2-Tmin))
moveTmax = np.nanargmin(np.abs(accT2-Tmax))

eyeT0 = np.nanargmin(np.abs(eyeT-0))
eyeTmin = np.nanargmin(np.abs(eyeT-Tmin))
eyeTmax = np.nanargmin(np.abs(eyeT-Tmax))

worldT0 = np.nanargmin(np.abs(worldT-0))
worldTmin = np.nanargmin(np.abs(worldT-Tmin))
worldTmax = np.nanargmin(np.abs(worldT-Tmax))

cell_spikes = []
for celln,spks in spikeT.iteritems():
    cell_spikes.append(spks[(spks>Tmin) &(spks < Tmax)])


In [None]:

fig = plt.figure(constrained_layout=False, figsize=(10,15))
gs0 = gridspec.GridSpec(ncols=1, nrows=3, figure=fig,wspace=.1,hspace=.2)
gs00 = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs0[:1,:],wspace=.05,hspace=.05)
gs01 = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=gs0[1,0],wspace=.5,hspace=.7)
gs02 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs0[2,0],wspace=.5,hspace=.7)
axs1a = np.array([fig.add_subplot(gs00[n,m]) for n in range(1) for m in range(gs00.ncols)])
axs1b = np.array([fig.add_subplot(gs01[n,m]) for n in range(4) for m in range(1)])
axs1c = np.array([fig.add_subplot(gs02[0,0])])

skipT = [0,2,4,6]
for n,ax in enumerate(axs1a):
    skipT2 = np.nanargmin(np.abs(worldT-skipT[n]))
    ax.imshow(world_vid_raw[worldTmin+skipT2],cmap='gray')
    ax.axis('off')

ax = axs1b[0]
ax.plot(eyeT[eyeTmin:eyeTmax],th[eyeTmin:eyeTmax],c=move_clrs[0])
ax.set_ylabel(titles[0])
ax = axs1b[1]
ax.plot(eyeT[eyeTmin:eyeTmax],phi[eyeTmin:eyeTmax],c=move_clrs[1])
ax.set_ylabel(titles[1])
ax = axs1b[2]
ax.plot(accT2[moveTmin:moveTmax],gpitch[moveTmin:moveTmax],c=move_clrs[2])
ax.set_ylabel(titles[2])
ax = axs1b[3]
ax.plot(accT2[moveTmin:moveTmax],groll[moveTmin:moveTmax],c=move_clrs[3])
ax.set_ylabel(titles[3])

for ax, modeln in zip(axs1b,range(len(titles))):
    ax.set_yticks([-50,0,50])
    if modeln < len(titles)-1:
        ax.set_xticks([])
    else:
        ax.set_xticks(np.arange(Tmin,Tmax+1,5))
        ax.set_xticklabels(np.arange(0,Tmax-Tmin+1,5))

ax = axs1c[0]
ax.eventplot(cell_spikes,color='k',linelengths=.5)
ax.set_xlabel('time (sec)',fontsize=fontsize)
ax.set_ylabel('cell #',fontsize=fontsize)
ax.set_xticks(np.arange(Tmin, Tmax+dT, dT))
ax.set_xticklabels(np.arange(0, dT_all+dT, dT), fontsize=fontsize)

plt.show()

# fig.savefig(paper_fig_dir/'SpikeRaster.pdf', facecolor='white', transparent=True, bbox_inches='tight')
