In [None]:
# imports
import os
import sys
import pandas as pd
import numpy as np
import importlib
sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import scipy.signal as ss
import scipy.stats as stat

import panel as pn
import holoviews as hv
from holoviews import opts, dim
from holoviews.operation import histogram
hv.extension('bokeh')
from bokeh.resources import INLINE

import paths
import processing_parameters
import functions_misc as fm
import functions_bondjango as bd
import snakemake_scripts.tc_calculate as tc
import snakemake_scripts.wf_tc_calculate as wf_tc
import functions_plotting as fp
import functions_data_handling as fdh
from functions_tuning import calculate_dff

In [None]:
def calculate_information(occupancy, tuning_curve, average):
    """Calculate the information on a tuning curve based on Stefanini et al. 2020"""
    information = np.nansum(occupancy*(tuning_curve/average) * np.log2(tuning_curve/average))
    return information


def clipping_function(trace_in, threshold=8):
    """Clip traces to their threshold-th percentile"""
    # skip if there are only zeros
    if np.sum(trace_in) == 0:
        return trace_in
    # get the baseline
    baseline = np.percentile(trace_in[trace_in > 0], threshold)

    # clip the trace
    trace_in[trace_in < baseline] = 0
    return trace_in


def clip_calcium(pre_data):
    """ Clip the calcium traces based on baseline """

    # allocate memory for the cleaned up data
    data = []
    # define the clipping threshold in percentile of baseline
    clip_threshold = 8
    # for all the trials
    for idx, el in enumerate(pre_data):

        # get the current df
        current_df = el[1]
        labels = list(current_df.columns)
        cells = [el for el in labels if 'cell' in el]
        not_cells = [el for el in labels if 'cell' not in el]
        # get the non-cell data
        non_cell_data = current_df[not_cells]
        # get the current calcium data
        cell_data = current_df[cells].fillna(0)

        # do the cell clipping
        cell_data.apply(clipping_function, axis=1, raw=True, threshold=clip_threshold)

        # assemble a new data frame with only the matched cells and the rest of the data
        data.append(pd.concat((non_cell_data, cell_data), axis=1))
    return data


def parse_features(data, feature_list, bin_number=10):
    """set up the feature and calcium matrices"""

    # allocate memory for a data frame without the encoding model features
    feature_raw_trials = []
    # allocate memory for the calcium
    calcium_trials = []

    # get the features
    for idx, el in enumerate(data):
        # get the intersection of the labels
        label_intersect = [feat for feat in feature_list if feat in el.columns]

        # # add the y coordinate of the variables with x
        # coordinate_variables = [column.replace('_x', '_y') for column in label_intersect if '_x' in column]
        # label_intersect += coordinate_variables

        # get the features of interest
        target_features = el.loc[:, label_intersect]
        # get the original columns
        original_columns = target_features.columns

        # for all the columns
        for label in original_columns:
            # skip if latent or motif
            if ('latent' in label) | (label == 'motifs'):
                target_features[label] = target_features[label]
                continue

            # smooth the feature
            target_features[label] = ss.medfilt(target_features[label], 21)

        # # allocate a copy of the target features for changes
        # temp_features = target_features.copy()
        # # for the coordinate variables, turn into a 2D grid
        # for variable in coordinate_variables:
        #     x_variable = target_features[variable.replace('_y', '_x')].to_numpy()
        #     y_variable = target_features[variable].to_numpy()
        #     bin_ranges = processing_parameters.tc_params[variable.replace('_y', '_x')]
        #     bins = np.linspace(bin_ranges[0], bin_ranges[1], num=bin_number + 1)
        #     # bin the variables in 2D
        #     current_tc = \
        #         stat.binned_statistic_2d(x_variable, y_variable, y_variable, statistic='count', bins=bins,
        #                                  expand_binnumbers=True)
        #
        #     binnumbers = current_tc[3]
        #     # current_tc = np.ravel_multi_index((current_tc[3][0, :], current_tc[3][1, :]), (bin_ranges[0], bin_ranges[1]), mode='clip')
        #     current_tc = np.ravel_multi_index(binnumbers, (11, 11), mode='raise')
            # replace the x column in the target features

            # eliminate the

        # store the features
        feature_raw_trials.append(target_features)

        # get the calcium data
        cells = [cell for cell in el.columns if 'cell' in cell]
        cells = el.loc[:, cells].to_numpy()

        # store
        calcium_trials.append(cells)

    return feature_raw_trials, calcium_trials


