# Import Modules

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 
import ray
import logging 
import json
import gc
import cv2
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
# import io_dict_to_hdf5 as ioh5
import xarray as xr

from tqdm.notebook import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import interpolate 
from scipy import signal
from pathlib import Path
from scipy.interpolate import interp1d
from scipy.ndimage import shift as imshift
from sklearn.linear_model import LinearRegression


# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
# torch.backends.cudnn.benchmark = True
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append(str(Path('.').absolute().parent))
from utils import check_path, add_colorbar, sizeof_fmt
import io_dict_to_hdf5 as ioh5

pd.set_option('display.max_rows', None)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures')

ray.init(
    ignore_reinit_error=True,
    logging_level=logging.ERROR,
)
print(f'Dashboard URL: https://{ray.get_dashboard_url()}')
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Gather Data

In [None]:
save_dir = Path('~/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1').expanduser()
with open(save_dir / 'file_dict.json','r') as fp:
    file_dict = json.load(fp)

In [None]:
file_dict

In [None]:

@ray.remote
def shift_vid_parallel(x, world_vid, warp_mode, criteria, dt):
    xshift_t = []
    yshift_t = []
    cc_t = []
    for i in range(x,x+dt):
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try: 
            (cc, warp_matrix) = cv2.findTransformECC(world_vid[i,:,:], world_vid[i+1,:,:], warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=1)
            xshift = warp_matrix[0,2]
            yshift = warp_matrix[1,2]
        except:
            cc = np.nan
            xshift=np.nan
            yshift = np.nan
        xshift_t.append(xshift)
        yshift_t.append(yshift)
        cc_t.append(cc)
    return xshift_t, yshift_t, cc_t

@ray.remote
def shift_world_pt2(f, dt, world_vid, thInterp, phiInterp, ycorrection, xcorrection):
    if (f+dt) < world_vid.shape[0]:
        world_vid2 = np.zeros((dt,world_vid.shape[1],world_vid.shape[2]))
        for n, x in enumerate(range(f,f+dt)):
            world_vid2[n,:,:] = imshift(world_vid[x,:,:],(-np.int8(thInterp[x]*ycorrection[0] + phiInterp[x]*ycorrection[1]),
                                                         -np.int8(thInterp[x]*xcorrection[0] + phiInterp[x]*xcorrection[1])))
    else: 
        world_vid2 = np.zeros((world_vid.shape[0]-f,world_vid.shape[1],world_vid.shape[2]))
        for n,x in enumerate(range(f,world_vid.shape[0])):
            world_vid2[n,:,:] = imshift(world_vid[x,:,:],(-np.int8(thInterp[x]*ycorrection[0] + phiInterp[x]*ycorrection[1]),
                                                         -np.int8(thInterp[x]*xcorrection[0] + phiInterp[x]*xcorrection[1])))
    return world_vid2


def grab_aligned_data(goodcells, worldT, accT, img_norm, gz, groll, gpitch, th_interp, phi_interp, free_move=True, downsamp=0.25, model_dt=0.025):
    # get number of good units
    n_units = len(goodcells)
    print('doing GLM receptive field estimate')
    # simplified setup for GLM
    # these are general parameters (spike rates, eye position)
    n_units = len(goodcells)
    print('get timing')
    model_t = np.arange(0,np.max(worldT),model_dt)
    model_nsp = np.zeros((n_units,len(model_t)))

    # get spikes / rate
    print('get spikes')
    bins = np.append(model_t,model_t[-1]+model_dt)
    for i,ind in enumerate(goodcells.index):
        model_nsp[i,:],bins = np.histogram(goodcells.at[ind,'spikeT'],bins)

    # get eye position
    print('get eye')
    model_th = th_interp(model_t+model_dt/2)
    model_phi = phi_interp(model_t+model_dt/2)
    # del thInterp, phiInterp

    # get active times
    if free_move:
        interp = interp1d(accT,(gz-np.mean(gz))*7.5,bounds_error=False)
        model_gz = interp(model_t)
        model_active = np.convolve(np.abs(model_gz), np.ones(int(1/model_dt)), 'same') / len(np.ones(int(1/model_dt)))
        use = np.where((model_active>40))[0] # (np.abs(model_th)<10) & (np.abs(model_phi)<10) & 
        roll_interp = interp1d(accT,groll,bounds_error=False)
        pitch_interp = interp1d(accT,gpitch,bounds_error=False)
        model_roll = roll_interp(model_t)
        model_pitch = pitch_interp(model_t)
    else:
        use = np.where((np.abs(model_th)<10) & (np.abs(model_phi)<10))[0]
    # get video ready for GLM
    print('setting up video') 
    movInterp = interp1d(worldT, img_norm,'nearest', axis = 0,bounds_error = False) 
    testimg = movInterp(model_t[0])
    testimg = cv2.resize(testimg,(int(np.shape(testimg)[1]*downsamp), int(np.shape(testimg)[0]*downsamp)))
    testimg = testimg[5:-5,5:-5]; #remove area affected by eye movement correction
    model_vid_sm = np.zeros((len(model_t),int(np.shape(testimg)[0]),int(np.shape(testimg)[1])),dtype=float)
    for i in tqdm(range(len(model_t))):
        model_vid = movInterp(model_t[i] + model_dt/2)
        smallvid = cv2.resize(model_vid,(int(np.shape(img_norm)[2]*downsamp),int(np.shape(img_norm)[1]*downsamp)), interpolation=cv2.INTER_AREA)
        smallvid = smallvid[5:-5,5:-5]
        #smallvid = smallvid - np.mean(smallvid)
        model_vid_sm[i,:] = smallvid
    nks = np.shape(smallvid); nk = nks[0]*nks[1]
    model_vid_sm[np.isnan(model_vid_sm)]=0
    del movInterp
    gc.collect()
    return model_vid_sm, model_nsp, model_t, model_th, model_phi, model_roll, model_pitch, model_active

