In [None]:
import sys, os
sys.path.insert(0, '/home/niell_lab/Documents/github/FreelyMovingEphys/')
import pandas as pd
import cv2
import xarray as xr
from utils.aux_funcs import find
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from skimage.exposure import adjust_gamma
import matplotlib as mpl
mpl.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
from matplotlib.animation import FFMpegWriter
import subprocess
import wavio

In [None]:
import matplotlib
import matplotlib.patches as mpatches

In [None]:
import matplotlib as mpl
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams.update({'font.size': 18})

In [None]:
class EphysAnimation:
    def __init__(self):
        self.start = 396 # sec
        self.base = 'fm1_prey'
        self.recording_path = os.path.join('/home/niell_lab/data/ephys_preycap/ephys_recordings/111121/J586TT/', self.base)
        self.probe = 'DB_P64-3'
        self.vidfile = '/home/niell_lab/Desktop/111121_J586TT_'+self.base+'.mp4'
        self.audfile = '/home/niell_lab/Desktop/111121_J586TT_'+self.base+'.wav'
        self.merge_mp4_name = '/home/niell_lab/Desktop/111121_J586TT_'+self.base+'_merge.mp4'
        self.has_top = (False if 'hf' in self.base else True)
#         self.population_pickle_path = '/home/niell_lab/data/freely_moving_ephys/batch_files/101921/pooled_ephys_population_update_102021.pickle'
#         self.session_name = '111121_J553RT_control_Rig2'
        self.fps = 60
        
    def pack_video_frames(self, video_path, dwsmpl=1):
        # open the .avi file
        vidread = cv2.VideoCapture(video_path)
        # empty array that is the target shape
        # should be number of frames x downsampled height x downsampled width
        all_frames = np.empty([int(vidread.get(cv2.CAP_PROP_FRAME_COUNT)),
                            int(vidread.get(cv2.CAP_PROP_FRAME_HEIGHT)),
                            int(vidread.get(cv2.CAP_PROP_FRAME_WIDTH))], dtype=np.uint8)
        # iterate through each frame
        for frame_num in tqdm(range(0,int(vidread.get(cv2.CAP_PROP_FRAME_COUNT)))):
            # read the frame in and make sure it is read in correctly
            ret, frame = vidread.read()
            if not ret:
                break
            # convert to grayyscale
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            # downsample the frame by an amount specified in the config file
            sframe = cv2.resize(frame, (0,0), fx=dwsmpl, fy=dwsmpl, interpolation=cv2.INTER_NEAREST)
            # add the downsampled frame to all_frames as int8
            all_frames[frame_num,:,:] = sframe.astype(np.int8)
        return all_frames
        
    def setup(self):
        ephys_h5 = pd.read_hdf(find('*ephys*.h5', self.recording_path)[0])
        world_nc = xr.open_dataset(find('*world*.nc', self.recording_path)[0])
        eye_nc = xr.open_dataset(find('*REYE*.nc', self.recording_path)[0])
        top_nc = xr.open_dataset(find('*TOP1*.nc', self.recording_path)[0])
        eye_avi_path = find('*REYEdeinter.avi', self.recording_path)[0]
        imu_data = xr.open_dataset(find('*imu*.nc', self.recording_path)[0])
        ephysT0 = pd.read_json(find('*_ephys_merge.json', self.recording_path)[0]).iloc[0,12]
        
        print('eye')
        eye_params = eye_nc.REYE_ellipse_params
        self.eyeT = eye_nc.timestamps - ephysT0
        del eye_nc
#         self.eye = self.pack_video_frames(eye_avi_path).astype(np.uint8)[:,30:250,300:575]
        self.eye = interp1d(self.eyeT, self.pack_video_frames(eye_avi_path).astype(np.uint8)[:,30:280,260:590], axis=0, bounds_error=False)
        
        print('world')
        self.worldT = world_nc.timestamps - ephysT0
#         self.world = world_nc.WORLD_video.astype(np.uint8)
        self.world = interp1d(self.worldT, world_nc.WORLD_video.astype(np.uint8), axis=0, bounds_error=False)
        del world_nc
        
        print('top')
        self.topT = top_nc.timestamps - ephysT0