def extract_tc_parts(current_feature_0, cell_number, calcium_trials, feature_counts, bins, num_splits=2):
    """
    Extract the split tuning curves for consistency calculation. Generalized version of extract_half_tc
    :param current_feature_0:
    :param cell_number:
    :param calcium_trials:
    :param feature_counts:
    :param bins:
    :return:
    """

    tc_part_temp = []

    # Split the trace into parts
    split_bounds = np.array_split(np.arange(current_feature_0.shape[0]), num_splits)
    
    for split, split_vector in enumerate(split_bounds):
        # get the split feature
        split_feature_0 = current_feature_0[split_vector]

        # exclude nan values
        keep_vector = ~np.isnan(split_feature_0)
        keep_feature_0 = split_feature_0[keep_vector]

        # allocate a list for the cells
        tc_cell = []

        # for all the cells
        for cell in np.arange(cell_number):
            # get the current cell
            split_cell = calcium_trials[split_vector, cell]
            keep_cell = split_cell[keep_vector]

            # get the tc
            current_tc = \
                stat.binned_statistic(keep_feature_0, keep_cell, statistic='sum', bins=bins)[0]

            # normalize the TC
            norm_tc = current_tc / feature_counts

            # remove nans and infs
            norm_tc[np.isnan(norm_tc)] = 0
            norm_tc[np.isinf(norm_tc)] = 0
            # store
            tc_cell.append(norm_tc)

        # store the cells
        tc_part_temp.append(tc_cell)

    return tc_part_temp


def extract_half_tc(current_feature_0, cell_number, calcium_trials, feature_counts, bins):
    """Get the the half tuning curves for consistency calculation"""
    tc_half_temp = []

    # for first and second half
    for half in np.arange(2):
        # get the half vector
        half_bound = int(np.floor(current_feature_0.shape[0] / 2))
        half_vector = np.arange(half_bound) + half_bound * half
        half_feature_0 = current_feature_0[half_vector]
        # exclude nan values
        keep_vector = ~np.isnan(half_feature_0)
        keep_feature_0 = half_feature_0[keep_vector]

        # allocate a list for the cells
        tc_cell = []

        # for all the cells
        for cell in np.arange(cell_number):
            # get the current cell
            half_cell = calcium_trials[half_vector, cell]
            keep_cell = half_cell[keep_vector]

            # get the tc
            current_tc = \
                stat.binned_statistic(keep_feature_0, keep_cell, statistic='sum', bins=bins)[0]

            # normalize the TC
            norm_tc = current_tc / feature_counts
            # remove nans and infs
            norm_tc[np.isnan(norm_tc)] = 0
            norm_tc[np.isinf(norm_tc)] = 0
            # store
            tc_cell.append(norm_tc)

        # store the cells
        tc_half_temp.append(tc_cell)

    return tc_half_temp


def shuffle_random(cell, counts_feature_0, feature_counts, bins, tc_idx, shuffle_number=100):
    # allocate memory for the shuffles
    # shuffle_array = np.zeros((shuffle_number, working_bin_number))
    shuffle_array = np.zeros((shuffle_number, 1))
    shuffle_prediction = np.zeros((shuffle_number, 1))

    # TODO change randomization to wrapping with lag (or rather add it in addition to the current one)
    # use np.take
    # shuffle the calcium activity
    for shuffle in np.arange(shuffle_number):

        # randomize the calcium activity
        random_cell = cell.copy()
        random_cell = np.random.choice(random_cell, cell.shape[0])

        # Get the sum of the bins
        tc_random = \
            stat.binned_statistic(counts_feature_0, random_cell, statistic='sum', bins=bins)[0]

        # get the information
        shuffle_array[shuffle] = calculate_information(feature_counts, tc_random, np.mean(random_cell))

        # process the TC
        tc_random = tc_random / feature_counts
        tc_random[np.isnan(tc_random)] = 0
        tc_random[np.isinf(tc_random)] = 0
        # shuffle_array[shuffle, :] = tc_random

        # use the tc_idx to regenerate the activity
        predicted_calcium = tc_random[tc_idx-2]

        # get the correlation with the real calcium
        random_quality = stat.spearmanr(random_cell, predicted_calcium, nan_policy='omit')[0]
        shuffle_prediction[shuffle] = random_quality

        return shuffle_array, shuffle_prediction
    