def load_ephys_data_aligned(file_dict, save_dir, free_move=True, has_imu=True, has_mouse=False, max_frames=60*60, model_dt=.025):
        
    ##### Align Data #####
    if (save_dir / 'ModelData_dt{:03d}.h5'.format(int(model_dt*1000))).exists():
        data = ioh5.load((save_dir / 'ModelData_dt{:03d}.h5'.format(int(model_dt*1000))))
    else:
        ##### Loading ephys experiment data #####
        print('Starting to Load Data')
        world_data = xr.open_dataset(file_dict['world'])
        world_vid_raw = np.uint8(world_data['WORLD_video'])
        # resize worldcam
        sz = world_vid_raw.shape # raw video size
        # if size is larger than the target 60x80, resize by 0.5
        if sz[1]>160:
            downsamp = 0.5
            world_vid = np.zeros((sz[0],np.int(sz[1]*downsamp),np.int(sz[2]*downsamp)), dtype = 'uint8')
            for f in range(sz[0]):
                world_vid[f,:,:] = cv2.resize(world_vid_raw[f,:,:],(np.int(sz[2]*downsamp),np.int(sz[1]*downsamp)))
        else:
            # if the worldcam has already been resized when the nc file was written in preprocessing, don't resize
            world_vid = world_vid_raw.copy()
        # world timestamps
        worldT = world_data.timestamps.copy()

        # open the topdown camera nc file
        top_data = xr.open_dataset(file_dict['top'])
        # get the speed of the base of the animal's tail in the topdown tracking
        # most points don't track well enough for this to be done with other parts of animal (e.g. head points)
        topx = top_data.TOP1_pts.sel(point_loc='tailbase_x').values; topy = top_data.TOP1_pts.sel(point_loc='tailbase_y').values
        topdX = np.diff(topx); topdY = np.diff(topy)
        top_speed = np.sqrt(topdX**2, topdY**2) # speed of tailbase in topdown camera
        topT = top_data.timestamps.copy() # read in time timestamps
        top_vid = np.uint8(top_data['TOP1_video']) # read in top video
        # clear from memory
        del top_data , world_vid_raw
        gc.collect()

        # load IMU data
        if file_dict['imu'] is not None:
            print('opening imu data')
            imu_data = xr.open_dataset(file_dict['imu'])
            accT = imu_data.IMU_data.sample # imu timestamps
            acc_chans = imu_data.IMU_data # imu dample data
            # raw gyro values
            gx = np.array(acc_chans.sel(channel='gyro_x_raw'))
            gy = np.array(acc_chans.sel(channel='gyro_y_raw'))
            gz = np.array(acc_chans.sel(channel='gyro_z_raw'))
            # gyro values in degrees
            gx_deg = np.array(acc_chans.sel(channel='gyro_x'))
            gy_deg = np.array(acc_chans.sel(channel='gyro_y'))
            gz_deg = np.array(acc_chans.sel(channel='gyro_z'))
            # pitch and roll in deg
            groll = np.array(acc_chans.sel(channel='roll'))
            gpitch = np.array(acc_chans.sel(channel='pitch'))

        print('opening ephys data')
        # ephys data for this individual recording
        ephys_data = pd.read_json(file_dict['ephys'])
        # sort units by shank and site order
        ephys_data = ephys_data.sort_values(by='ch', axis=0, ascending=True)
        ephys_data = ephys_data.reset_index()
        ephys_data = ephys_data.drop('index', axis=1)
        # spike times
        ephys_data['spikeTraw'] = ephys_data['spikeT']
        print('getting good cells')
        # select good cells from phy2
        goodcells = ephys_data.loc[ephys_data['group']=='good']
        units = goodcells.index.values
        # get number of good units
        n_units = len(goodcells)

        print('opening eyecam data')
        # load eye data
        eye_data = xr.open_dataset(file_dict['eye'])
        eye_vid = np.uint8(eye_data['REYE_video'])
        eyeT = eye_data.timestamps.copy()

        # plot eye postion across recording
        eye_params = eye_data['REYE_ellipse_params']

        # define theta, phi and zero-center
        th = np.array((eye_params.sel(ellipse_params = 'theta')-np.nanmean(eye_params.sel(ellipse_params = 'theta')))*180/3.14159)
        phi = np.array((eye_params.sel(ellipse_params = 'phi')-np.nanmean(eye_params.sel(ellipse_params = 'phi')))*180/3.14159)

        print('adjusting camera times to match ephys')
        # adjust eye/world/top times relative to ephys
        ephysT0 = ephys_data.iloc[0,12]
        eyeT = eye_data.timestamps  - ephysT0
        if eyeT[0]<-600:
            eyeT = eyeT + 8*60*60 # 8hr offset for some data
        worldT = world_data.timestamps - ephysT0
        if worldT[0]<-600:
            worldT = worldT + 8*60*60
        if free_move is True and has_imu is True:
            accT = imu_data.IMU_data.sample - ephysT0
        # if free_move is False and has_mouse is True:
        #     speedT = spd_tstamps - ephysT0
        if free_move is True:
            topT = topT - ephysT0

        ##### Clear some memory #####
        del world_data, eye_data, ephys_data 
        gc.collect()

        ##### Correction to world cam #####
        if (save_dir / 'corrected_worldcam.npy').exists():
            world_vid = np.load(save_dir / 'corrected_worldcam.npy', mmap_mode='r')
            # get eye displacement for each worldcam frame
            th_interp = interp1d(eyeT, th, bounds_error=False)
            phi_interp = interp1d(eyeT, phi, bounds_error=False)
        else:
            start = time.time()
            # get eye displacement for each worldcam frame
            th_interp = interp1d(eyeT, th, bounds_error=False)
            phi_interp = interp1d(eyeT, phi, bounds_error=False)
            dth = np.diff(th_interp(worldT))
            dphi = np.diff(phi_interp(worldT))
            # calculate x-y shift for each worldcam frame  
            number_of_iterations = 5000
            termination_eps = 1e-4
            criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
            warp_mode = cv2.MOTION_TRANSLATION

            # Parallel Testing
            world_vid_r = ray.put(world_vid)
            warp_mode_r = ray.put(warp_mode)
            criteria_r = ray.put(criteria)
            dt = 60
            result_ids = []
            [result_ids.append(shift_vid_parallel.remote(i, world_vid_r, warp_mode_r, criteria_r, dt)) for i in range(0, max_frames, dt)]
            results_p = ray.get(result_ids)
            results_p = np.array(results_p).transpose(0,2,1).reshape(-1,3)

            xshift = results_p[:,0]
            yshift = results_p[:,1]
            cc = results_p[:,2]

            xmodel = LinearRegression()
            ymodel = LinearRegression()

            # eye data as predictors
            eyeData = np.zeros((max_frames,2))
            eyeData[:,0] = dth[0:max_frames]
            eyeData[:,1] = dphi[0:max_frames]
            # shift in x and y as outputs
            xshiftdata = xshift[0:max_frames]
            yshiftdata = yshift[0:max_frames]
            # only use good data
            # not nans, good correlation between frames, small eye movements (no sacccades, only compensatory movements)
            usedata = ~np.isnan(eyeData[:,0]) & ~np.isnan(eyeData[:,1]) & (cc>0.95)  & (np.abs(eyeData[:,0])<2) & (np.abs(eyeData[:,1])<2) & (np.abs(xshiftdata)<5) & (np.abs(yshiftdata)<5)

            # fit xshift
            xmodel.fit(eyeData[usedata,:],xshiftdata[usedata])
            xmap = xmodel.coef_
            xrscore = xmodel.score(eyeData[usedata,:],xshiftdata[usedata])
            # fit yshift
            ymodel.fit(eyeData[usedata,:],yshiftdata[usedata])
            ymap = ymodel.coef_
            yrscore = ymodel.score(eyeData[usedata,:],yshiftdata[usedata])
            warp_mat_duration = time.time() - start
            print("warp mat duration =", warp_mat_duration)
            del results_p, warp_mode_r, criteria_r, result_ids
            gc.collect()

            start = time.time()
            print('estimating eye-world calibration')
            xcorrection_r = ray.put(xmap.copy())
            ycorrection_r = ray.put(ymap.copy())
            print('shifting worldcam for eyes')

            thInterp_r = ray.put(th_interp(worldT))
            phiInterp_r = ray.put(phi_interp(worldT))

            dt = 1000
            result_ids2 = []
            [result_ids2.append(shift_world_pt2.remote(f, dt, world_vid_r, thInterp_r, phiInterp_r, ycorrection_r, xcorrection_r)) for f in range(0, world_vid.shape[0], dt)] # 
            results = ray.get(result_ids2)
            world_vid = np.concatenate(results,axis=0).astype(np.uint8)
            print('saving worldcam video corrected for eye movements')
            np.save(file=(save_dir / 'corrected_worldcam.npy'), arr=world_vid)
            shift_world_duration = time.time() - start
            print("shift world duration =", shift_world_duration)
            print('Total Duration:', warp_mat_duration + shift_world_duration)

        ##### Calculating image norm #####
        print('Calculating Image Norm')
        start = time.time()
        world_vid = world_vid.astype(float)
        std_im = np.std(world_vid, axis=0, dtype=float)
        img_norm = ((world_vid-np.mean(world_vid,axis=0,dtype=float))/std_im).astype(float)
        std_im[std_im<20] = 0
        img_norm = (img_norm * (std_im>0)).astype(float)
        del world_vid
        gc.collect()
        img_norm_duration = time.time() - start
        print("img_norm duration =", img_norm_duration)

        start = time.time()
        model_vid_sm, model_nsp, model_t, model_th, model_phi, model_roll, model_pitch, model_active = grab_aligned_data(
            goodcells, worldT, accT, img_norm, gz, groll, gpitch, th_interp, phi_interp, free_move=True, downsamp=0.25, model_dt=model_dt)

        data = {'model_vid_sm': model_vid_sm,
                'model_nsp': model_nsp.T,
                'model_t': model_t,
                'model_th': model_th,
                'model_phi': model_phi,
                'model_roll': model_roll,
                'model_pitch': model_pitch,
                'model_active': model_active}
        
        ioh5.save( (save_dir / 'ModelData_dt{:03d}.h5'.format(int(model_dt*1000))), data)
        align_data_duration = time.time() - start
        print("align_data_duration =", align_data_duration)
    print('Done Loading Aligned Data')
    return data