#         self.top = top_nc.TOP1_video.astype(np.uint8)
        self.top = interp1d(self.topT, top_nc.TOP1_video.astype(np.uint8), axis=0, bounds_error=False)
        del top_nc
        
        self.th = eye_params.sel(ellipse_params='theta').values
        self.longaxis = eye_params.sel(ellipse_params='longaxis')
        del eye_params
        
        self.gz = np.array(imu_data.IMU_data.sel(channel='gyro_z'))
        self.accT = ephys_h5[self.base+'_accT'].iloc[0]
        del imu_data
        
        self.ephys_h5 = ephys_h5
        del ephys_h5
        
    def normal_animation(self, do_animation=True):
        fig = plt.figure(constrained_layout=True, figsize=(9,12))
        gs = fig.add_gridspec(7,4)
        ax_eyecam = fig.add_subplot(gs[0,0])
        ax_worldcam = fig.add_subplot(gs[0,1])
        ax_topcam = fig.add_subplot(gs[0,2])
        ax_theta = fig.add_subplot(gs[1,0:3])
        ax_gyro_z = fig.add_subplot(gs[2,0:3])
        ax_radius = fig.add_subplot(gs[3,0:3])
        ax_raster = fig.add_subplot(gs[4:6,0:3])

        tr = [self.start, self.start+15]
        fr = np.mean(tr)
        eye_frame = np.abs(self.eyeT-fr).argmin(dim='frame')
        world_frame = np.abs(self.worldT-fr).argmin(dim='frame')
        top_frame = np.abs(self.topT-fr).argmin(dim='frame')

        ax_eyecam.cla(); ax_eyecam.axis('off')
        ax_eyecam.imshow(self.eye[eye_frame], 'gray', vmin=0, vmax=255, aspect='equal')

        ax_worldcam.cla(); ax_worldcam.axis('off'); 
        ax_worldcam.imshow(self.world[world_frame], 'gray', vmin=0, vmax=255, aspect='equal')

        ax_topcam.cla(); ax_topcam.axis('off'); 
        ax_topcam.imshow(self.top[top_frame], 'gray', vmin=0, vmax=255, aspect='equal')

        ax_theta.cla()
        ax_theta.fontsize = 20
        ax_theta.plot(self.eyeT, np.rad2deg(self.th), color='k')
        ax_theta.set_xlim(tr[0], tr[1]); 
        ax_theta.set_ylabel('deg')
        ax_theta.set_ylim([-10,45])
        ax_theta.set_xticklabels(np.arange(0,15,2))

        ax_radius.cla()
        ax_radius.fontsize = 20
        ax_radius.plot(self.eyeT, self.longaxis/np.nanmax(self.longaxis), color='k')
        ax_radius.set_xlim(tr[0], tr[1])
        ax_radius.set_ylabel('norm. radius')
        ax_radius.set_xticklabels(np.arange(0,15,2))
        ax_radius.set_ylim([0,0.5])
        ax_radius.set_xlabel('sec')

        ax_gyro_z.cla()
        ax_gyro_z.fontsize = 20
        ax_gyro_z.plot(self.accT, self.gz, color='k')
        ax_gyro_z.set_xlim(tr[0], tr[1])
        ax_gyro_z.set_ylim(-500, 500)
        ax_gyro_z.set_ylabel('deg/s')
        ax_gyro_z.set_xticklabels(np.arange(0,15,2))

        # plot spikes
        ax_raster.fontsize = 20
        sh_num = 2
        sh0 = np.arange(0,len(self.ephys_h5.index)+sh_num,sh_num)
        full_raster = np.array([]).astype(int)
        for sh in range(sh_num):
            full_raster = np.concatenate([full_raster, sh0+sh])
        # axR.hlines(np.linspace(this_unit-0.25,this_unit+0.25,10), tr[0],tr[1],alpha=0.1,color='tab:blue') # this unit
        for i, ind in enumerate(self.ephys_h5.index):
            i = full_raster[i]
            ax_raster.vlines(self.ephys_h5.at[ind,'spikeT'], i-0.25, i+0.25, color='k', linewidth=0.5)

        n_units = len(self.ephys_h5)
        ax_raster.set_ylim(n_units, -.5)
        ax_raster.set_xlim(tr[0], tr[1])
        ax_raster.set_xticklabels(np.arange(0,15,2))

        plt.tight_layout()

        if do_animation:
            writer = FFMpegWriter(fps=30, extra_args=['-vf','scale=3200:-2'])
            with writer.saving(fig, self.vidfile, dpi=500):
                for t in np.arange(tr[0],tr[1],1/30):
                    ax_eyecam.cla(); ax_eyecam.axis('off')
                    ax_eyecam.imshow(self.eye(t),'gray',vmin=0,vmax=255,aspect='equal')
                    ax_worldcam.cla(); ax_worldcam.axis('off'); 
                    ax_worldcam.imshow(self.world(t),'gray',vmin=0,vmax=255,aspect = "equal")
                    ax_topcam.cla(); ax_topcam.axis('off')
                    ax_topcam.imshow(self.top(t).astype(np.uint8),'gray',vmin=0,vmax=255,aspect = "equal")
                    # plot line for time, then remove
                    ln1 = ax_raster.vlines(t,-0.5,len(exc_goodcells),'tab:gray')
                    ln3 = ax_theta.vlines(t,-30,45,'tab:gray')
                    ln4 = ax_radius.vlines(t,0,1,'tab:gray')
                    ln5 = ax_gyro_z.vlines(t,-500,500,'tab:gray')
                    writer.grab_frame()
                    ln1.remove(); ln3.remove(); ln4.remove(); ln5.remove()