def shuffle_random_bin(cell, counts_feature_0, feature_counts, bins, tc_idx, time_bin_width=0.5, shuffle_number=100):

    # allocate memory for the shuffles
    shuffle_array = np.zeros((shuffle_number, 1))
    shuffle_prediction = np.zeros((shuffle_number, 1))

    # generate a time vector and bin it into 500 ms bins
    time_vector = np.arange(cell.shape[0], dtype=float) / processing_parameters.wf_frame_rate
    bin_edges = np.arange(time_vector[0], time_vector[-1], time_bin_width)
    binned_time_idxs = np.digitize(time_vector, bin_edges)
    unique_time_bins = np.unique(binned_time_idxs)

    # shuffle the calcium activity
    for shuffle in np.arange(shuffle_number):
        
        # Shuffle the time while maintaining the binning. Deliberately oversample to ensure we have enough
        random_time_bins = np.random.choice(unique_time_bins.copy(), int(unique_time_bins.shape[0] * 1.2), replace=True)
        random_time_idxs = np.squeeze(np.concatenate([np.argwhere(binned_time_idxs == el) for el in random_time_bins]))

        # Trim the indexes to size of calcium activity and randomize the calcium activity
        random_cell = cell.copy()
        random_time_idxs = random_time_idxs[:random_cell.shape[0]]
        random_cell = random_cell[random_time_idxs]

        # randomize the calcium activity
        random_cell = cell.copy()
        random_cell = random_cell[random_time_idxs]
        # print(random_cell.shape, counts_feature_0.shape)
        
        # Get the sum of the bins
        tc_random = \
            stat.binned_statistic(counts_feature_0, random_cell, statistic='sum', bins=bins)[0]

        # get the information
        shuffle_array[shuffle] = calculate_information(feature_counts, tc_random, np.mean(random_cell))

        # process the TC
        tc_random = tc_random / feature_counts
        tc_random[np.isnan(tc_random)] = 0
        tc_random[np.isinf(tc_random)] = 0

        # use the tc_idx to regenerate the activity
        predicted_calcium = tc_random[tc_idx-2]

        # get the correlation with the real calcium
        random_quality = stat.spearmanr(random_cell, predicted_calcium, nan_policy='omit')[0]
        shuffle_prediction[shuffle] = random_quality

    return shuffle_array, shuffle_prediction


