In [3]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget
import numpy as np
import os
import sys
import h5py
import cv2
import importlib
import holoviews as hv
hv.extension('bokeh')

import pandas as pd
import matplotlib.pyplot as plt
from skimage import color

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import processing_parameters
import functions_bondjango as bd
import functions_misc as fmisc
import functions_matching as fm
import functions_data_handling as fdh
import functions_tuning as tuning
from wirefree_experiment import WirefreeExperiment, DataContainer
from functions_wirefree_trigger_fix import get_trial_duration_stats

In [1]:
def get_footprint_centroids(calcium_data):
    cents = []
    for cell in calcium_data:
        new_cell = cell.copy()
        new_cell[new_cell > 0] == 1 
        M = cv2.moments(new_cell)
        
        # centroid calciulation
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        cents.append([cX, cY])
    return cents

def get_footprint_contours(calcium_data):
    contour_list = []
    contour_stats = []
    for frame in calcium_data:
        frame = frame * 255.
        frame = frame.astype(np.uint8)
        thresh = cv2.threshold(frame, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
        
      # get contours and filter out small defects
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Only take the largest contour
        cntr = max(contours, key=cv2.contourArea)

        # # get contours and filter out small defects
        # contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # # Only take the largest contour
        # cntr = contours[0]
        area = cv2.contourArea(cntr)
        perimeter = cv2.arcLength(cntr, True)
        compactness = 4*np.pi*area / (perimeter + 1e-16)**2
        
        contour_list.append(max(contours, key=cv2.contourArea))
        contour_stats.append((area, perimeter, compactness))

    return contour_list, np.array(contour_stats)

def make_contour_projection(contour_list, shape):
    contour_img = np.zeros(shape)
    for i, cntr in enumerate(contour_list):
        cv2.drawContours(contour_img[i, :], [cntr], 0, 1, 1)

    contour_img = np.sum(contour_img, axis=0)
    contour_img[contour_img > 0] = 1.0
    return contour_img

def get_binary_footprints(footprint_pic, threshold=0.1):
    bin_pic = np.zeros_like(footprint_pic)
    bin_pic[footprint_pic >= threshold] = 1
    return bin_pic

def make_rgb_overlay(max_proj, footprints, contour_img, channel='r'):
    max_proj -= max_proj.min()
    max_proj /= max_proj.max()

    # Make RGB max projection
    max_proj_rgb = np.dstack((max_proj, max_proj, max_proj))

    # mak RGB footprint image
    footprint_rgb = np.zeros((*max_proj.shape, 3))

    footprints /= footprints.max()
    if channel == 'r':
        footprint_rgb[:, :, 0] = footprints
    elif channel == 'g':
        footprint_rgb[:, :, 1] = footprints
    elif channel == 'b':
        footprint_rgb[:, :, 2] = footprints
    else:
        raise ValueError('channel must be r, g, or b')
    
    footprint_rgb[:] += np.expand_dims(contour_img, -1).astype(float)

    # Convert RGB max proj and RGB footprints to HSV colorspace
    max_proj_hsv = color.rgb2hsv(max_proj_rgb)
    footprint_mask_hsv = color.rgb2hsv(footprint_rgb)

    # Overlay the footprint mask on the max projection
    max_proj_hsv[..., 0] = footprint_mask_hsv[..., 0]
    max_proj_hsv[..., 1] = footprint_mask_hsv[..., 1] * 0.6

    # Return to RGB colorspace
    overlay = color.hsv2rgb(max_proj_hsv)
    return overlay

def hv_plot_FOVs(rigs, binary_footprints, contour_images, labels=None, overlay=True):
    binary_images = []

    for i, (rig, bin_pic) in enumerate(zip(rigs, binary_footprints)):
        # Plot all binarized ROIS with contours
        binary_image = hv.RGB(bin_pic.astype(float), bounds=(0, 0, 320, 320)).opts(title=rig)
        
        if labels is not None:
            cents = labels[i][:, :2]
            cents[:,1] = 320 - cents[:,1]
            label = labels[i][:, -1]
            label_plot = hv.Labels({('x', 'y'): cents, 'text': label}, ['x', 'y'], 'text').opts(text_color='white', xoffset=0.05, yoffset=0.05, text_font_size='8pt')
            #  hv.Labels((cents[:, 0], 320-cents[:, 1]), str(label[i]))
            binary_image = binary_image * label_plot

        binary_images.append(binary_image)

    if overlay:
        binary_overlay = hv.RGB(np.dstack((contour_images[0], np.zeros_like(contour_images[0]), contour_images[1])), bounds=(0, 0, 320, 320)).opts(title='Overlay')
        layout = hv.Layout(binary_images[0] + binary_images[1] + binary_overlay).cols(3)
    else:
        layout = hv.Layout(binary_images[0] + binary_images[1]).cols(2)

    return layout


In [4]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string
parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database
file_infos = bd.query_database('analyzed_data', search_string)
preproc_paths = np.sort(np.array([el['analysis_path'] for el in file_infos if (el['analysis_type'] == 'preprocessing') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
calcium_paths = np.sort(np.array([el['analysis_path'].replace('preproc', 'calciumraw') for el in file_infos if (el['analysis_type'] == 'preprocessing') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
tc_paths = np.sort(np.array([el['analysis_path'] for el in file_infos if (el['analysis_type'] == 'tc_analysis') and
                         (parsed_search['mouse'].lower() in el['slug'])]))
cell_matching_path = [el['analysis_path'] for el in file_infos if ('daycellmatch' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])]
rigs = np.array([os.path.basename(file).split('_')[6] for file in calcium_paths])
print(cell_matching_path)
print(calcium_paths)
print(preproc_paths)
print(tc_paths)
print(rigs)


['Z:\\Prey_capture\\AnalyzedData\\01_20_2023_MM_221109_a_dayCellMatch.hdf5']
['Z:\\Prey_capture\\AnalyzedData\\01_20_2023_12_04_38_VWheelWF_MM_221109_a_fixed1_gabor_calciumraw.hdf5']
['Z:\\Prey_capture\\AnalyzedData\\01_20_2023_12_04_38_VWheelWF_MM_221109_a_fixed1_gabor_preproc.hdf5']
['Z:\\Prey_capture\\AnalyzedData\\01_20_2023_12_04_38_VWheelWF_MM_221109_a_fixed1_gabor_tcday.hdf5']
['VWheelWF']


# Generate Contour Plots

In [None]:
# Load the assignments and find the column that corresponds to each file
assignments =  fm.match_cells(cell_matching_path[0])
new_cols = [col.split('_')[-2] for col in assignments.columns]
assignments.columns = new_cols
match_cols = [i for i, (rig, col) in enumerate(zip(rigs, assignments.columns)) if str(col) in rig]

# Use number of non-NaNs in each row to filter out components that were not registered in enough sessions
assignments_filtered = assignments.dropna().astype(int).to_numpy()
unassigned = np.array(assignments[np.sum(~np.isnan(assignments), axis=1) < 2])
unassigned = [unassigned[~np.isnan(unassigned[:, 0]), 0].astype(int), unassigned[~np.isnan(unassigned[:, 1]), 1].astype(int)]
unassigned = [np.sort(np.unique(unassigned[0])), np.sort(np.unique(unassigned[1]))]

In [None]:
# load the data for the matching
calcium_list = []
max_proj_list = []
footprint_list = []
contour_list = []
size_list = []
template_list = []
footprint_pics = []
countour_pics = []
centroids_list = []
binary_footprints = []
overlay_footprints = []
overlay_binary_footprints = []


# load the calcium data
for files, channel in zip(calcium_paths, ['b', 'r']):

    with h5py.File(files, mode='r') as f:

        try:
            calcium_data = np.array(f['A'])
            max_proj = np.array(f['max_proj'])     

        except KeyError:
            continue

    # if there are no ROIs, skip
    if (type(calcium_data) == np.ndarray) and np.any(calcium_data.astype(str) == 'no_ROIs'):
        continue
        
    # clear the rois that don't pass the size or compactness criteria
    roi_stats = fmisc.get_roi_stats(calcium_data)
    contours, contour_stats = get_footprint_contours(calcium_data)

    if len(roi_stats.shape) == 1:
        roi_stats = roi_stats.reshape(1, -1)
        contour_stats = contour_stats.reshape(1, -1)

    areas = roi_stats[:, -1]
    compactness = contour_stats[:, -1]

    keep_vector = (areas > processing_parameters.roi_parameters['area_min']) & \
                  (areas < processing_parameters.roi_parameters['area_max']) & \
                  (compactness > 0.5)

    if np.all(keep_vector == False):
        continue

    calcium_data = calcium_data[keep_vector, :, :]
    contours = [contours[i] for i, keep in enumerate(keep_vector) if keep]

    centroids = get_footprint_centroids(calcium_data)
    footprint_proj = np.sum(calcium_data, axis=0)
    binary_footprint_proj = get_binary_footprints(footprint_proj)
    contour_proj = make_contour_projection(contours, calcium_data.shape)
    
    # format and masks and store for matching
    calcium_list.append(calcium_data)
    footprint_list.append(np.moveaxis(calcium_data, 0, -1).reshape((-1, calcium_data.shape[0])))
    contour_list.append(contours)
    countour_pics.append(contour_proj)

    size_list.append(calcium_data.shape[1:])
    template_list.append(max_proj)
    max_proj_list.append((max_proj - max_proj.min())/ max_proj.max())
    footprint_pics.append(footprint_proj)
    binary_footprints.append(binary_footprint_proj)
    centroids_list.append(np.array(centroids))

    overlay_footprints.append(make_rgb_overlay(max_proj, footprint_proj, contour_proj, channel=channel))
    overlay_binary_footprints.append(make_rgb_overlay(max_proj, binary_footprint_proj, contour_proj, channel=channel))

# Use filtered indices to select the corresponding spatial components
spatial_filtered = footprint_list[0][:, assignments_filtered[:, 0]]
matched_footprints = np.sum(spatial_filtered.reshape(320, 320, spatial_filtered.shape[-1]), axis=-1)

# Filter footprints and contours based on the matching
match_ca1 = calcium_list[0][assignments_filtered[:, match_cols[0]], :]
match_ca2 = calcium_list[1][assignments_filtered[:, match_cols[1]], :]

match_footprint_projs = [np.sum(match_ca1, axis=0),  np.sum(match_ca2, axis=0)]
match_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in match_footprint_projs]
match_centroids = [np.array(get_footprint_centroids(match_ca1)), np.array(get_footprint_centroids(match_ca2))]

match_contours1, _ = get_footprint_contours(match_ca1)
match_contours2, _ = get_footprint_contours(match_ca2)
match_contour_proj1 = make_contour_projection(match_contours1, match_ca1.shape)
match_contour_proj2 = make_contour_projection(match_contours2, match_ca2.shape)
match_binary_contour_projs = [match_contour_proj1, match_contour_proj2]
match_contours = [match_contours1, match_contours2]

match_overlay_binary_footprints = []
for i, channel in enumerate(['b', 'r']):
    match_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], match_binary_footprint_projs[i], match_binary_contour_projs[i], channel=channel))

# Filter unmatched footprints and contours based on the matching
unmatch_ca1 = calcium_list[0][unassigned[match_cols[0]], :]
unmatch_ca2 = calcium_list[1][unassigned[match_cols[1]], :]

unmatch_footprint_projs = [np.sum(unmatch_ca1, axis=0),  np.sum(unmatch_ca2, axis=0)]
unmatch_binary_footprint_projs = [get_binary_footprints(fp_proj) for fp_proj in unmatch_footprint_projs]
unmatch_centroids = [np.array(get_footprint_centroids(unmatch_ca1)), np.array(get_footprint_centroids(unmatch_ca2))]

unmatch_contours1, _ = get_footprint_contours(unmatch_ca1)
unmatch_contours2, _ = get_footprint_contours(unmatch_ca2)
unmatch_contour_proj1 = make_contour_projection(unmatch_contours1, unmatch_ca1.shape)
unmatch_contour_proj2 = make_contour_projection(unmatch_contours2, unmatch_ca2.shape)
unmatch_binary_contour_projs = [unmatch_contour_proj1, unmatch_contour_proj2]
unmatch_contours = [unmatch_contours1, unmatch_contours2]

unmatch_overlay_binary_footprints = []
for i, channel in enumerate(['b', 'r']):
    unmatch_overlay_binary_footprints.append(make_rgb_overlay(max_proj_list[i], unmatch_binary_footprint_projs[i], unmatch_binary_contour_projs[i], channel=channel))

## Plot all cells, matches, and not matched cells

In [None]:
all_cell_labels = [np.concatenate((centroids, np.arange(centroids.shape[0]).reshape(-1,1)), axis=1) for centroids in centroids_list]
all_cells = hv_plot_FOVs(rigs, overlay_binary_footprints, countour_pics, overlay=True, labels=all_cell_labels)

matched_cell_labels = [np.concatenate((match_centroids[i], assignments_filtered[:, i].reshape(-1,1)), axis=1) for i in np.arange(len(match_centroids))]
match_cells = hv_plot_FOVs(rigs, match_overlay_binary_footprints, match_binary_contour_projs, 
                           overlay=True, labels=matched_cell_labels).opts(hv.opts.RGB(title=''))

unmatched_cell_labels = [np.concatenate((unmatch_centroids[i], unassigned[i].reshape(-1,1)), axis=1) for i in np.arange(len(unmatch_centroids))]
unmatch_cells = hv_plot_FOVs(rigs, unmatch_overlay_binary_footprints, unmatch_binary_contour_projs, 
                             overlay=True, labels=unmatched_cell_labels).opts(hv.opts.RGB(title=''))

match_plot = hv.Layout(all_cells + match_cells + unmatch_cells).cols(3).opts(hv.opts.RGB(xlabel=None, ylabel=None, xaxis=None, yaxis=None))

match_plot

## Plot randomly selected matched cells

In [None]:
random_idxs = np.random.choice(np.arange(len(assignments_filtered)), 5)
random_matched_cell_labels = [np.concatenate((match_centroids[i][random_idxs, :], assignments_filtered[random_idxs, i].reshape(-1,1)), axis=1) for i in np.arange(len(match_centroids))]

In [None]:
random_match_cells = hv_plot_FOVs(rigs, match_overlay_binary_footprints, match_binary_contour_projs, 
                           overlay=True, labels=random_matched_cell_labels)
random_match_cells

# Visual Stimulus Triggered Responses

In [5]:
def parse_trial_frames(df, pre_trial=0, post_trial=0):
    trial_idx_frames = df[df.trial_num >= 1].groupby(['trial_num']).apply(
        lambda x: [x.index[0] - pre_trial * processing_parameters.wf_frame_rate, 
                   x.index[0], 
                   x.index[-1] + post_trial * processing_parameters.wf_frame_rate]
        ).to_numpy()
    trial_idx_frames = np.vstack(trial_idx_frames)

    if trial_idx_frames[0, 0] < 0:
        trial_idx_frames[0, 0] = 0

    if trial_idx_frames[-1, -1] > df.index[-1]:
        trial_idx_frames[-1, -1] = df.index[-1]

    # Get the shifts from the zero point (important for plotting)
    max_zero_idx_shift = np.max(trial_idx_frames[:, 1] - trial_idx_frames[:, 0])
    
    traces = []
    for i, frame in enumerate(trial_idx_frames):
        df_slice = df.iloc[frame[0]:frame[-1], :].copy()
        df_slice['trial_num'] = df_slice['trial_num'].max()
        df_slice['direction'] = df_slice['direction'].max()
        df_slice['direction_wrapped'] = df_slice['direction_wrapped'].max()
        df_slice['orientation'] = df_slice['orientation'].max()

        zero_idx_shift = np.abs((frame[1] - frame[0]) - max_zero_idx_shift)
        df_slice['zero_idx_shift'] = zero_idx_shift

        traces.append(df_slice)
    
    traces = pd.concat(traces, axis=0).reset_index(drop=True)
    return traces, trial_idx_frames


def drop_partial_or_long_trials(df, min_trial_length=4.5, max_trial_length=5.5):
    trial_lengths = df[df.trial_num >= 1].groupby('trial_num').apply(lambda x: x.shape[0] / processing_parameters.wf_frame_rate)

    # Drop trials that are shorter than min_trial_length (partial trials)
    short_trials = trial_lengths[trial_lengths < min_trial_length].index
    df = df.drop(df[df.trial_num.isin(short_trials)].index)

    # Drop trials that are longer than max_trial_length (errors in trial number indexing)
    long_trials = trial_lengths[trial_lengths > max_trial_length].index
    df = df.drop(df[df.trial_num.isin(long_trials)].index)

    return df

In [6]:
exp_query = bd.query_database('vr_experiment', search_string)
preproc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:preprocessing')
tc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:tc_analysis')

In [7]:
exp = WirefreeExperiment(exp_info=exp_query[0], preproc_info=preproc_query[0], tc_info=tc_query[0])
exp._load_preprocessing()

In [28]:
random_idxs = np.random.choice(np.arange(len(exp.cell_matches)), 5)
cells_to_match = exp.cell_matches[exp.metadata.rig].iloc[random_idxs]
cells_to_match = list(np.unique([f'cell_{idx:04d}' for idx in cells_to_match]))
print(cells_to_match)

['cell_0000']


In [11]:
exp.dff = tuning.calculate_dff(exp.raw_fluor.copy(), baseline_type='quantile', quantile=0.25)
exp.norm_dff = tuning.normalize_responses(exp.dff.copy())
# exp.norm_spikes = tuning.normalize_responses(exp.raw_spikes.copy(), remove_baseline=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ds[cells][ds[cells] < 0] = 0
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ds[cells][ds[cells] < 0] = 0


In [12]:
exp.norm_dff = drop_partial_or_long_trials(exp.norm_dff)
# exp.raw_spikes = drop_partial_or_long_trials(exp.raw_spikes)

pre_trial_period = 2
post_trial_period = 2
dff_trials, tc_frames = parse_trial_frames(exp.norm_dff.copy(), pre_trial=pre_trial_period, post_trial=post_trial_period)

In [13]:
a = get_trial_duration_stats(dff_trials, 'trial_num', 'time_vector')

Min. trial. dur.: 6.87
Max. trial. dur.: 8.93
Mean. trial. dur.: 8.89


In [14]:
#Plot dFF
trials_on = exp.norm_dff['trial_num'] > 1
trials_on_roll = np.roll(trials_on, int(-2.5 * processing_parameters.wf_frame_rate))
time = exp.norm_dff['time_vector']
trials_plot = hv.Area((time, trials_on)).opts(color='gray', alpha=0.25)
trials_roll_plot = hv.Area((time, trials_on_roll)).opts(color='blue', alpha=0.25)


norm_running = tuning.normalize(exp.kinematics['wheel_speed_abs'].to_numpy())
running_plot = hv.Curve((time, norm_running)).opts(color='red', alpha=0.5)

plt_list = []
for cell in cells_to_match:
    dff_plot = hv.Curve(exp.norm_dff[['time_vector', cell]]).opts(color='black')
    plt_list.append(trials_plot * running_plot * dff_plot )
dff_layout = hv.Layout(plt_list).opts(hv.opts.Area(xlabel='Time (s)', ylabel=f'dF/F {cell}', show_legend=False), hv.opts.Curve(height=200, width=800)).cols(1)
dff_layout

In [20]:
def trial_average_response(ds, stim_type):
    ds.dropna(inplace=True)

    idxs_shifts = ds.groupby(['trial_num']).apply(lambda x: np.unique(x.zero_idx_shift)[0]).reset_index()
    idxs_shifts = idxs_shifts.rename({0: 'zero_idx_shift'}, axis=1)	

    trials_per_stim = ds.groupby([stim_type, 'trial_num'])[cells_to_match].agg(list).reset_index()
    trials_per_stim = trials_per_stim.join(idxs_shifts.set_index('trial_num'), on='trial_num')
    trials_per_stim = trials_per_stim.groupby([stim_type]).agg(list)

    idxs_shifts = trials_per_stim['zero_idx_shift'].copy()

    trial_array = trials_per_stim.copy()

    for i, row in trials_per_stim.iterrows():
        for cell in cells_to_match:
            shifts = list(idxs_shifts.loc[row.name])
            trial_array.loc[i, cell] = fmisc.list_lists_to_array(row[cell], alignment='on', on=shifts)

    trial_averages = trial_array.applymap(np.nanmean, axis=0) 
    trial_averages = trial_averages.drop('zero_idx_shift', axis=1)
    trial_array = trial_array.drop('zero_idx_shift', axis=1) 
    
    return trial_averages, trial_array


def hv_plot_trial_averages(trial_averages, trials, cells):

    plt_list = []

    for i, cell in enumerate(cells):

        cell_resps = trials[cell]
        cell_mean = trial_averages[cell]
        
        for k in range(cell_resps.shape[-1]):
            resps = cell_resps.iloc[k]
            mean = cell_mean.iloc[k]
            
            trials_list = [hv.Curve(resps[r, :]).opts(color='k', alpha=0.25) for r in np.arange(resps.shape[0])]
            mean_list = [hv.Curve(mean).opts(color='r', xticks=[(40, 0), (90, 2.5), (150, 5)], xlabel='', ylabel='')]

            if i == 0:
                [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]
                    
            plt_list.append(hv.Overlay(trials_list + mean_list))

    return hv.Layout(plt_list)
    

In [21]:
ori_averages, ori_trials = trial_average_response(dff_trials.copy(), 'orientation')

time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
trial_vector = np.zeros_like(time_vector)
trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.25)

norm_dff_oris = hv_plot_trial_averages(ori_averages, ori_trials, cells_to_match).cols(12).opts(hv.opts.Curve(width=100, height=160))
norm_dff_oris = norm_dff_oris * trial_plot
norm_dff_oris

In [22]:
dir_averages, dir_trials = trial_average_response(dff_trials.copy(), 'direction_wrapped')

time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
trial_vector = np.zeros_like(time_vector)
trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.25)

