In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import numpy as np
import pandas as pd
import xarray as xr
import panel as pn
import holoviews as hv
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact

import os
import sys
import random
import importlib
import datetime
import warnings
import math
import cmath
warnings.filterwarnings('ignore')

from pprint import pprint
from itertools import zip_longest
from scipy.stats import norm, binned_statistic, ttest_1samp, ttest_ind

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
# plt.rcParams['font.size'] = 15
cm = 1/2.54  # centimeters in inches

In [None]:
def plot_free_ethogram(experiment, cells=None, num_cells=15, **kwargs):

    kinem = experiment.kinematics
    time = kinem.time_vector

    # Select cells
    if cells is None:
        cells = np.random.choice(experiment.cells, num_cells)
    
    kine_labels = ['trials', 'speed\n [m/s]', 'heading\n [$^\circ$]', 'ahv\n [$^\circ$/sec]', 'head dir.\n [$^\circ$]', \
                   'head height\n [m]', 'head pitch\n [$^\circ$]', 'head yaw\n [$^\circ$]', 'head roll\n [$^\circ$]'] 

    fig, axes = plt.subplots(nrows=len(kine_labels)+len(list(cells)), ncols=1, sharex=True, **kwargs)

    # Plot trials and kinematics
    trial_tick = np.where(kinem.trial_num > 0, 1, 0)
    axes[0].plot(time, trial_tick, 'k')
    
    
    axes[1].plot(time, np.abs(kinem.mouse_speed), 'k', alpha=0.8)
    # axes[1].set_ylim((0, 1))

    axes[2].plot(time, fk.smooth_trace(kinem.mouse_heading, range=(-180,180)), 'k', alpha=0.8)
    axes[2].set_ylim((-180, 180))

    axes[3].plot(time, kinem.mouse_angular_speed, 'k', alpha=0.8)

    # smoothed_x = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_x, 3), 15)
    # smoothed_x -= np.mean(smoothed_x)
    # smoothed_y = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_y, 3), 15)
    # smoothed_y -= np.mean(smoothed_y)
    axes[4].plot(time, kinem.head_direction, 'k', alpha=0.8)
    axes[4].set_ylim((-180, 180))
    
    axes[5].plot(time, kinem.head_height, 'k', alpha=0.8)
    
    axes[6].plot(time, kinem.head_pitch, 'k', alpha=0.8)
    axes[6].set_ylim((-180, 180))

    axes[7].plot(time, kinem.head_yaw, 'k', alpha=0.8)
    axes[7].set_ylim((-180, 180))
    
    axes[8].plot(time, kinem.head_roll, 'k', alpha=0.8)
    axes[8].set_ylim((-180, 180))
    

    # Plot neural data
    for i, cell in enumerate(cells):
        axes[i+len(kine_labels)].plot(time, exp.norm_dff[cell], 'k', alpha=0.8)
        axes[i+len(kine_labels)].plot(time, exp.norm_spikes[cell], 'g', alpha=0.6)
        axes[i+len(kine_labels)].set_ylim((0, 1))
        # if i == 0:
        #     axes[i+len(kine_labels)].legend(['dFF','inf. spikes'], loc='right')

    # Remove ticks and set axis labels
    axes_labels = kine_labels + [f'cell {n+1}' for n, cell in enumerate(cells)]
    for i, a in enumerate(axes):
        # if i < len(axes_labels):
        a.set_ylabel(axes_labels[i].capitalize(), rotation=0, va='center', wrap=True, labelpad=25)
            
        a.spines["top"].set_visible(False)
        a.spines["right"].set_visible(False)
        a.yaxis.set_tick_params(pad=5)
        if i < len(axes)-1:
            a.spines["bottom"].set_visible(False)
            a.xaxis.set_tick_params(bottom=False)
    
    axes[-1].set_xlabel ("Time (s)")
    # plt.tight_layout()
    fig.align_ylabels()

    return fig