In [None]:
ea = EphysAnimation()

In [None]:
eyevid = ea.pack_video_frames(find('*REYEdeinter.avi', ea.recording_path)[0]).astype(np.uint8)

In [None]:
f = 1000
plt.imshow(eyevid[f,30:280,260:590], 'gray', vmin=0, vmax=255, aspect="equal")

In [None]:
ea.setup()

In [None]:
do_animation = True
startmin = 2
startsec = 50
ea.start = (startmin*60)+startsec

fig = plt.figure(constrained_layout=True, figsize=(20,11))
gs = fig.add_gridspec(7,3)
ax_eyecam = fig.add_subplot(gs[0:3,0])
ax_worldcam = fig.add_subplot(gs[0:3,1])
ax_topcam = fig.add_subplot(gs[0:3,2])
ax_theta = fig.add_subplot(gs[3,0:3])
ax_gyro_z = fig.add_subplot(gs[4,0:3])
ax_radius = fig.add_subplot(gs[5,0:3])
ax_raster = fig.add_subplot(gs[6,0:3])

tr = [ea.start, ea.start+30]
fr = np.mean(tr)
eye_frame = np.abs(ea.eyeT-fr).argmin(dim='frame')
world_frame = np.abs(ea.worldT-fr).argmin(dim='frame')
top_frame = np.abs(ea.topT-fr).argmin(dim='frame')

ax_eyecam.cla(); ax_eyecam.axis('off')
ax_eyecam.imshow(ea.eye(eye_frame), 'gray', vmin=0, vmax=255, aspect='equal')

ax_worldcam.cla(); ax_worldcam.axis('off'); 
ax_worldcam.imshow(ea.world(world_frame), 'gray', vmin=0, vmax=255, aspect='equal')

ax_topcam.cla(); ax_topcam.axis('off'); 
ax_topcam.imshow(ea.top(top_frame), 'gray', vmin=0, vmax=255, aspect='equal')

ax_theta.cla()
ax_theta.fontsize = 20
ax_theta.plot(ea.eyeT, np.rad2deg(ea.th), color='k')
ax_theta.set_xlim(tr[0], tr[1]); 
ax_theta.set_ylabel('deg')
ax_theta.set_ylim([-10,45])
ax_theta.set_xticklabels(np.linspace(0,30,7).astype(int))

ax_radius.cla()
ax_radius.fontsize = 20
ax_radius.plot(ea.eyeT, ea.longaxis/np.nanmax(ea.longaxis), color='k')
ax_radius.set_xlim(tr[0], tr[1])
ax_radius.set_ylabel('norm. radius')
ax_radius.set_xticklabels(np.linspace(0,30,7).astype(int))
ax_radius.set_ylim([0,0.5])