norm_dff_dirs = hv_plot_trial_averages(dir_averages, dir_trials, cells_to_match).cols(12).opts(hv.opts.Curve(width=100, height=160))
norm_dff_dirs = norm_dff_dirs * trial_plot
norm_dff_dirs


# Running Triggered Responses

In [23]:
def running_triggered_averages(activity, running, speed_cutoff, min_duration=0.5, pre_trial=0, post_trial=0):

    activity.dropna(inplace=True)

    # Group contiuous bouts
    min_duration *= processing_parameters.wf_frame_rate    # convert to frames
    is_running = running[running >= speed_cutoff].index.to_numpy()
    bouts = [bout for bout in fmisc.consecutive(is_running, stepsize=min_duration) if len(bout) >= min_duration]
    bouts = np.array([[bout[0], bout[0] + post_trial*processing_parameters.wf_frame_rate] for bout in bouts])

    bouts[:, 0] -= pre_trial * processing_parameters.wf_frame_rate

    if bouts[-1, 1] > activity.index[-1]:
        bouts[-1, 1] = activity.index[-1]
    if bouts[0, 0] < 0:
        bouts[0, 0] = 0

    # Get the shifts from the zero point (important for plotting)
    max_zero_idx_shift = np.max(bouts[:, 1] - bouts[:, 0])

    traces = []
    for i, frame in enumerate(bouts):
        ds_slice = activity.iloc[frame[0]:frame[1], :].copy()
        ds_slice['trial_num'] = i

        zero_idx_shift = np.abs((frame[1] - frame[0]) - max_zero_idx_shift)
        ds_slice['zero_idx_shift'] = zero_idx_shift

        traces.append(ds_slice)


    traces = pd.concat(traces, axis=0).reset_index(drop=True)
    # traces.drop(['trial_num', 'direction', 'direction_wrapped', 'orientation'], axis=1, inplace=True)
    # traces = traces.groupby('trial_num').agg(list)
    return traces, bouts