In [None]:
def plot_free_ethogram_hv(experiment, cells=None, num_cells=15, **kwargs):

    kinem = experiment.kinematics
    time = kinem.time_vector

    # Select cells
    if cells is None:
        cells = np.random.choice(experiment.cells, num_cells)
    
    kine_labels = ['trials', 'speed\n [m/s]', 'heading\n [$^\circ$]', 'ahv\n [$^\circ$/sec]', 'head dir.\n [$^\circ$]', \
                   'head height\n [m]', 'head pitch\n [$^\circ$]', 'head yaw\n [$^\circ$]', 'head roll\n [$^\circ$]'] 


    # Plot trials and kinematics
    trial_tick = np.where(kinem.trial_num > 0, 1, 0)
    trial_plot = hv.Curve((time, trial_tick)).opts(color='k', xaxis=None)    
    speed_plot = hv.Curve((time, np.abs(kinem.mouse_speed))).opts(color='k', xaxis=None)
    heading_plot = hv.Curve((time, fk.smooth_trace(kinem.mouse_heading, range=(-180,180)))).opts(color='k')

    plt_list = [trial_plot, speed_plot, heading_plot]
    
    layout = hv.Layout(plt_list).cols(1).opts(shared_axes=True)
    


    # axes[2].plot(time, fk.smooth_trace(kinem.mouse_heading, range=(-180,180)), 'k', alpha=0.8)
    # axes[2].set_ylim((-180, 180))

    # axes[3].plot(time, kinem.mouse_angular_speed, 'k', alpha=0.8)

    # # smoothed_x = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_x, 3), 15)
    # # smoothed_x -= np.mean(smoothed_x)
    # # smoothed_y = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_y, 3), 15)
    # # smoothed_y -= np.mean(smoothed_y)
    # axes[4].plot(time, kinem.head_direction, 'k', alpha=0.8)
    # axes[4].set_ylim((-180, 180))
    
    # axes[5].plot(time, kinem.head_height, 'k', alpha=0.8)
    
    # axes[6].plot(time, kinem.head_pitch, 'k', alpha=0.8)
    # axes[6].set_ylim((-180, 180))

    # axes[7].plot(time, kinem.head_yaw, 'k', alpha=0.8)
    # axes[7].set_ylim((-180, 180))
    
    # axes[8].plot(time, kinem.head_roll, 'k', alpha=0.8)
    # axes[8].set_ylim((-180, 180))
    

    # # Plot neural data
    # for i, cell in enumerate(cells):
    #     axes[i+len(kine_labels)].plot(time, exp.norm_dff[cell], 'k', alpha=0.8)
    #     axes[i+len(kine_labels)].plot(time, exp.norm_spikes[cell], 'g', alpha=0.6)
    #     axes[i+len(kine_labels)].set_ylim((0, 1))
    #     # if i == 0:
    #     #     axes[i+len(kine_labels)].legend(['dFF','inf. spikes'], loc='right')

    # # Remove ticks and set axis labels
    # axes_labels = kine_labels + [f'cell {n+1}' for n, cell in enumerate(cells)]
    # for i, a in enumerate(axes):
    #     # if i < len(axes_labels):
    #     a.set_ylabel(axes_labels[i].capitalize(), rotation=0, va='center', wrap=True, labelpad=25)
            
    #     a.spines["top"].set_visible(False)
    #     a.spines["right"].set_visible(False)
    #     a.yaxis.set_tick_params(pad=5)
    #     if i < len(axes)-1:
    #         a.spines["bottom"].set_visible(False)
    #         a.xaxis.set_tick_params(bottom=False)
    
    # axes[-1].set_xlabel ("Time (s)")
    # # plt.tight_layout()
    # fig.align_ylabels()

    return layout

In [None]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string + r", rig:VTuningWF, analysis_type:preprocessing"

# get the paths from the database
file_path, paths_all, parsed_query, date_list, animal_list = fdh.fetch_preprocessing(search_string)

animal_idxs = [i for i,d in enumerate(animal_list) if d==parsed_query['mouse'].lower()]
good_entries = [file_path[index] for index in animal_idxs]
input_path = [paths_all[index] for index in animal_idxs]

# # assemble the output path
print(input_path)

In [None]:
exp = WirefreeExperiment(input_path[0], use_xarray=False)
exp.dff = tuning.calculate_dff(exp.raw_fluor)
exp.norm_spikes = tuning.normalize_responses(exp.raw_spikes)
exp.norm_fluor = tuning.normalize_responses(exp.raw_fluor)
exp.norm_dff = tuning.normalize_responses(exp.dff)

if 'head_pitch' not in exp.kinematics.columns:
    pitch = -fk.wrap_negative(exp.kinematics.mouse_xrot_m.values)
    exp.kinematics['head_pitch'] = fk.smooth_trace(pitch, range=(-180, 180), kernel_size=10, discont=2*np.pi)
    
    yaw = fk.wrap_negative(exp.kinematics.mouse_zrot_m.values)
    exp.kinematics['head_yaw'] = fk.smooth_trace(yaw, range=(-180, 180), kernel_size=10, discont=2*np.pi)
    
    roll = fk.wrap_negative(exp.kinematics.mouse_yrot_m.values)
    exp.kinematics['head_roll'] = fk.smooth_trace(roll, range=(-180, 180), kernel_size=10, discont=2*np.pi)

In [None]:
%matplotlib widget
fig = plot_free_ethogram(exp, num_cells=10, figsize=(15,15))
save_path = os.path.join(r'C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster', f'free_ethogram.png')
# fp.save_figure(fig, fig_width=8, save_path=save_path, fontsize='poster', target='both')
# plt.savefig(save_path, dpi=600)

In [None]:
fig = plot_free_ethogram_hv(exp, num_cells=10, figsize=(15,15))
fig.opts(height=200, width=10000)
# save_path = os.path.join(r'C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster', f'free_ethogram.png')
# fp.save_figure(fig, fig_width=8, save_path=save_path, fontsize='poster', target='both')
# plt.savefig(save_path, dpi=600)