In [None]:
data = load_ephys_data_aligned(file_dict, save_dir)

In [None]:
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

# Testing Aligning Data

In [None]:
free_move = True; has_imu = True; has_mouse = False
##### Loading ephys experiment data #####
world_data = xr.open_dataset(file_dict['world'])
world_vid_raw = np.uint8(world_data['WORLD_video'])
# resize worldcam
sz = world_vid_raw.shape # raw video size
# if size is larger than the target 60x80, resize by 0.5
if sz[1]>160:
    downsamp = 0.5
    world_vid = np.zeros((sz[0],np.int(sz[1]*downsamp),np.int(sz[2]*downsamp)), dtype = 'uint8')
    for f in range(sz[0]):
        world_vid[f,:,:] = cv2.resize(world_vid_raw[f,:,:],(np.int(sz[2]*downsamp),np.int(sz[1]*downsamp)))
else:
    # if the worldcam has already been resized when the nc file was written in preprocessing, don't resize
    world_vid = world_vid_raw.copy()
# world timestamps
worldT = world_data.timestamps.copy()

# open the topdown camera nc file
top_data = xr.open_dataset(file_dict['top'])
# get the speed of the base of the animal's tail in the topdown tracking
# most points don't track well enough for this to be done with other parts of animal (e.g. head points)
topx = top_data.TOP1_pts.sel(point_loc='tailbase_x').values; topy = top_data.TOP1_pts.sel(point_loc='tailbase_y').values
topdX = np.diff(topx); topdY = np.diff(topy)
top_speed = np.sqrt(topdX**2, topdY**2) # speed of tailbase in topdown camera
topT = top_data.timestamps.copy() # read in time timestamps
top_vid = np.uint8(top_data['TOP1_video']) # read in top video
# clear from memory
del top_data, world_vid_raw
gc.collect()