In [24]:
speed_column = 'wheel_speed_abs'

speed_cutoff = np.percentile(exp.kinematics[speed_column], 80)
hv.Curve(exp.kinematics[['time_vector', speed_column]]) * hv.HLine(speed_cutoff).opts(color='r')

In [27]:
cells_to_match + ['trial_num', 'zero_idx_shift']

UFuncTypeError: ufunc 'add' did not contain a loop with signature matching types (dtype('<U9'), dtype('<U14')) -> None

In [29]:
run_traces, run_bouts = running_triggered_averages(exp.norm_dff.copy(), exp.kinematics[speed_column], speed_cutoff, min_duration=1.0, pre_trial=2, post_trial=5)
idxs_shifts = run_traces.groupby(['trial_num']).apply(lambda x: np.unique(x.zero_idx_shift)[0]).reset_index().rename({0: 'zero_idx_shift'}, axis=1)	
run_traces = run_traces[cells_to_match + ['trial_num', 'zero_idx_shift']].groupby('trial_num').agg(list)

plot_list = []

for cell in cells_to_match:
    resps = fmisc.list_lists_to_array(run_traces[cell].tolist(), alignment='on', on=idxs_shifts['zero_idx_shift'].to_numpy())
    mean = np.nanmean(resps, axis=0)

    trials_list = [hv.Curve(resps[r, :]).opts(color='black', alpha=0.25) for r in np.arange(resps.shape[0])]
    mean_list = [hv.Curve((np.arange(0, mean.shape[0]), mean)).opts(color='red', xticks=[(0, -2), (40, 0), (90, 2.5), (140, 5)], xlabel='', ylabel='', title=cell, fontsize={'xticks': 10, 'yticks': 10,})]
    vlines = [hv.VLine(40).opts(color='k', line_width=2)]
    plot_list.append(hv.Overlay(trials_list + mean_list + vlines).opts(width=150, height=150))