ax_gyro_z.cla()
ax_gyro_z.fontsize = 20
ax_gyro_z.plot(ea.accT, ea.gz, color='k')
ax_gyro_z.set_xlim(tr[0], tr[1])
ax_gyro_z.set_ylim(-500, 500)
ax_gyro_z.set_ylabel('deg/s')
ax_gyro_z.set_xticklabels(np.linspace(0,30,7).astype(int))

# plot spikes
ax_raster.fontsize = 20
sh_num = 2
sh0 = np.arange(0,len(ea.ephys_h5.index)+sh_num,sh_num)
full_raster = np.array([]).astype(int)
for sh in range(sh_num):
    full_raster = np.concatenate([full_raster, sh0+sh])
# axR.hlines(np.linspace(this_unit-0.25,this_unit+0.25,10), tr[0],tr[1],alpha=0.1,color='tab:blue') # this unit
for i, ind in enumerate(ea.ephys_h5.index):
    i = full_raster[i]
    ax_raster.vlines(ea.ephys_h5.at[ind,'spikeT'], i-0.25, i+0.25, color='k', linewidth=0.5)

n_units = len(ea.ephys_h5)
ax_raster.set_ylim(n_units, -.5)
ax_raster.set_xlim(tr[0], tr[1])
ax_raster.set_xticklabels(np.linspace(0,30,7).astype(int))
ax_raster.set_xlabel('sec')

plt.tight_layout()

if do_animation:
    writer = FFMpegWriter(fps=30, extra_args=['-vf','scale=3200:-2'])
    with writer.saving(fig, ea.vidfile, dpi=500):
        for t in np.arange(tr[0],tr[1],1/30):
            ax_eyecam.cla(); ax_eyecam.axis('off')
            ax_eyecam.imshow(ea.eye(t),'gray',vmin=0,vmax=255,aspect='equal')
            ax_worldcam.cla(); ax_worldcam.axis('off'); 
            ax_worldcam.imshow(ea.world(t),'gray',vmin=0,vmax=255,aspect = "equal")
            ax_topcam.cla(); ax_topcam.axis('off')
            ax_topcam.imshow(ea.top(t).astype(np.uint8),'gray',vmin=0,vmax=255,aspect = "equal")
            # plot line for time, then remove
            ln1 = ax_raster.vlines(t,-0.5,len(ea.ephys_h5),'tab:gray')
            ln3 = ax_theta.vlines(t,-30,45,'tab:gray')
            ln4 = ax_radius.vlines(t,0,1,'tab:gray')
            ln5 = ax_gyro_z.vlines(t,-500,500,'tab:gray')
            writer.grab_frame()
            ln1.remove(); ln3.remove(); ln4.remove(); ln5.remove()

In [None]:
# plot spikes
startmin = 2
startsec = 50
ea.start=(startmin*60)+startsec
tr = [ea.start, ea.start+30]
ax_raster.fontsize = 20
sh_num = 2
sh0 = np.arange(0,len(ea.ephys_h5.index)+sh_num,sh_num)
full_raster = np.array([]).astype(int)
for sh in range(sh_num):
    full_raster = np.concatenate([full_raster, sh0+sh])
# axR.hlines(np.linspace(this_unit-0.25,this_unit+0.25,10), tr[0],tr[1],alpha=0.1,color='tab:blue') # this unit
for i, ind in enumerate(ea.ephys_h5.index):
    i = full_raster[i]
    plt.vlines(ea.ephys_h5.at[ind,'spikeT'], i-0.25, i+0.25, color='k', linewidth=0.5)

n_units = len(ea.ephys_h5)
plt.ylim(n_units, -.5)
plt.xlim(tr[0], tr[1])

In [None]:
plt.imshow

In [None]:
len(np.arange(tr[0],tr[1],1/30))

In [None]:
f = 250
t = np.arange(tr[0],tr[1],1/30)[f]
plt.imshow(ea.world(t),'gray',vmin=0,vmax=255,aspect = "equal")