# load IMU data
if file_dict['imu'] is not None:
    print('opening imu data')
    imu_data = xr.open_dataset(file_dict['imu'])
    accT = imu_data.IMU_data.sample # imu timestamps
    acc_chans = imu_data.IMU_data # imu dample data
    # raw gyro values
    gx = np.array(acc_chans.sel(channel='gyro_x_raw'))
    gy = np.array(acc_chans.sel(channel='gyro_y_raw'))
    gz = np.array(acc_chans.sel(channel='gyro_z_raw'))
    # gyro values in degrees
    gx_deg = np.array(acc_chans.sel(channel='gyro_x'))
    gy_deg = np.array(acc_chans.sel(channel='gyro_y'))
    gz_deg = np.array(acc_chans.sel(channel='gyro_z'))
    # pitch and roll in deg
    groll = np.array(acc_chans.sel(channel='roll'))
    gpitch = np.array(acc_chans.sel(channel='pitch'))

print('opening ephys data')
# ephys data for this individual recording
ephys_data = pd.read_json(file_dict['ephys'])
# sort units by shank and site order
ephys_data = ephys_data.sort_values(by='ch', axis=0, ascending=True)
ephys_data = ephys_data.reset_index()
ephys_data = ephys_data.drop('index', axis=1)
# spike times
ephys_data['spikeTraw'] = ephys_data['spikeT']
print('getting good cells')
# select good cells from phy2
goodcells = ephys_data.loc[ephys_data['group']=='good']
units = goodcells.index.values
# get number of good units
n_units = len(goodcells)

print('opening eyecam data')
# load eye data
eye_data = xr.open_dataset(file_dict['eye'])
eye_vid = np.uint8(eye_data['REYE_video'])
eyeT = eye_data.timestamps.copy()

# plot eye postion across recording
eye_params = eye_data['REYE_ellipse_params']

# define theta, phi and zero-center
th = np.array((eye_params.sel(ellipse_params = 'theta')-np.nanmean(eye_params.sel(ellipse_params = 'theta')))*180/3.14159)
phi = np.array((eye_params.sel(ellipse_params = 'phi')-np.nanmean(eye_params.sel(ellipse_params = 'phi')))*180/3.14159)

print('adjusting camera times to match ephys')
# adjust eye/world/top times relative to ephys
ephysT0 = ephys_data.iloc[0,12]
eyeT = eye_data.timestamps  - ephysT0
if eyeT[0]<-600:
    eyeT = eyeT + 8*60*60 # 8hr offset for some data
worldT = world_data.timestamps - ephysT0
if worldT[0]<-600:
    worldT = worldT + 8*60*60
if free_move is True and has_imu is True:
    accT = imu_data.IMU_data.sample - ephysT0
if free_move is False and has_mouse is True:
    speedT = spd_tstamps - ephysT0
if free_move is True:
    topT = topT - ephysT0


# Parallel Testing

In [None]:
@ray.remote
def shift_vid_parallel(x, world_vid,wrap_mode,criteria,dt):
    xshift_t = []
    yshift_t = []
    cc_t = []
    for i in range(x,x+dt):
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try: 
            (cc, warp_matrix) = cv2.findTransformECC(world_vid[i,:,:], world_vid[i+1,:,:], warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=1)
            xshift = warp_matrix[0,2]
            yshift = warp_matrix[1,2]
        except:
            cc = np.nan
            xshift=np.nan
            yshift = np.nan
        xshift_t.append(xshift)
        yshift_t.append(yshift)
        cc_t.append(cc)
    return xshift_t, yshift_t, cc_t

@ray.remote
def shift_world_pt2(f, dt, world_vid, thInterp, phiInterp, ycorrection, xcorrection):
    if (f+dt) < world_vid.shape[0]:
        world_vid2 = np.zeros((dt,world_vid.shape[1],world_vid.shape[2]))
        for n, x in enumerate(range(f,f+dt)):
            world_vid2[n,:,:] = imshift(world_vid[x,:,:],(-np.int8(thInterp[x]*ycorrection[0] + phiInterp[x]*ycorrection[1]),
                                                         -np.int8(thInterp[x]*xcorrection[0] + phiInterp[x]*xcorrection[1])))
    else: 
        world_vid2 = np.zeros((world_vid.shape[0]-f,world_vid.shape[1],world_vid.shape[2]))
        for n,x in enumerate(range(f,world_vid.shape[0])):
            world_vid2[n,:,:] = imshift(world_vid[x,:,:],(-np.int8(thInterp[x]*ycorrection[0] + phiInterp[x]*ycorrection[1]),
                                                         -np.int8(thInterp[x]*xcorrection[0] + phiInterp[x]*xcorrection[1])))
    return world_vid2