running_trig_avg = hv.Layout(plot_list).cols(5)
running_trig_avg

# Still Trials

In [30]:
still_trials = exp.kinematics.groupby('trial_num').filter(lambda x: x[speed_column].mean() < speed_cutoff).trial_num.unique()

In [31]:
still_ori_averages, still_ori_trials = trial_average_response(dff_trials.loc[dff_trials.trial_num.isin(still_trials)].copy(), 'orientation')

time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
trial_vector = np.zeros_like(time_vector)
trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.25)

norm_dff_oris_still = hv_plot_trial_averages(still_ori_averages, still_ori_trials, cells_to_match).cols(12).opts(hv.opts.Curve(width=100, height=160))
norm_dff_oris_still = norm_dff_oris_still * trial_plot
norm_dff_oris_still

In [32]:
still_dir_averages, still_dir_trials = trial_average_response(dff_trials.loc[dff_trials.trial_num.isin(still_trials)].copy(), 'direction_wrapped')

time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
trial_vector = np.zeros_like(time_vector)
trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.25)

norm_dff_dirs_still = hv_plot_trial_averages(still_dir_averages, still_dir_trials, cells_to_match).cols(12).opts(hv.opts.Curve(width=100, height=160))
norm_dff_dirs_still = norm_dff_dirs_still * trial_plot
norm_dff_dirs_still