def add_lag(cell, counts_feature_0, feature_counts, bins, tc_idx, lag=0.5):
    
    # calculate how many indices to lag on each iteration
    lag_step =  int(lag * processing_parameters.wf_frame_rate)

    # get the number of lags to calculate
    num_lags = int(cell.shape[0] // lag_step)

    # allocate memory for the shuffles
    shuffle_array = np.zeros((num_lags, 1))
    shuffle_prediction = np.zeros((num_lags, 1))

    # lag the calcium activity
    for shuffle in np.arange(num_lags, dtype=int):

        # lag the calcium activity
        lag_cell = np.roll(cell.copy(), shuffle * lag_step)

        # Get the sum of the bins
        tc_random = \
            stat.binned_statistic(counts_feature_0, lag_cell, statistic='sum', bins=bins)[0]

        # get the information
        shuffle_array[shuffle] = calculate_information(feature_counts, tc_random, np.mean(lag_cell))

        # process the TC
        tc_random = tc_random / feature_counts
        tc_random[np.isnan(tc_random)] = 0
        tc_random[np.isinf(tc_random)] = 0

        # use the tc_idx to regenerate the activity
        predicted_calcium = tc_random[tc_idx-2]

        # get the correlation with the real calcium
        random_quality = stat.spearmanr(lag_cell, predicted_calcium, nan_policy='omit')[0]
        shuffle_prediction[shuffle] = random_quality

    return shuffle_array, shuffle_prediction


def extract_full_tc(counts_feature_0, feature_counts, cell_number, calcium_trials, 
                    bins, keep_vector_full, shuffle_number, percentile, shuffle_kind='random', lag_or_bin=1):
    """Get the full tc"""
    # allocate memory for the full tc per cell
    tc_cell_full = []
    tc_cell_resp = np.zeros((cell_number, 4))

    # calculate the full TC
    for cell in np.arange(cell_number):
        keep_cell = calcium_trials[keep_vector_full, cell]

        # Get the sum of the bins
        tc_cell, _, tc_idx = \
            stat.binned_statistic(counts_feature_0, keep_cell, statistic='sum', bins=bins)
        
        # get the information
        information_content = calculate_information(feature_counts, tc_cell, np.mean(keep_cell))

        # process the TC
        tc_cell = tc_cell / feature_counts
        tc_cell[np.isnan(tc_cell)] = 0
        tc_cell[np.isinf(tc_cell)] = 0

        # use the tc_idx to regenerate the activity
        predicted_calcium = tc_cell[tc_idx-2]

        # get the correlation with the real calcium
        tc_quality = stat.spearmanr(keep_cell, predicted_calcium, nan_policy='omit')[0]

        # shuffle the calcium activity
        if shuffle_kind == 'random':
            shuffle_array, shuffle_prediction = shuffle_random(keep_cell, counts_feature_0, feature_counts, bins, tc_idx, 
                                                               shuffle_number=shuffle_number)
        elif shuffle_kind == 'random_bin':
            shuffle_array, shuffle_prediction = shuffle_random_bin(keep_cell, counts_feature_0, feature_counts, bins, tc_idx, 
                                                                   time_bin_width=lag_or_bin, shuffle_number=shuffle_number)
        elif shuffle_kind == 'lag_wrap':
            shuffle_array, shuffle_prediction = add_lag(keep_cell, counts_feature_0, feature_counts, bins, tc_idx, lag=lag_or_bin)
        else:
            raise ValueError('Shuffle kind not recognized')

        # get the threshold
        resp_threshold = np.percentile(np.abs(shuffle_array.flatten()), percentile)
        qual_threshold = np.percentile(np.abs(shuffle_prediction.flatten()), percentile)

        # fill up the responsivity matrix
        # tc_cell_resp[cell, 0] = np.mean(np.sort(np.abs(tc_cell), axis=None)[-3:]) / resp_threshold
        # tc_cell_resp[cell, 1] = np.sum(np.abs(tc_cell) > resp_threshold) > 3
        tc_cell_resp[cell, 0] = information_content
        tc_cell_resp[cell, 1] = np.abs(information_content) > resp_threshold
        tc_cell_resp[cell, 2] = tc_quality
        tc_cell_resp[cell, 3] = np.abs(tc_quality) > qual_threshold

        # store
        tc_cell_full.append(tc_cell)

    return tc_cell_full, tc_cell_resp


def extract_tcs_responsivity(feature_raw_trials, calcium_trials, target_variables, cell_number,
                             percentile=99, bin_number=10, shuffle_kind='random'):
    '''
    Extract the tuning curves (full and half) and their responsivity index
    
    feature_raw_trials: (pd.DataFrame) The kinematic features 
    calcium_trials: (np.array)  The calcium traces (cells x time)
    target_variables: (list of str) names of the variables to extract
    cell_number: (int) number of cells
    percentile: (int) percentile for the responsivity index
    bin_number: (int) number of bins for the tuning curves (bins tile the range of the TCs)
    '''
    
    # get the number of pairs
    var_number = len(target_variables)

    # define the number of calcium shuffles
    shuffle_number = 100

    # allocate memory for the trial TCs
    tc_half = {}
    tc_full = {}
    tc_resp = {}
    tc_counts = {}
    tc_edges = {}
    # initialize the template_idx
    template_idx = -1
    # for all the features
    for var_idx in np.arange(var_number):
        # get the current feature
        feature_name = target_variables[var_idx]
        # feature_names = feature_name.split('__')
        # skip the pair and save an empty if the feature is not present
        try:
            current_feature_0 = feature_raw_trials.loc[:, feature_name].to_numpy()
            # save the index of the feature
            template_idx = var_idx
            # current_feature_1 = feature_raw_trials.loc[:, feature_names[1]].to_numpy()
        except KeyError:
            tc_half[feature_name] = []
            tc_full[feature_name] = []
            tc_resp[feature_name] = []
            tc_counts[feature_name] = []
            tc_edges[feature_name] = []
            continue

        # get the bins from the parameters file (bins based on range of the data)
        try:
            bin_ranges = processing_parameters.tc_params[feature_name]
            # calculate the bin edges based on the ranges
            if len(bin_ranges) == 1:
                bins = np.arange(bin_ranges[0] + 1) - 0.5
                # bins = bin_ranges[0]

            else:
                bins = np.linspace(bin_ranges[0], bin_ranges[1], num=bin_number + 1)

        except KeyError:
            # if not in the parameters, go for default and report
            print(f'Feature {feature_name} not found, default to 10 bins ad hoc')
            bins = 10

        # exclude nan values
        keep_vector_full = ~np.isnan(current_feature_0)
        counts_feature_0 = current_feature_0[keep_vector_full]

        # get the counts for each range bin
        feature_counts_raw, tc_current_edges, _ = \
            stat.binned_statistic(counts_feature_0, counts_feature_0, statistic='count', bins=bins)
        feature_counts = feature_counts_raw.copy()

        # zero the positions with less than 3 counts
        feature_counts[feature_counts < 3] = 0

        # get the half tuning curves
        tc_splits = processing_parameters.tc_consistency_splits
        tc_half_temp = extract_tc_parts(current_feature_0, cell_number, calcium_trials, feature_counts, bins, num_splits=tc_splits)

        # get the full tuning curves
        tc_cell_full, tc_cell_resp = extract_full_tc(counts_feature_0, feature_counts, cell_number, calcium_trials, 
                                                     bins, keep_vector_full, shuffle_number, percentile, shuffle_kind, 
                                                     lag_or_bin=processing_parameters.tc_lags[feature_name])
        # store the halves and fulls
        tc_half[feature_name] = tc_half_temp
        tc_full[feature_name] = tc_cell_full
        tc_resp[feature_name] = tc_cell_resp
        tc_counts[feature_name] = feature_counts_raw
        tc_edges[feature_name] = tc_current_edges
    
    # run through the features and fill up the non-populated ones with nan
    for feat in tc_half.keys():
        if len(tc_half[feat]) == 0:
            tc_half[feat] = tc_half[target_variables[template_idx]]

            for i in np.arange(tc_splits):
                tc_half[feat][i] = [el * np.nan for el in tc_half[feat][i]]

            tc_full[feat] = tc_full[target_variables[template_idx]]
            tc_full[feat] = [el * np.nan for el in tc_full[feat]]
            tc_resp[feat] = tc_resp[target_variables[template_idx]] * np.nan
            tc_counts[feat] = tc_counts[target_variables[template_idx]] * np.nan
            tc_edges[feat] = tc_edges[target_variables[template_idx]] * np.nan

    return tc_half, tc_full, tc_resp, tc_counts, tc_edges


def extract_consistency(tc_half, target_variables, cell_number, shuffle_kind='random', comp_kind='halves', percentile=95):
    """Calculate TC consistency"""

    # define the number of shuffles
    shuffle_number = 100

    # get the number of pairs
    var_number = len(target_variables)

    # get the number of splits
    num_splits = processing_parameters.tc_consistency_splits

    # allocate memory for the trial TCs
    tc_cons = {}

    # for all the features
    for var_idx in np.arange(var_number):

        # get the name
        feature_name = target_variables[var_idx]

        # get the two halves
        halves = tc_half[feature_name]

        # allocate an array for the correlations and tests
        tc_half_temp = np.zeros([cell_number, num_splits])

        # if empty, skip
        if len(halves) == 0:
            tc_cons[feature_name] = []
            continue

        # calculate the real and shuffle correlation
        for cell in np.arange(cell_number):

            # get the current cell first and second half
            if comp_kind == 'halves':
                half_idx = num_splits//2
                # print(len(halves[:half_idx][cell]))
                current_first = np.array([halves[i][cell] for i in np.arange(half_idx)]).flatten()
                current_second = np.array([halves[i][cell] for i in np.arange(half_idx, num_splits)]).flatten()
            
            # get the odd and even splits
            elif comp_kind == 'odd_even':
                even_idx = np.arange(0, num_splits, 2)
                odd_idx = np.arange(1, num_splits, 2)
                current_first = np.array([halves[i][cell] for i in even_idx]).flatten()
                current_second = np.array([halves[i][cell] for i in odd_idx]).flatten()

            else:
                raise ValueError('Comparison kind not recognized')
            
            # real correlation
            real_correlation = np.corrcoef(current_first, current_second)[1][0]

            # shuffle array
            shuffle_array = np.zeros([shuffle_number, 1])

            # Used by if shuffle kind is random_bin or lag_wrap
            time_vector = np.arange(current_second.shape[0], dtype=float) / processing_parameters.wf_frame_rate
            bin_edges = np.arange(time_vector[0], time_vector[-1], processing_parameters.tc_lags[feature_name])
            binned_time_idxs = np.digitize(time_vector, bin_edges)
            unique_time_bins = np.unique(binned_time_idxs)

            # calculate the confidence interval
            for shuffle in np.arange(shuffle_number):
                random_second = current_second.copy().flatten()

                # shuffle the second half calcium activity
                if shuffle_kind == 'random':
                    random_second = np.random.choice(random_second, random_second.shape[0])

                elif shuffle_kind == 'random_bin':
                    # Shuffle the time while maintaining the binning. Deliberately oversample to ensure we have enough
                    random_time_bins = np.random.choice(unique_time_bins.copy(), int(unique_time_bins.shape[0] * 1.2),
                                                        replace=True)
                    random_time_idxs = np.squeeze(
                        np.concatenate([np.argwhere(binned_time_idxs == el) for el in random_time_bins]))

                    # Trim the indexes to size of calcium activity and randomize the calcium activity
                    random_time_idxs = random_time_idxs[:random_second.shape[0]]
                    random_second = random_second[random_time_idxs]

                else:
                    raise ValueError('Shuffle kind not recognized')

                shuffle_array[shuffle] = np.corrcoef(current_first, random_second)[1][0]
            # turn nans into 0
            shuffle_array[np.isnan(shuffle_array)] = 0

            # get the confidence interval
            conf_interval = np.percentile(shuffle_array, percentile)
            # store the correlation and whether it passes the criterion
            tc_half_temp[cell, 0] = real_correlation
            tc_half_temp[cell, 1] = (real_correlation > conf_interval) & (real_correlation > 0) & \
                                    (conf_interval > 0)

        # store for the variable
        tc_cons[feature_name] = tc_half_temp
    return tc_cons


def convert_to_dataframe(half_in, full_in, counts_in, resp_in, cons_in, edges_in, date, mouse, setup):
    """Convert the TCs and their metrics into dataframe format"""
    # allocate an output dict
    out_dict = {}
    # also one for the counts and edges
    count_dict = {}
    edges_dict = {}
    # cycle through features
    for feat in half_in.keys():
        # get all the components
        c_half = half_in[feat]
        c_full = full_in[feat]
        c_count = counts_in[feat]
        c_resp = resp_in[feat]
        c_cons = cons_in[feat]
        c_edges = edges_in[feat]

        # if the feature is not present, skip
        if len(c_half) == 0:
            continue
        # flatten the tcs and generate labels
        flat_half = []
        labels_half = []
        for half in np.arange(2):
            flat_half.append(np.array([el.flatten() for el in c_half[half]]))
            labels_half.append(['half_'+str(half)+'_bin_'+str(el) for el in np.arange(flat_half[half].shape[1])])

        flat_full = np.array([el.flatten() for el in c_full])
        labels_full = ['bin_'+str(el) for el in np.arange(flat_full.shape[1])]

        flat_count = np.array([el.flatten() for el in c_count]).T
        labels_count = ['count_' + str(el) for el in np.arange(flat_count.shape[1])]

        flat_edges = np.array([el.flatten() for el in c_edges]).T
        labels_edges = ['edge_' + str(el) for el in np.arange(flat_edges.shape[1])]
        # turn everything into dataframes
        df_half = pd.DataFrame(np.hstack(flat_half), columns=np.hstack(labels_half), dtype=np.float32)
        df_full = pd.DataFrame(flat_full, columns=labels_full, dtype=np.float32)
        df_resp = pd.DataFrame(c_resp, columns=['Resp_index', 'Resp_test', 'Qual_index', 'Qual_test'], dtype=np.float32)
        df_cons = pd.DataFrame(c_cons, columns=['Cons_index', 'Cons_test'], dtype=np.float32)
        # concatenate
        df_concat = pd.concat((df_half, df_full, df_resp, df_cons), axis=1)
        # generate columns for date and animal
        df_concat['day'] = date
        df_concat['animal'] = mouse
        df_concat['rig'] = setup
        # store
        out_dict[feat] = df_concat
        # store the counts
        df_count = pd.DataFrame(flat_count, columns=labels_count, dtype=np.float32)
        df_count['day'] = date
        df_count['animal'] = mouse
        df_count['rig'] = setup
        count_dict[feat] = df_count

        # store the edges
        df_edges = pd.DataFrame(flat_edges, columns=labels_edges, dtype=np.float32)
        df_edges['day'] = date
        df_edges['animal'] = mouse
        df_edges['rig'] = setup
        edges_dict[feat] = df_edges

    return out_dict, count_dict, edges_dict


In [None]:
importlib.reload(fp)
# set up the figure theme
fp.set_theme()

In [None]:
importlib.reload(processing_parameters)
# get the search query
search_string = processing_parameters.search_string + ', analysis_type:preprocessing'
parsed_search_string = fdh.parse_search_string(search_string)

# get the paths from the database
all_path = bd.query_database('analyzed_data', search_string)
input_path = [el['analysis_path'] for el in all_path if ('_preproc' in el['slug']) and (parsed_search_string['mouse'].lower() in el['slug'])]
# get the day, animal and rig
day = '_'.join(all_path[0]['slug'].split('_')[0:3])
rig = all_path[0]['rig']
animal = all_path[0]['slug'].split('_')[3:6]
animal = '_'.join([animal[0].upper()] + animal[1:])

# assemble the output path
out_path = os.path.join(paths.analysis_path, '_'.join((day, animal, rig, 'tcday.hdf5')))

# allocate memory for the data
raw_data = []
# allocate memory for excluded trials
excluded_trials = []
# for all the files
for files in input_path:
    # load the data
    with pd.HDFStore(files, mode='r') as h:
        if ('/matched_calcium' in h.keys()):

            # concatenate the latents
            dataframe = h['matched_calcium']
            # store
            raw_data.append(dataframe)
            print(input_path)
        else:
            excluded_trials.append(files)
print(f'Number of files loaded: {len(raw_data)}')

In [None]:
importlib.reload(tc)
importlib.reload(processing_parameters)

variable_list = processing_parameters.variable_list_free

ds = raw_data[0]

# define ca activity type
ca_type = 'spikes'    #'spikes' or 'fluor'

# drop activity not of correct type
cols_to_drop = [el for el in ds.columns if ('cell' in el) and (ca_type not in el)]
ds.drop(cols_to_drop, axis='columns', inplace=True)

# If using fluorescence data, calulate dF/F
if ca_type == 'fluor':
    ds = calculate_dff(ds, baseline_type='iti', inplace=True)

# define the pairs to quantify
if rig in ['VWheel', 'VWheelWF']:
    variable_names = processing_parameters.variable_list_fixed
    ds['wheel_speed_abs'] = np.abs(ds['wheel_speed'])
else:
    variable_names = processing_parameters.variable_list_free

# Convert to cm
for col in ['wheel_speed', 'wheel_speed_abs', 'wheel_acceleration', 
            'mouse_y_m', 'mouse_z_m', 'mouse_x_m',
            'head_height', 'mouse_speed', 'mouse_acceleration']:
    if col in ds.columns:
        ds[col] = ds[col] * 100.
    else:
        pass

# clip the calcium traces
clipped_data = tc.clip_calcium([('', ds)])

# parse the features
features, calcium = tc.parse_features(clipped_data, variable_list, bin_number=20)

# concatenate all the trials
features = pd.concat(features)
calcium = np.concatenate(calcium)

# calculate the delta heading angle (i.e. direction) for the cricket with respect to the mouse
# get the heading
# prey_heading = features.loc[:, 'cricket_0_delta_heading']
# prey_direction = np.concatenate(([0], np.diff(prey_heading)), axis=0)
# prey_distance = features.loc[:, 'cricket_0_mouse_distance']
# prey_loom = np.concatenate(([0], np.diff(prey_distance)), axis=0)
# prey_visual_angle = features.loc[:, 'cricket_0_visual_angle']
# prey_delta_visual = np.concatenate(([0], np.diff(prey_visual_angle)), axis=0)
# # add the variables to the features
# features.loc[:, 'cricket_0_direction'] = prey_direction
# features.loc[:, 'cricket_0_loom'] = prey_loom
# features.loc[:, 'cricket_0_delta_visual'] = prey_delta_visual

# define the variable to quantify
# variable_pairs = ['mouse_angular_speed']
# print(raw_data[0][1].keys()[:30])
variable_pairs = processing_parameters.variable_list_free
# variable_pairs = ['mouse_x']
# variable_pairs = ['latent_0']

# get the number of cells
cell_num = calcium.shape[1]

In [None]:
# plot the calcium
print(calcium.shape)
hv.Raster(calcium.T).opts(width=1000, height=600, tools=['hover'])

In [None]:
# %%time
importlib.reload(tc)
importlib.reload(processing_parameters)

# define the bparameters from processing parameters
bin_number = processing_parameters.bin_number
shuffle_kind = processing_parameters.tc_shuffle_kind
resp_qual_percentile = processing_parameters.tc_resp_qual_cutoff

# get the TCs and their responsivity
tcs_half, tcs_full, tcs_resp, tc_count, tc_bins = extract_tcs_responsivity(features, calcium, variable_names, cell_num, 
                                                                            percentile=resp_qual_percentile, bin_number=bin_number, shuffle_kind=shuffle_kind)


In [None]:
# get the TC consistency
consistency_percentile = processing_parameters.tc_consistency_cutoff
consistency_comp = processing_parameters.tc_consistency_comp
tcs_cons = extract_consistency(tcs_half, variable_names, cell_num, shuffle_kind=shuffle_kind, comp_kind=consistency_comp, percentile=consistency_percentile)

# convert the outputs into a dataframe
tcs_dict, tcs_counts_dict, tcs_bins_dict = tc.convert_to_dataframe(tcs_half, tcs_full, tc_count, tcs_resp, tcs_cons, tc_bins, day, animal, rig)

In [None]:
# Plot the criteria distributions
plot_list = []
map_dict = {}

pair_number = len(variable_pairs)
# print(tcs_cons)
# for all features
for pair_idx in np.arange(pair_number):
     # get the name
#     feature_name = feature_raw_trials[0].columns[feature]
    feature_name = list(tcs_cons.keys())[pair_idx]
#     feature_name = 
    # collect data across trials
    across_cons = tcs_cons[feature_name]
    across_resp = tcs_resp[feature_name]
    # remove the inf
    across_cons[np.isinf(across_cons)] = np.nan
    across_resp[np.isinf(across_resp)] = np.nan
    
    # remove outliers (NEED TO FIX THIS IN PREPROCESSING)
#     across_cons[across_cons>1] = np.nan
#     across_resp[across_resp>6] = np.nan
    
#     # also generate maps to identify the trial and cell
#     trial_map = np.hstack([np.ones([el[feature_name].shape[0]])*idx 
#                            for idx, el in enumerate(tcs_cons)]).T.astype(int)
#     cell_map = np.hstack([np.arange(el[feature_name].shape[0]) 
#                           for el in tcs_cons]).T.astype(int)
    
    # plot
    cons_pass = (across_cons[:, 1]==1) & (across_resp[:, 1]==0)
    resp_pass = (across_resp[:, 1]==1) & (across_cons[:, 1]==0)
    both_pass = (across_cons[:, 1]==1) & (across_resp[:, 1]==1)
    none_pass = (across_cons[:, 1]==0) & (across_resp[:, 1]==0)
    both_plot = hv.Scatter((across_cons[both_pass, 0], across_resp[both_pass, 0]), 
                      kdims=['Consistency'], vdims=['Responsivity'], label='Both')
    both_plot.opts(title=feature_name, width=600, height=600)
    none_plot = hv.Scatter((across_cons[none_pass, 0], across_resp[none_pass, 0]), 
                      kdims=['Consistency'], vdims=['Responsivity'], label='None')
    cons_plot = hv.Scatter((across_cons[cons_pass, 0], across_resp[cons_pass, 0]), 
                      kdims=['Consistency'], vdims=['Responsivity'], label='Cons')
    resp_plot = hv.Scatter((across_cons[resp_pass, 0], across_resp[resp_pass, 0]), 
                      kdims=['Consistency'], vdims=['Responsivity'], label='Resp')
    plot_list.append(both_plot*none_plot*cons_plot*resp_plot)
#     plot_list.append(both_plot)
    
    # store the maps and vector for plotting later
#     map_dict[feature_name] = [trial_map, cell_map, both_pass]
    print(f'Number of cells passing the thresholds for {feature_name}: {np.sum(both_pass)}')
print(f'Total number of cells: {both_pass.shape[0]}')
hv.Layout(plot_list).opts(opts.Scatter(size=10))
# print(across_cons)
# print(across_resp)

In [None]:
# plot cells passing the criterion

# define the target feature
target_feature = variable_pairs[1]
# get the indexes of the target cells
cell_idx = np.argwhere(both_pass).flatten()
# allocate memory for the plots
cell_plots = []

# for all the cells
for target_cell in cell_idx:
    # get the bins
    current_bins = processing_parameters.tc_params[target_feature]
    bins0 = np.linspace(current_bins[0], current_bins[1], bin_number)
    # get the tuning curve
    current_tc = tcs_full[target_feature][target_cell]
    current_half0 = tcs_half[target_feature][0][target_cell]
    current_half1 = tcs_half[target_feature][1][target_cell]
#     print(tc)
    
    # plot
    plot = hv.Curve((bins0, current_tc), kdims=target_feature, vdims='Activity (a.u.)', label='Full')
    halfplot0 = hv.Curve((bins0, current_half0), kdims=target_feature, vdims='Activity (a.u.)', label='half 0')
    halfplot1 = hv.Curve((bins0, current_half1), kdims=target_feature, vdims='Activity (a.u.)', label='half 1')
    # store
    cell_plots.append(plot*halfplot0*halfplot1)
# plot the layout
hv.Layout(cell_plots).opts(shared_axes=True)

In [None]:
# plot the occupancy TCs

plot_list = []

cmap='Spectral'

pair_number = len(variable_pairs)
# print(tcs_cons)
# for all features
for pair_idx in np.arange(pair_number):
    current_bins = processing_parameters.tc_params[variable_pairs[pair_idx]]


    bins0 = np.linspace(current_bins[0], current_bins[1], bin_number)
#     bins1 = np.linspace(current_bins[1][0], current_bins[1][1], 10)
    
#     var_names = variable_pairs[pair_idx].split('__')[::-1]
    var_names = variable_pairs[pair_idx]

#     full_map = hv.Image((bins1, bins0, 1-tcs_counts[variable_pairs[pair_idx]]), kdims=var_names)
    full_map = hv.Curve((bins0, tc_count[var_names]), kdims=[var_names], vdims='Activity (a.u.)')
#     full_map.opts(cmap=cmap, colorbar=True)
    #     full_map.opts(shared_axes=False, xrotation=45, cmap=cmap)
#     full_map = fp.format_figure(full_map, frame_width=400, frame_height=400)
#     full_map.opts(tools=['hover'], fontsize=14, colorbar_opts={'major_label_text_font_size':'40pt'})
#     full_map.opts()
    plot_list.append(full_map)

hv.Layout(plot_list).cols(5)
    

In [None]:
target_cell = 21
pair_number = len(variable_pairs)
cmap = 'Purples'
plot_list = []
# print(tcs_cons)
# for all features
for pair_idx in np.arange(pair_number):
    current_bins = processing_parameters.tc_params[variable_pairs[pair_idx]]


    bins0 = np.linspace(current_bins[0], current_bins[1], bin_number)
#     bins1 = np.linspace(current_bins[1][0], current_bins[1][1], 10)
    
#     var_names = variable_pairs[pair_idx].split('__')[::-1]
    var_names = variable_pairs[pair_idx]
    full_map = hv.Curve((bins0, tcs_full[var_names][target_cell]), kdims=[var_names], vdims='Activity (a.u.)')
#     full_map = hv.Image((bins1, bins0, tcs_full[variable_pairs[pair_idx]][target_cell]), kdims=var_names)
#     full_map.opts(cmap=cmap, colorbar=True)
    #     full_map.opts(shared_axes=False, xrotation=45, cmap=cmap)
#     full_map = fp.format_figure(full_map, frame_width=400, frame_height=400)
#     full_map.opts(tools=['hover'], fontsize=14, colorbar_opts={'major_label_text_font_size':'40pt'})
#     full_map.opts()
    plot_list.append(full_map)

hv.Layout(plot_list).cols(5)


In [None]:
# Also plot it's calcium activity
calcium_plot = hv.Curve((np.arange(calcium.shape[0]), calcium[:, target_cell])).opts(width=800, xrotation=45)
calcium_plot

In [None]:
plot_list = []
for feature in features.columns:
    trace = features[feature].to_numpy()
    traces = np.array_split(trace, 2)
    plot = hv.Overlay([hv.Curve(traces[0]), hv.Curve(traces[1])]).opts(title=feature, xrotation=45)
    plot_list.append(plot)
hv.Layout(plot_list).cols(5).opts(shared_axes=False)