In [None]:
start = time.time()
max_frames = 60*60
# get eye displacement for each worldcam frame
th_interp = interp1d(eyeT, th, bounds_error=False)
phi_interp = interp1d(eyeT, phi, bounds_error=False)
dth = np.diff(th_interp(worldT))
dphi = np.diff(phi_interp(worldT))
# calculate x-y shift for each worldcam frame  
number_of_iterations = 5000
termination_eps = 1e-4
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
warp_mode = cv2.MOTION_TRANSLATION

# Parallel Testing
world_vid_r = ray.put(world_vid)
warp_mode_r = ray.put(warp_mode)
criteria_r = ray.put(criteria)
dt = 60
result_ids = []
[result_ids.append(shift_vid_parallel.remote(i, world_vid_r, warp_mode_r, criteria_r, dt)) for i in range(0, max_frames, dt)]
results_p = ray.get(result_ids)
results_p = np.array(results_p).transpose(0,2,1).reshape(-1,3)

xshift = results_p[:,0]
yshift = results_p[:,1]
cc = results_p[:,2]

xmodel = LinearRegression()
ymodel = LinearRegression()

# eye data as predictors
eyeData = np.zeros((max_frames,2))
eyeData[:,0] = dth[0:max_frames]
eyeData[:,1] = dphi[0:max_frames]
# shift in x and y as outputs
xshiftdata = xshift[0:max_frames]
yshiftdata = yshift[0:max_frames]
# only use good data
# not nans, good correlation between frames, small eye movements (no sacccades, only compensatory movements)
usedata = ~np.isnan(eyeData[:,0]) & ~np.isnan(eyeData[:,1]) & (cc>0.95)  & (np.abs(eyeData[:,0])<2) & (np.abs(eyeData[:,1])<2) & (np.abs(xshiftdata)<5) & (np.abs(yshiftdata)<5)

# fit xshift
xmodel.fit(eyeData[usedata,:],xshiftdata[usedata])
xmap = xmodel.coef_
xrscore = xmodel.score(eyeData[usedata,:],xshiftdata[usedata])
# fit yshift
ymodel.fit(eyeData[usedata,:],yshiftdata[usedata])
ymap = ymodel.coef_
yrscore = ymodel.score(eyeData[usedata,:],yshiftdata[usedata])
par_time = time.time() - start
print("duration =", par_time)
del results_p, warp_mode_r, criteria_r, result_ids
gc.collect()