In [None]:
%matplotlib inline
a = exp.kinematics.mouse_speed*100
hist = a.hist(bins=20)

# Filter by head pitch
Get rid of trials where the mouse is consistently looking above the upper edge of the arena

In [None]:
pitch_upper_cutoff = 20
pitch_lower_cutoff = -90
view_fraction = 0.7
exp.kinematics['viewed'] = np.logical_and(exp.kinematics.head_pitch >= pitch_lower_cutoff, exp.kinematics.head_pitch <= pitch_upper_cutoff)
viewed_trials = exp.kinematics.groupby('trial_num').filter(lambda x: (x['viewed'].sum() / len(x['viewed'])) > view_fraction).trial_num.unique()

norm_spikes_viewed = exp.norm_spikes.loc[exp.norm_spikes.trial_num.isin(viewed_trials)]
exp.norm_spikes_viewed = norm_spikes_viewed

# Filter by running speed

In [None]:
# Get an idea of the cutoff to be considered still
speed_cutoff = np.percentile(np.abs(exp.kinematics.mouse_speed), 80)
exp.kinematics['is_running'] = np.abs(exp.kinematics.mouse_speed) >= speed_cutoff
exp.kinematics['mouse_speed_abs'] = np.abs(exp.kinematics['mouse_speed'])
# fig = plt.figure()
# plt.plot(exp.kinematics.time_vector, np.abs(exp.kinematics.mouse_speed)) 
# plt.hlines(speed_cutoff, 0, exp.kinematics.time_vector.max(), colors='r')

In [None]:
still_trials = exp.kinematics.groupby('trial_num').filter(lambda x: x['mouse_speed_abs'].mean() < speed_cutoff).trial_num.unique()
still_trials = viewed_trials[np.in1d(viewed_trials, still_trials)]
still_spikes = exp.norm_spikes_viewed.loc[exp.norm_spikes_viewed.trial_num.isin(still_trials)]
exp.norm_spikes_viewed_still = still_spikes.copy()

# Run the TC calculation loop

In [None]:
from snakemake_scripts.wf_tc_calculate import calculate_visual_tuning

In [None]:
%%time
datasets = ['norm_spikes', 'norm_spikes_viewed', 'norm_spikes_viewed_still'] # 'raw_spikes 
for dataset in datasets: 
    activity = getattr(exp, dataset)
    
    for tuning_type in ['direction_wrapped', 'orientation']:
        props = calculate_visual_tuning(activity, tuning_type, bootstrap_shuffles=1)
        
        tuning_label = tuning_type.split('_')[0]
        setattr(exp, f'{dataset}_{tuning_label}_props', props)

In [None]:
[attr for attr in exp.list_attributes() if ('spikes'in attr) and ('props'not in attr)]

# Check out tuning curves

## Visual tuning

In [None]:
%matplotlib inline
interact(plot_visual_tuning_curves,
         experiment = widgets.fixed(exp),
         cell=exp.cells, 
         data_type=datasets, 
         error = ['std', 'sem'],
         polar=[True, False],
         norm=widgets.fixed(True),
         axes=widgets.fixed(None))

In [None]:
fig = plt.figure()
plt.plot(exp.kinematics.head_direction)
plt.plot(-exp.kinematics.yaw)

In [None]:
cell = 'cell_0000'
a = exp.norm_spikes_viewed.groupby(['orientation', 'trial_num'])[cell].agg(list)
a = a.droplevel('trial_num').groupby('orientation').agg(list)
a = a.apply(misc.list_lists_to_array).reset_index()
a = a.drop(a[a['orientation'] == -1000].index)
a['mean_trace'] = a[cell].map(lambda x: np.nanmean(x, axis=0))
fig = plt.figure(figsize=(5,5))
for idx, row in a.iterrows():
    plt.plot(row['mean_trace'][:120])
plt.legend([f'{val:.1f}' for val in a['orientation'].values], ncol=3)

In [None]:
directions = a['orientation'].to_numpy()
mean_resp = a['mean_trace'].to_list()
mean_resp = misc.list_lists_to_array(mean_resp)
fig = plt.figure(figsize=(20,20))
ax = plt.axes()
im = plt.imshow(mean_resp[:, :120], cmap='Reds')
# plt.xticks(np.arange(0, 115), np.arange(0, 115*0.05, 0.05))
plt.yticks(np.arange(0, mean_resp.shape[0]), directions)


cax = fig.add_axes([ax.get_position().x1 + 0.01, ax.get_position().y0, 0.02, ax.get_position().height])
cbar = plt.colorbar(im, cax=cax)
cbar.set_label('Activity (a.u.)', rotation=270)

In [None]:
np.mean(np.diff(exp.norm_spikes_viewed.time_vector))