In [None]:
# diagnostic plots
fig = plt.figure(figsize = (8,6))
plt.subplot(2,2,1)
plt.plot(dth[0:max_frames],xshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('xshift')
plt.title('xmap = '+str(xmap))
plt.subplot(2,2,2)
plt.plot(dth[0:max_frames],yshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('yshift')
plt.title('ymap = '+str(ymap))
plt.subplot(2,2,3)
plt.plot(dphi[0:max_frames],xshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('xshift')
plt.subplot(2,2,4)
plt.plot(dphi[0:max_frames],yshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('yshift')
plt.tight_layout()

In [None]:
start = time.time()
print('estimating eye-world calibration')
# fig, xmap, ymap = eye_shift_estimation(th, phi, eyeT, world_vid,worldT,60*60)
xcorrection_r = ray.put(xmap.copy())
ycorrection_r = ray.put(ymap.copy())
print('shifting worldcam for eyes')
# worldT_r = ray.put(world_data.timestamps.copy())

thInterp_r = ray.put(th_interp(worldT))
phiInterp_r = ray.put(phi_interp(worldT))
dt = 1000
dt_r = ray.put(dt)

# thWorld = ray.put(thInterp(worldT))
# phiWorld = ray.put(phiInterp(worldT))

result_ids2 = []
[result_ids2.append(shift_world_pt2.remote(f, dt, world_vid_r, thInterp_r, phiInterp_r, ycorrection_r, xcorrection_r)) for f in range(0, world_vid.shape[0], dt)] # 
results = ray.get(result_ids2)
world_vid = np.concatenate(results,axis=0).astype(np.uint8)
print('saving worldcam video corrected for eye movements')
# np.save(file=os.path.join(file_dict['save'], 'corrected_worldcam.npy'), arr=world_vid)
par_time = time.time() - start
print("duration =", par_time)

In [None]:
save_dir = Path('~/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1').expanduser()
# np.save(file= save_dir / 'corrected_worldcam.npy', arr=world_vid)
world_vid = np.load(save_dir / 'corrected_worldcam.npy', mmap_mode='r')

In [None]:
start = time.time()
world_vid = world_vid.astype(float)
std_im = np.std(world_vid, axis=0, dtype=float)
img_norm = ((world_vid-np.mean(world_vid,axis=0,dtype=float))/std_im).astype(float)
std_im[std_im<20] = 0
img_norm = (img_norm * (std_im>0)).astype(float)
del world_vid
gc.collect()
print('Done Loading!')
print("duration =", time.time() - start)

In [None]:
# get number of good units
n_units = len(goodcells)
print('doing GLM receptive field estimate')
# simplified setup for GLM
# these are general parameters (spike rates, eye position)
n_units = len(goodcells)
print('get timing')
model_dt = 0.025
model_t = np.arange(0,np.max(worldT),model_dt)
model_nsp = np.zeros((n_units,len(model_t)))

# get spikes / rate
print('get spikes')
bins = np.append(model_t,model_t[-1]+model_dt)
for i,ind in enumerate(goodcells.index):
    model_nsp[i,:],bins = np.histogram(goodcells.at[ind,'spikeT'],bins)

# get eye position
print('get eye')
model_th = th_interp(model_t+model_dt/2)
model_phi = phi_interp(model_t+model_dt/2)
# del thInterp, phiInterp

# get active times
if free_move:
    interp = interp1d(accT,(gz-np.mean(gz))*7.5,bounds_error=False)
    model_gz = interp(model_t)
    model_active = np.convolve(np.abs(model_gz),np.ones(int(1/model_dt)),'same')
    use = np.where((np.abs(model_th)<10) & (np.abs(model_phi)<10)& (model_active>40))[0]
    roll_interp = interp1d(accT,groll,bounds_error=False)
    pitch_interp = interp1d(accT,gpitch,bounds_error=False)
    model_roll = roll_interp(model_t)
    model_pitch = pitch_interp(model_t)
else:
    use = np.where((np.abs(model_th)<10) & (np.abs(model_phi)<10))[0]
# get video ready for GLM
downsamp = 0.25
print('setting up video') 
movInterp = interp1d(worldT,img_norm,'nearest',axis = 0,bounds_error = False) 
testimg = movInterp(model_t[0])
testimg = cv2.resize(testimg,(int(np.shape(testimg)[1]*downsamp), int(np.shape(testimg)[0]*downsamp)))
testimg = testimg[5:-5,5:-5]; #remove area affected by eye movement correction
model_vid_sm = np.zeros((len(model_t),int(np.shape(testimg)[0]),int(np.shape(testimg)[1])),dtype=float)
for i in tqdm(range(len(model_t))):
    model_vid = movInterp(model_t[i] + model_dt/2)
    smallvid = cv2.resize(model_vid,(int(np.shape(img_norm)[2]*downsamp),int(np.shape(img_norm)[1]*downsamp)),interpolation = cv2.INTER_AREA)
    smallvid = smallvid[5:-5,5:-5]
    #smallvid = smallvid - np.mean(smallvid)
    model_vid_sm[i,:] = smallvid
nks = np.shape(smallvid); nk = nks[0]*nks[1]
model_vid_sm[np.isnan(model_vid_sm)]=0
del movInterp
gc.collect()


In [None]:
model_vid_sm.shape, model_nsp.shape, model_t.shape, model_th.shape, model_phi.shape, model_roll.shape, model_pitch.shape

In [None]:
data = {'model_vid_sm': model_vid_sm,
        'model_nsp': model_nsp,
        'model_t': model_t,
        'model_th': model_th,
        'model_phi': model_phi,
        'model_roll': model_roll,
        'model_pitch': model_pitch,}

In [None]:
ioh5.save(save_dir / 'ModelData.h5', data)

In [None]:
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

# Sequential Testing

In [None]:

def shift_vid(i, world_vid,wrap_mode,criteria):
    warp_matrix = np.eye(2, 3, dtype=np.float32)
    try: 
        (cc, warp_matrix) = cv2.findTransformECC(world_vid[i,:,:], world_vid[i+1,:,:], warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=1)
        xshift = warp_matrix[0,2]
        yshift = warp_matrix[1,2]
    except:
        cc = np.nan
        xshift=np.nan
        yshift = np.nan
    return xshift, yshift, cc

In [None]:
def fit_glm_vid(model_vid, model_nsp, model_dt, use, nks):
    """
    calculate GLM spatial receptive field
    INPUTS
        model_vid: video as array
        model_nsp: binned number of spikes
        model_dt: dt
        use: frames when animal is active
        nks: dimensions of video
    OUTPUTS
        sta_all: receptive fields for each unit
        cc_all: cross correlation for each unit
        fig: figure
    """
    nT = np.shape(model_nsp)[1]
    x = model_vid.copy()
    # image dimensions
    nk  = nks[0] * nks[1]
    n_units = np.shape(model_nsp)[0]
    # subtract mean and renormalize -- necessary? 
    mn_img = np.mean(x[use,:],axis=0)
    x = x-mn_img
    x = x/np.std(x[use,:],axis =0)
    x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones
    x = x[use,:]
    # set up prior matrix (regularizer)
    # L2 prior
    Imat = np.eye(nk)
    Imat = linalg.block_diag(Imat,np.zeros((1,1)))
    # smoothness prior
    consecutive = np.ones((nk, 1))
    consecutive[nks[1]-1::nks[1]] = 0
    diff = np.zeros((1,2))
    diff[0,0] = -1
    diff[0,1]= 1
    Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
    Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
    Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
    D  = linalg.block_diag(Dx.toarray(),np.zeros((1,1)))      
    # summed prior matrix
    Cinv = D + Imat
    lag_list = [ -4, -2, 0 , 2, 4]
    lambdas = 1024 * (2**np.arange(0,16))
    nlam = len(lambdas)
    # set up empty arrays for receptive field and cross correlation
    sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
    cc_all = np.zeros((n_units,len(lag_list)))
    # iterate through units
    for celln in tqdm(range(n_units)):
        # iterate through timing lags
        for lag_ind, lag in enumerate(lag_list):
            sps = np.roll(model_nsp[celln,:],-lag)
            sps = sps[use]
            nT = len(sps)
            #split training and test data
            test_frac = 0.3
            ntest = int(nT*test_frac)
            x_train = x[ntest:,:] ; sps_train = sps[ntest:]
            x_test = x[:ntest,:]; sps_test = sps[:ntest]
            #calculate a few terms
            sta = x_train.T@sps_train/np.sum(sps_train)
            XXtr = x_train.T @ x_train
            XYtr = x_train.T @sps_train
            msetrain = np.zeros((nlam,1))
            msetest = np.zeros((nlam,1))
            w_ridge = np.zeros((nk+1,nlam))
            # initial guess
            w = sta
            # loop over regularization strength
            for l in range(len(lambdas)):  
                # calculate MAP estimate               
                w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
                w_ridge[:,l] = w
                # calculate test and training rms error
                msetrain[l] = np.mean((sps_train - x_train@w)**2)
                msetest[l] = np.mean((sps_test - x_test@w)**2)
            # select best cross-validated lambda for RF
            best_lambda = np.argmin(msetest)
            w = w_ridge[:,best_lambda]
            ridge_rf = w_ridge[:,best_lambda]
            sta_all[celln,lag_ind,:,:] = np.reshape(w[:-1],nks)
            # plot predicted vs actual firing rate
            # predicted firing rate
            sp_pred = x_test@ridge_rf
            # bin the firing rate to get smooth rate vs time
            bin_length = 80
            sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
            pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
            # a few diagnostics
            err = np.mean((sp_smooth-pred_smooth)**2)
            cc = np.corrcoef(sp_smooth, pred_smooth)
            cc_all[celln,lag_ind] = cc[0,1]
    # figure of receptive fields
    fig = plt.figure(figsize=(10,np.int(np.ceil(n_units/3))),dpi=50)
    for celln in tqdm(range(n_units)):
        for lag_ind, lag in enumerate(lag_list):
            crange = np.max(np.abs(sta_all[celln,:,:,:]))
            plt.subplot(n_units,6,(celln*6)+lag_ind + 1)  
            plt.imshow(sta_all[celln, lag_ind, :, :], vmin=-crange, vmax=crange, cmap='jet')
            plt.title('cc={:.2f}'.format (cc_all[celln,lag_ind]),fontsize=5)
    return sta_all, cc_all, fig

def eye_shift_estimation(th, phi, eyeT, world_vid, worldT, max_frames=3600):
    """
    do a simple shift of the worldcam using eye parameters
    aim is to approximate the visual scene
    INPUTS
        th: theta as an array
        phi: phi as an array
        eyeT: eye timestamps
        world_vid: worldcam video as an array
        worldT: worldcam timestamps
        max_frames: number of frames to use in the estimation
    OUTPUTS
        fig: figure
        xmap: worldcam x-correction factor
        ymap: worldcam y-correction factor
    """
    # get eye displacement for each worldcam frame
    th_interp = interp1d(eyeT, th, bounds_error=False)
    phi_interp = interp1d(eyeT, phi, bounds_error=False)
    dth = np.diff(th_interp(worldT))
    dphi = np.diff(phi_interp(worldT))
    # calculate x-y shift for each worldcam frame  
    number_of_iterations = 5000
    termination_eps = 1e-4
    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
    warp_mode = cv2.MOTION_TRANSLATION
    cc = np.zeros(max_frames)
    xshift = np.zeros(max_frames)
    yshift = np.zeros(max_frames)
    warp_all = np.zeros((6,max_frames))
    # get shift between adjacent frames
    
    for i in tqdm(range(max_frames)):
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try: 
            (cc[i], warp_matrix) = cv2.findTransformECC(world_vid[i,:,:], world_vid[i+1,:,:], warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=1)
            xshift[i] = warp_matrix[0,2]
            yshift[i] = warp_matrix[1,2]
        except:
            cc[i] = np.nan
            xshift[i]=np.nan
            yshift[i] = np.nan
    # perform regression to predict frameshift based on eye shifts
    # set up models
    xmodel = LinearRegression()
    ymodel = LinearRegression()
    # eye data as predictors
    eyeData = np.zeros((max_frames,2))
    eyeData[:,0] = dth[0:max_frames]
    eyeData[:,1] = dphi[0:max_frames]
    # shift in x and y as outputs
    xshiftdata = xshift[0:max_frames]
    yshiftdata = yshift[0:max_frames]
    # only use good data
    # not nans, good correlation between frames, small eye movements (no sacccades, only compensatory movements)
    usedata = ~np.isnan(eyeData[:,0]) & ~np.isnan(eyeData[:,1]) & (cc>0.95)  & (np.abs(eyeData[:,0])<2) & (np.abs(eyeData[:,1])<2) & (np.abs(xshiftdata)<5) & (np.abs(yshiftdata)<5)
    # fit xshift
    xmodel.fit(eyeData[usedata,:],xshiftdata[usedata])
    xmap = xmodel.coef_
    xrscore = xmodel.score(eyeData[usedata,:],xshiftdata[usedata])
    # fit yshift
    ymodel.fit(eyeData[usedata,:],yshiftdata[usedata])
    ymap = ymodel.coef_
    yrscore = ymodel.score(eyeData[usedata,:],yshiftdata[usedata])
   # diagnostic plots
    fig = plt.figure(figsize = (8,6))
    plt.subplot(2,2,1)
    plt.plot(dth[0:max_frames],xshift[0:max_frames],'.')
    plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
    plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('xshift')
    plt.title('xmap = '+str(xmap))
    plt.subplot(2,2,2)
    plt.plot(dth[0:max_frames],yshift[0:max_frames],'.')
    plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
    plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('yshift')
    plt.title('ymap = '+str(ymap))
    plt.subplot(2,2,3)
    plt.plot(dphi[0:max_frames],xshift[0:max_frames],'.')
    plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
    plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('xshift')
    plt.subplot(2,2,4)
    plt.plot(dphi[0:max_frames],yshift[0:max_frames],'.')
    plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
    plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('yshift')
    plt.tight_layout()
    return fig, xmap, ymap

In [None]:
start = time.time()
max_frames = 60*60
# get eye displacement for each worldcam frame
th_interp = interp1d(eyeT, th, bounds_error=False)
phi_interp = interp1d(eyeT, phi, bounds_error=False)
dth = np.diff(th_interp(worldT))
dphi = np.diff(phi_interp(worldT))
# calculate x-y shift for each worldcam frame  
number_of_iterations = 5000
termination_eps = 1e-4
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
warp_mode = cv2.MOTION_TRANSLATION
cc = np.zeros(max_frames)
xshift = np.zeros(max_frames)
yshift = np.zeros(max_frames)
warp_all = np.zeros((6,max_frames))
# get shift between adjacent frames

for i in tqdm(range(max_frames)):
    warp_matrix = np.eye(2, 3, dtype=np.float32)
    try: 
        (cc[i], warp_matrix) = cv2.findTransformECC(world_vid[i,:,:], world_vid[i+1,:,:], warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=1)
        xshift[i] = warp_matrix[0,2]
        yshift[i] = warp_matrix[1,2]
    except:
        cc[i] = np.nan
        xshift[i]=np.nan
        yshift[i] = np.nan
        

# perform regression to predict frameshift based on eye shifts
# set up models
xmodel = LinearRegression()
ymodel = LinearRegression()
# eye data as predictors
eyeData = np.zeros((max_frames,2))
eyeData[:,0] = dth[0:max_frames]
eyeData[:,1] = dphi[0:max_frames]
# shift in x and y as outputs
xshiftdata = xshift[0:max_frames]
yshiftdata = yshift[0:max_frames]
# only use good data
# not nans, good correlation between frames, small eye movements (no sacccades, only compensatory movements)
usedata = ~np.isnan(eyeData[:,0]) & ~np.isnan(eyeData[:,1]) & (cc>0.95)  & (np.abs(eyeData[:,0])<2) & (np.abs(eyeData[:,1])<2) & (np.abs(xshiftdata)<5) & (np.abs(yshiftdata)<5)
# fit xshift
xmodel.fit(eyeData[usedata,:],xshiftdata[usedata])
xmap = xmodel.coef_
xrscore = xmodel.score(eyeData[usedata,:],xshiftdata[usedata])
# fit yshift
ymodel.fit(eyeData[usedata,:],yshiftdata[usedata])
ymap = ymodel.coef_
yrscore = ymodel.score(eyeData[usedata,:],yshiftdata[usedata])
seq_time = time.time() - start
print("duration =", seq_time)

In [None]:
xshiftdata.shape

In [None]:
eyeData[usedata,:].shape

In [None]:
# diagnostic plots
fig = plt.figure(figsize = (8,6))
plt.subplot(2,2,1)
plt.plot(dth[0:max_frames],xshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('xshift')
plt.title('xmap = '+str(xmap))
plt.subplot(2,2,2)
plt.plot(dth[0:max_frames],yshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dtheta'); plt.ylabel('yshift')
plt.title('ymap = '+str(ymap))
plt.subplot(2,2,3)
plt.plot(dphi[0:max_frames],xshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('xshift')
plt.subplot(2,2,4)
plt.plot(dphi[0:max_frames],yshift[0:max_frames],'.')
plt.plot([-5, 5], [5, -5],'r'); plt.xlim(-12,12)
plt.ylim(-12,12); plt.xlabel('dphi'); plt.ylabel('yshift')
plt.tight_layout()

In [None]:
start = time.time()

print('estimating eye-world calibration')
# fig, xmap, ymap = eye_shift_estimation(th, phi, eyeT, world_vid,worldT,60*60)
xcorrection = xmap.copy()
ycorrection = ymap.copy()
print('shifting worldcam for eyes')
thInterp =interp1d(eyeT,th, bounds_error = False)
phiInterp =interp1d(eyeT,phi, bounds_error = False)
thWorld = thInterp(worldT)
phiWorld = phiInterp(worldT)
for f in tqdm(range(3000)):#np.shape(world_vid)[0])):
    world_vid[f,:,:] = imshift(world_vid[f,:,:],(-np.int8(thInterp(worldT[f])*ycorrection[0] + phiInterp(worldT[f])*ycorrection[1]),
                                                 -np.int8(thInterp(worldT[f])*xcorrection[0] + phiInterp(worldT[f])*xcorrection[1])))
    
seq_time = time.time() - start
print("duration =", seq_time)
print('saving worldcam video corrected for eye movements')
# np.save(file=os.path.join(file_dict['save'], 'corrected_worldcam.npy'), arr=world_vid)
# std_im = np.std(world_vid,axis=0)
# img_norm = (world_vid-np.mean(world_vid,axis=0))/std_im
# std_im[std_im<20] = 0
# img_norm = img_norm * (std_im>0)

# Loading Other data

In [None]:
DataPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'data/070921/J553RT/fm1')
DataFile = list(DataPath.glob('*.h5'))[0]

In [None]:
df = pd.read_hdf(DataFile)

In [None]:
df.keys()

In [None]:
df['fm1_dEye'][7].shape, df['spikeT'][7].shape, df['fm1_pitch_interp'][7].shape, df['fm1_theta'][7].shape, df['rate'][7].shape

In [None]:
n = 3
ex_n = df['spikeT'].iloc[n]
ex_n_rate = df['rate'].iloc[n]

In [None]:
t = 100
dt = 1000
norm_roll = df['fm1_roll_interp'].iloc[n][t:t+dt]/np.max(np.abs(df['fm1_roll_interp'].iloc[n][t:t+dt]))
norm_roll = norm_roll-norm_roll[0]
norm_rate = ex_n_rate[t:t+dt]/np.max(ex_n_rate[t:t+dt])
plt.plot(norm_rate)
plt.plot(norm_roll)

In [None]:
plt.scatter(norm_roll,norm_rate)