# TO DO 
- [x] Responsivity - bootstrap for metric
- [ ] Mean spike rate per trial
- [x] Tuning Width
- [x] Tuning Peak
- [x] Smooth tuning curves
- [ ] Running vs not

Notes:
- [x] Distribution-based metrics of responsivity (bootstrapping, etc) over arbitrary values (Baden 2016 style)
- [ ] Population measures
    - [ ] distribution of widths/prefs
    - [ ] pop. vector decoding

# Functions

## Utilities

In [None]:
def normalize_rows(data_in):
    for idx, el in enumerate(data_in):
        data_in[idx, :] = (el-np.nanmin(el))/(np.nanmax(el)-np.nanmin(el))
    return data_in

def normalize(data_in):
    return (data_in - np.nanmin(data_in)) / (np.nanmax(data_in) - np.nanmin(data_in))

## Gaussian Fitting

In [None]:
import functions_tuning as tuning

## Plotting

In [None]:
def spike_raster(data, cells=None):
        if cells is None:
            cells = [el for el in data.columns if 'cell' in el]
    
        spikes = data.loc[:, cells]
    
        im = hv.Image((data.time_vector, np.arange(len(cells)), spikes.values.T), 
                        kdims=['Time (s)', 'Cells'], vdims=['Activity (a.u.)'])
        im.opts(width=600) #, cmap='Purples')
        return im
    
def trace_raster(data, cells=None, ds_factor=1):
    if cells is None:
        cells = [el for el in data.columns if 'cell' in el]

    trace = data.loc[:, cells]
    max_std = trace.std().max()

    lines = {i: hv.Curve((data.time_vector, trace.iloc[:, i].values.T + i*max_std)) for i in np.arange(len(cells))}
    lineoverlay = hv.NdOverlay(lines, kdims=['Time (s)']).opts(height=500, width=800)
    return lineoverlay


def rand_jitter(arr):
    stdev = .01 * (max(arr) - min(arr))
    return arr + np.random.randn(len(arr)) * stdev


def plot_tuning_curve(tuning_curve, error, fit=None, trials=None, pref_angle=None, ax=None, **kwargs):
    if ax is None:
        fig = plt.figure(dpi=300, figsize=(10, 6))
        ax = fig.add_subplot(111)

    tuning = ax.errorbar(tuning_curve[:,0], tuning_curve[:,1], yerr=error, elinewidth=0.5, **kwargs)

    if fit is not None:
        ax.plot(fit[:,0], fit[:, 1], c='r')
        
    if pref_angle is not None:
        ax.axvline(pref_angle, color='k', linewidth=1)

    if trials is not None:
        ax.scatter(trials[:, 0], rand_jitter(trials[:, 1]), marker='.', c='k', alpha=0.5)

    return tuning


def plot_polar_tuning_curve(tuning_curve, error, fit=None, trials=None, pref_angle=None, ax=None, **kwargs):
    theta_max = kwargs.pop('theta_max', 360)
    
    if ax is None:
        fig = plt.figure(dpi=300, figsize=(10, 6))
        ax = fig.add_subplot(111, projection='polar')

    tuning = ax.errorbar(np.deg2rad(tuning_curve[:,0]), tuning_curve[:,1], yerr=error, elinewidth=0.5, **kwargs)

    if fit is not None:
        ax.plot(np.deg2rad(fit[:,0]), fit[:, 1], c='r')

    if trials is not None:
        ax.scatter(np.deg2rad(trials[:, 0]), rand_jitter(trials[:, 1]), marker='.', color='k', alpha=0.5)
            
    if pref_angle is not None:
        ax.axvline(np.deg2rad(pref_angle), color='k', linewidth=1)

    
    ax.set_thetamax(theta_max)
    ax.set_theta_zero_location("W")
    ax.set_theta_direction(-1)
    ax.set_rlabel_position(180)

    return tuning

    
def plot_tuning_curve_hv(tuning_curve, fit=None, error=None, trials=None, pref_angle=None, **kwargs):
    overlay = []
    
    tuning = hv.Curve(tuning_curve).opts(width=600, height=300, **kwargs)
    overlay.append(tuning)

    error_plot = hv.Spread((*tuning_curve.T, error)).opts(fill_alpha=0.25)
    overlay.append(error_plot)
    
    if fit is not None:
        fit_plot = hv.Curve(fit)
        overlay.append(fit_plot)
        

        
    if trials is not None:
        trials_plot = hv.Scatter(trials).opts(color='k', size=3)
        overlay.append(trials_plot)
        
    if pref_angle is not None:
        pref_plot = hv.VLine(pref_angle).opts(color='k', line_width=1)
        overlay.append(pref_plot)
                  
    return hv.Overlay(overlay)
def plot_visual_tuning_curves(cell, data_type='norm_spikes', polar=True, axes=None):

    direction_data = getattr(cell, data_type + '_direction_props')
    orientation_data = getattr(cell, data_type + '_orientation_props')

    columns = ['tuning_curve_norm', 'std_norm', 'fit_curve', 'shown_pref', 'trial_resp_norm']

    if axes is None:
        if polar:
            fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,4), subplot_kw=dict(projection="polar"))
        else:
            fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,4))

    if polar:
        plot_func = plot_polar_tuning_curve
    else:
        plot_func = plot_tuning_curve

    # Plot directions    
    fig = plot_func(direction_data[columns[0]][0], 
                    direction_data[columns[1]][0],
                    fit=direction_data[columns[2]][0], 
                    pref_angle=direction_data[columns[3]][0],
                    trials=direction_data[columns[4]][0],
                    ax=axes[0]
                    )

    if polar:
        fig = plot_func(orientation_data[columns[0]][0], 
                        orientation_data[columns[1]][0],
                        fit=orientation_data[columns[2]][0], 
                        pref_angle=orientation_data[columns[3]][0],
                        trials=orientation_data[columns[4]][0],
                        ax=axes[1], theta_max=180
                       )
    else:
        fig = plot_func(orientation_data[columns[0]][0], 
                        orientation_data[columns[1]][0],
                        fit=orientation_data[columns[2]][0], 
                        pref_angle=orientation_data[columns[3]][0],
                        trials=orientation_data[columns[4]][0],
                        ax=axes[1]
                       )

    return fig, axes
def plot_fixed_ethogram(experiment, num_cells=15, **kwargs):
    kinem = experiment.kinematics
    time = kinem.time_vector

    # Select cells
    cells = np.random.choice(experiment.cells, num_cells)
    
    fig, axes = plt.subplots(nrows=num_cells+4, ncols=1, sharex=True, **kwargs)
    axes_labels = ['trials', 'run speed', 'pupil dia.', 'pupil pos.'] + list(cells)

    # Plot trials and kinematics
    trial_tick = np.where(kinem.trial_num > 0, 1, 0)
    axes[0].plot(time, trial_tick, 'k')
    
    
    axes[1].plot(time, np.abs(kinem.wheel_speed), 'k')
    axes[1].set_ylim((0, 1))

    smoothed_diameter = normalize(np.unwrap(fk.jump_killer(kinem.pupil_diameter, 2), 10))
    axes[2].plot(time, smoothed_diameter, 'k')
    axes[2].set_ylim((0, 1))

    smoothed_x = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_x, 3), 15)
    smoothed_x -= np.mean(smoothed_x)
    smoothed_y = np.unwrap(fk.jump_killer(kinem.fit_pupil_center_y, 3), 15)
    smoothed_y -= np.mean(smoothed_y)
    axes[3].plot(time, smoothed_x, 'k', time, smoothed_y, 'r')
    

    # Plot neural data
    for i, cell in enumerate(cells):
        axes[i+4].plot(time, exp.cell_props[cell].norm_dff[cell], 'k', alpha=0.7)
        axes[i+4].plot(time, exp.cell_props[cell].norm_spikes[cell], 'g', alpha=0.5)
        axes[i+4].set_ylim((0, 1))

    # Remove ticks
    for i, a in enumerate(axes):   
        a.set_ylabel(axes_labels[i], rotation=45, va='center', wrap=True, x=-50)
        a.spines["top"].set_visible(False)
        a.spines["right"].set_visible(False)
        if i < len(axes)-1:
            a.spines["bottom"].set_visible(False)
            a.xaxis.set_tick_params(bottom=False)
    
    axes[-1].set_xlabel ("Time (s)")

    return fig

# Load the Data

In [None]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string + r", analysis_type:preprocessing"

# get the paths from the database
file_path, paths_all, parsed_query, date_list, animal_list = fdh.fetch_preprocessing(search_string)

animal_idxs = [i for i,d in enumerate(animal_list) if d==parsed_query['mouse'].lower()]
good_entries = [file_path[index] for index in animal_idxs]
input_paths = [paths_all[index] for index in animal_idxs]

# # assemble the output path
print(input_paths)

In [None]:
fixed = WirefreeExperiment(input_paths[0], use_xarray=False)
free = WirefreeExperiment(input_paths[1], use_xarray=False)

# Remove baselines from the data for later processing
Also normalize inferred spike responses

In [None]:
def normalize_spike_responses(ds, quantile=0.07):
    ds_norm = ds.fillna(0)
    
    # get the number of cells
    if type(ds) is pd.DataFrame:
        cells = [el for el in ds.columns if "cell" in el]
        # Normalize cell responses across all sessions
        ds_norm[cells].apply(normalize)
    
        # Get the 7th percentile of activity per cell for each stimulus
        # Try 7th/8th percentile
        percentiles = ds_norm.groupby(['direction'])[cells].quantile(quantile)
    
        # get the baselines - The first row is the inter-trial interval
        baselines = percentiles.iloc[0, :]
    
        # Subtract baseline from everything
        ds_norm[cells].subtract(baselines, axis=1)

    elif type(ds) == xr.Dataset:
        # TODO
        cells = [el for el in ds.data_vars if "cell" in el]
    

    return ds_norm

In [None]:
free.norm_spikes = normalize_spike_responses(free.raw_spikes)
fixed.norm_spikes = normalize_spike_responses(fixed.raw_spikes)

# Populate cell objects with their neural activity

In [None]:
def add_activity_to_cell(experiment):
    # Populate spikes
    for cell_id, cell_obj in experiment.cell_props.items():
        cell_obj.raw_spikes = experiment.raw_spikes.iloc[:, [*range(6), experiment.raw_spikes.columns.get_loc(cell_id)]]
        cell_obj.norm_spikes = experiment.norm_spikes.iloc[:, [*range(6), experiment.norm_spikes.columns.get_loc(cell_id)]]
        
        cell_obj.raw_fluor = experiment.raw_fluor.iloc[:, [*range(6), experiment.raw_fluor.columns.get_loc(cell_id)]]

In [None]:
add_activity_to_cell(free)
add_activity_to_cell(fixed)

# Extract matched cells

In [None]:
def create_matched_datasets(exp, dataset_names, cell_ids, matched_ids):
    cell_cols = [f"cell_{el:04}" for el in cell_ids]
    new_cols = [f"cell_{el:04}" for el in matched_ids]
    renaming_mapper = dict(zip(cell_cols, new_cols))
    new_ds_names  = [f'{name}_matched' for name in dataset_names]
    
    for ds_name, new_ds_name in zip(dataset_names, new_ds_names):
        ds = getattr(exp, ds_name)
        ds_vars = ds.iloc[:, :6]
        traces = ds[cell_cols]
        traces.rename(columns=renaming_mapper, inplace=True)
        traces = pd.concat([ds_vars, traces], axis=1)
        setattr(exp, new_ds_name, traces)


for i, exp in enumerate([fixed, free]):
    cell_nums = [int(cell.split("_")[-1])  for cell in exp.cells]
    match_col = np.argwhere([exp.metadata.exp_type in col for col in list(matches.columns)]).flatten()[0]
    
    matched_cells = matches.iloc[:, match_col].to_numpy()
    unmatched_cells = np.setdiff1d(cell_nums, match_cells)

    rename_match = matches.index.to_numpy()

    create_matched_datasets(exp, ['raw_spikes', 'norm_spikes', 'raw_fluor'], matched_cells, rename_match)

# Determine responsivity

In [None]:
ds = matched_data[0]
# get the number of cells
cell_spikes = [el for el in ds.columns if 'spikes' in el]

# Get response per trial for each cell
for cell in cell_spikes:
    trial_responses = ds[ds.trial_num >= 1].groupby(['trial_num'])[cell].mean().to_numpy()
    isi_responses = ds[ds.trial_num == 0][cell].mean()
#     isi_responses = np.ones(trial_responses.shape) * isi_responses
    
    ttest = ttest_ind(isi_responses, trial_responses)

# trial_start = ds[ds.trial_num >= 1].groupby(['trial_num']).time_vector.nth(0)
# trial_end = ds[ds.trial_num >= 1].groupby(['trial_num']).time_vector.nth(-2)
# trial_times = np.array((trial_start, trial_end)).T

In [None]:
isi_responses

# Some simple exploration

## Plot activity

In [None]:
# Inferred spikes
spike_plots = [trace_raster(free.norm_spikes), trace_raster(fixed.norm_spikes)]
hv.Layout(spike_plots).cols(2).opts(shared_axes=False)

In [None]:
# allocate memory for the plots
fluor_plots = [trace_raster(free.raw_fluor), trace_raster(fixed.raw_fluor)]
hv.Layout(fluor_plots).cols(2).opts(shared_axes=False)

In [None]:
# average across repeats and time


## Visualize the activity for all units during the inter-trial interval

In [None]:
free_iti = free.norm_spikes.loc[free.norm_spikes['trial_num'] == 0, ['time_vector'] + free.cells].reset_index(drop=True)
trace_raster(free_iti).opts(ylabel="Cells", width=900, height=500)

In [None]:
fixed_iti = fixed.norm_spikes.loc[fixed.norm_spikes['trial_num'] == 0, ['time_vector'] + fixed.cells].reset_index(drop=True)
trace_raster(fixed_iti).opts(ylabel="Cells", width=900, height=500)

## Plot the activity of each cell 

In [None]:
direction_plots = []
for cell in free.cells[:2]:
    for direction in free.exp_params['direction']:
        individual_responses = free.norm_spikes.groupby(['trial_num'])[cell].agg(list)
        resp_array = np.array(list(zip_longest(*individual_responses, fillvalue=np.NaN))).T
        current_direction = np.nanmean(resp_array, axis=0)
        current_sem = sem(resp_array, axis=0, nan_policy='omit')

        x = np.arange(resp_array.shape[-1])
        plot = hv.Curve(current_direction).opts(ylabel=str(direction), xlabel="time")
        plot2 = hv.Spread((x, current_direction, current_sem))
        plot.opts(xrotation=45)
        direction_plots.append(plot*plot2)

hv.Layout(direction_plots).opts(shared_axes=False).cols(len(free.exp_params['direction']))

## Plot direction selectivity

In [None]:
def plot_trial_responses(ds, key, cells=None):
    # allocate memory for the cell plots
    cell_plots = []

    if cells is None:
        cells = [el for el in data.columns if 'cell' in el]
    num_cells = len(cells)
    
    tuning = ds.groupby([key])[cells].mean()
    tuning_sem = ds.groupby([key])[cells].sem()
    percentiles = ds.groupby([key])[cells].quantile(0.07)

    # get the baselines - The first column is the inter-trial interval
    baselines = percentiles.iloc[0, percentiles.columns.get_loc(cells[0]):percentiles.columns.get_loc(cells[-1])+1]

    # Subtract baseline from everything
    ds[cells].subtract(baselines, axis=1)
    
    # Get the mean reponse on each trial
    trial_responses = ds.groupby([key, 'trial_num'])[cells].agg(np.nanmean)
    # Drop the trial number level and regroup by orientation
    trial_responses = trial_responses.droplevel(['trial_num']).groupby([key]).agg(list)
    
    for cell in cells:
        
        # get the current cell dActivity and sem
        current_cell = tuning.iloc[1:, tuning.columns.get_loc(cell)]
        current_sem = tuning_sem.iloc[1:, tuning_sem.columns.get_loc(cell)]
        
        # get the current cell dActivity
        current_cell_trials = trial_responses.iloc[1:, trial_responses.columns.get_loc(cell)].to_list()
        current_cell_trials = list_lists_to_array(current_cell_trials)

        # get the orientation or direction
        label = current_cell.index.to_numpy()
        label_by_observation = np.multiply(np.ones(current_cell_trials.shape), label[:, np.newaxis])
        
        # Create a 2D array of positions and values for scatter plot
        X = np.vstack((np.ravel(label_by_observation), np.ravel(current_cell_trials))).T

        # plot
        if key == 'direction':
            x_lim = (-180, 180)
        else:
            x_lim= (0, 180)
            
        plot_scatter = hv.Scatter(X).opts(color='r', size=3, xlabel=key, xlim=x_lim, xrotation=45)
        plot_mean = hv.Curve((label, current_cell))
        plot_sem = hv.Spread((label, current_cell.values, current_sem))
        cell_plots.append(plot_scatter*plot_mean*plot_sem)
        
    return cell_plots

In [None]:
# allocate memory for the cell plots
cell_plots = []

# for all the files
for ds in matched_data:
    ds_plots = plot_trial_responses(ds, 'direction')
    cell_plots.append(ds_plots)
        
num_cells = len(cell_plots[0])
cell_plots = sum(cell_plots, [])
hv.Layout(cell_plots).opts(shared_axes=False).cols(num_cells)

## Plot orientation selectivity

In [None]:
# allocate memory for the cell plots
cell_plots = []

# for all the files
for ds in matched_data:
    ds_plots = plot_trial_responses(ds, 'orientation')
    cell_plots.append(ds_plots)
        
num_cells = len(cell_plots[0])
cell_plots = sum(cell_plots, [])
hv.Layout(cell_plots).opts(shared_axes=False).cols(num_cells)

# Tuning with circular variance
As seen in Carandini, Rose, Schumacher papers

## Circular Functions

In [None]:
def map_statistic(x, func):
    x_array = np.array(x)
    output = func(x_array, axis=1)
    return output

## Main Loop

In [None]:
def tuning_loop(experiment, key, **kwargs):

    for cell_name, cell in list(experiment.cell_props.items()):
        calculate_tuning(cell, key, **kwargs)
        
def calculate_tuning(cell, tuning_kind, data_type='norm_spikes', tuning_fit='von_mises'):
    # Get the mean reponse per trial and drop the inter-trial interval from df
    activity = getattr(cell, data_type).copy()
    trial_activity = activity.groupby([tuning_kind, 'trial_num'])[cell.id].agg(np.nanmean)
    trial_activity = trial_activity.droplevel(['trial_num'])
    trial_activity = trial_activity.drop(trial_activity[trial_activity.index == -1000].index)
    trial_responses = trial_activity.reset_index()

    #-- Create the response vectors --#
    mean_resp, angles = tuning.generate_response_vector(trial_responses, np.nanmean)
    sem_resp, _ = tuning.generate_response_vector(trial_responses, sem, nan_policy='omit')
    std_resp, _ = tuning.generate_response_vector(trial_responses, np.std)

    # Normalize if cell is responsive. Here we normalize the responses of each cell to the maximum response of the cell on any given trial
    if np.max(mean_resp) > 0:
        
        norm_trial_resp = trial_responses.copy()
        norm_trial_resp[cell.id] = normalize(trial_responses[cell.id])
        
        norm_mean_resp, _ = tuning.generate_response_vector(norm_trial_resp, np.nanmean)
        norm_sem_resp, _ = tuning.generate_response_vector(norm_trial_resp, sem, nan_policy='omit')
        norm_std_resp, _ = tuning.generate_response_vector(norm_trial_resp, np.std)
        
    else:
        
        norm_mean_resp = mean_resp
        norm_trial_resp = trial_responses
        norm_sem_resp = sem_resp
        norm_std_resp = std_resp
    
    # -- Fit tuning curves to get preference-- #
    if 'direction' in tuning_kind:
        if tuning_fit == 'von_mises':
            fit_function = tuning.calculate_pref_direction_vm
        else:
            fit_function = tuning.calculate_pref_direction
    else:
        fit_function = tuning.calculate_pref_orientation

    # Calculate fit on whole dataset and get R2
    fit, fit_curve, pref_angle, real_pref_angle = fit_function(norm_trial_resp[tuning_kind], norm_trial_resp[cell.id])
    fit_r2 = tuning.fit_r2(norm_trial_resp[tuning_kind], norm_trial_resp[cell.id], fit_curve[:,0], fit_curve[:,1])

    # -- Get resultant vector and respose variance-- #
    thetas = np.deg2rad(norm_trial_resp[tuning_kind])
    magnitudes = norm_trial_resp[cell.id]
    angle_sep = np.mean(np.diff(thetas))
    
    resultant_length = circ.resultant_vector_length(thetas, w=magnitudes, d=angle_sep)
    resultant_angle = circ.mean(thetas, w=magnitudes, d=angle_sep)
    resultant_angle = np.rad2deg(resultant_angle)

    circ_var = circ.var(thetas, w=magnitudes, d=angle_sep)
    responsivity = 1 - circ_var

    # -- Run permutation test -- #
    # Here we shuffle the trial IDs and compare the real selectivity index to the bootstrapped distribution
    _, shuffled_responsivity = tuning.bootstrap_responsivity(thetas, magnitudes, num_shuffles=500)
    p = percentileofscore(shuffled_responsivity, responsivity, kind='mean') / 100.

    # Try leave one out

    # -- Assign variables to the cell class -- #
    tuning_label = tuning_kind.split('_')[0]
    
    vars = ['trial_resp', 'trial_resp_norm', 'mean', 'mean_norm', 'std', 'std_norm', 'sem', 'sem_norm', \
            'tuning_curve', 'tuning_curve_norm', 'resultant', 'circ_var', 'responsivity', \
            'fit', 'fit_curve', 'fit_r2', 'pref', 'shown_pref']
    
    data = [trial_responses.to_numpy(), norm_trial_resp.to_numpy(), mean_resp, norm_mean_resp, std_resp, norm_std_resp, sem_resp, norm_sem_resp, \
            np.vstack([angles, mean_resp]).T, np.vstack([angles, norm_mean_resp]).T, (resultant_length, resultant_angle), circ_var, responsivity, \
            fit, fit_curve, fit_r2, pref_angle, real_pref_angle]

    df = pd.DataFrame(columns=vars, dtype='object')
    df = df.append(dict(zip(vars, data)), ignore_index=True)

    setattr(cell, f'{data_type}_{tuning_label}_props', df)

In [None]:
%%time
keys = ['direction_wrapped', 'orientation']

for key in keys:
    tuning_loop(fixed, key, data_type='norm_spikes')
    tuning_loop(free, key, data_type='norm_spikes_still')

In [None]:
cell_name, cell = list(fixed.cell_props.items())[-1]
direction_plot = plot_tuning_curve_hv(cell.spikes_props.direction_wrapped_tuning_curve_norm,
                                     error=cell.spikes_props.direction_wrapped_sem_resp_norm,
                                     fit=cell.spikes_props.direction_wrapped_fit_curve, 
                                     pref_angle=cell.spikes_props.direction_wrapped_pref,
                                     trials=cell.spikes_props.direction_wrapped_trial_resp_norm
                                     )

orientation_plot = plot_tuning_curve_hv(cell.spikes_props.orientation_tuning_curve_norm,
                                        error=cell.spikes_props.orientation_sem_resp_norm,
                                        fit=cell.spikes_props.orientation_fit_curve, 
                                        pref_angle=cell.spikes_props.orientation_pref,
                                        trials=cell.spikes_props.orientation_trial_resp_norm
                                       )
hv.Layout([direction_plot + orientation_plot]).opts(shared_axes=False)

In [None]:
cell_name, cell = list(fixed.cell_props.items())[-1]
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,4))
fig = plot_tuning_curve(cell.spikes_props.direction_wrapped_tuning_curve_norm, ax=axes[0],
                        error=cell.spikes_props.direction_wrapped_std_resp_norm,
                        fit=cell.spikes_props.direction_wrapped_fit_curve, 
                        pref_angle=cell.spikes_props.direction_wrapped_pref,
                        trials=cell.spikes_props.direction_wrapped_trial_resp_norm,
                       )
fig = plot_tuning_curve(cell.spikes_props.orientation_tuning_curve_norm, ax=axes[1],
                        error=cell.spikes_props.orientation_std_resp_norm,
                        fit=cell.spikes_props.orientation_fit_curve, 
                        pref_angle=cell.spikes_props.orientation_pref, 
                        trials=cell.spikes_props.orientation_trial_resp_norm,
                       )

In [None]:
cell_name, cell = list(fixed.cell_props.items())[100]
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,4), subplot_kw=dict(projection="polar"))
fig = plot_polar_tuning_curve(cell.spikes_props.direction_wrapped_tuning_curve_norm, ax=axes[0],
                        error=cell.spikes_props.direction_wrapped_std_resp_norm,
                        fit=cell.spikes_props.direction_wrapped_fit_curve, 
                        pref_angle=cell.spikes_props.direction_wrapped_pref,
                        trials=cell.spikes_props.direction_wrapped_trial_resp_norm,
                       )

fig = plot_polar_tuning_curve(cell.spikes_props.orientation_tuning_curve_norm, ax=axes[1],
                        error=cell.spikes_props.orientation_std_resp_norm,
                        fit=cell.spikes_props.orientation_fit_curve, 
                        pref_angle=cell.spikes_props.orientation_pref, 
                        trials=cell.spikes_props.orientation_trial_resp_norm,
                        theta_max=180
                       )

In [None]:
cell_name, cell = list(free.cell_props.items())[0]
fig1 = fp.plot_polar(cell.spikes_props.direction_wrapped_tuning_curve)
ax = fig1.gca()
mag, dir = cell.spikes_props.direction_wrapped_resultant
ax.arrow(np.deg2rad(dir), 0, 0, mag, width = 0.2, facecolor = 'red', lw=1, head_length=0.05)
plt.polar(np.deg2rad(cell.spikes_props.direction_wrapped_fit_curve[:,0]), cell.spikes_props.direction_wrapped_fit_curve[:,1], color='r')

ax.set_theta_zero_location("W")
ax.set_theta_direction(-1)

In [None]:
cell_name, cell = list(fixed.cell_props.items())[200]
trial_resp_dir = cell.spikes_props.direction_wrapped_trial_resp_norm
trial_resp_ori = cell.spikes_props.orientation_trial_resp_norm


fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15,5) , subplot_kw=dict(projection="polar"))

fit, fit_curve, pref_angle, real_pref_angle = tuning.calculate_pref_direction_vm(trial_resp_dir[:,0], trial_resp_dir[:,1])
fig = plot_polar_tuning_curve(cell.spikes_props.direction_wrapped_tuning_curve_norm, ax=axes[0],
                        error=cell.spikes_props.direction_wrapped_std_resp_norm,
                        fit=fit_curve,
                        pref_angle=pref_angle,
                        trials=cell.spikes_props.direction_wrapped_trial_resp_norm,
                       )

# fit, fit_curve, pref_angle, real_pref_angle = tuning.calculate_pref_orientation(trial_resp_ori[:,0], trial_resp_ori[:,1])
fig = plot_polar_tuning_curve(cell.spikes_props.orientation_tuning_curve_norm, ax=axes[1],
                        error=cell.spikes_props.orientation_std_resp_norm,
                        fit=cell.spikes_props.orientation_fit_curve, 
                        pref_angle=cell.spikes_props.orientation_pref, 
                        trials=cell.spikes_props.orientation_trial_resp_norm,
                        theta_max=180
                       )

In [None]:
bins = np.linspace(0, max(responsivity, np.max(shuffled_responsivity)), retstep=0.2)[0]
values, bins, _ = plt.hist(shuffled_responsivity, bins=bins, edgecolor='k', linewidth=1)
plt.axvline(responsivity, color='r', linewidth=3)
plt.xlabel('1 - circ_var')
plt.ylabel('# of runs')
plt.text(0, 200, str(p)) 
plt.show()

In [None]:
# allocate memory for the cell plots
cell_plots = []

# for all the files
for ds in df_list_circvar:
    for i, cell in ds.iterrows():
        
        cell_plot = plot_tuning_curve_hv(cell.ori_tuning_curve,
                                         sem=cell.ori_sem,
                                         fit=cell.ori_fit_curve, 
                                         pref_angle=cell.ori_pref)
        
        cell_plots.append(cell_plot.opts(xlim=(0,180), xlabel="angle [deg]", title=str(cell.cell)))
        
num_cells = np.max([len(ds.cell.unique()) for ds in df_list_circvar])
hv.Layout(cell_plots).opts(shared_axes=False).cols(num_cells)

In [None]:
# allocate memory for the cell plots
cell_plots = []

# for all the files
for ds in df_list_circvar:
    for i, cell in ds.iterrows():
        
        cell_plot = plot_tuning_curve_hv(cell.dir_tuning_curve, 
                                         sem=cell.dir_sem,
                                         fit=cell.dir_fit_curve, 
                                         pref_angle=cell.dir_pref)
        cell_plots.append(cell_plot.opts(xlim=(0, 360), xlabel="angle [deg]", title=str(cell.cell)))
        
num_cells = np.max([len(ds.cell.unique()) for ds in df_list_circvar])
hv.Layout(cell_plots).opts(shared_axes=False).cols(num_cells)

In [None]:
shared_dir_sel = sig_dir_sel_svd.merge(sig_dir_sel_cv, on='cell', how='left', suffixes=['_svd', '_cv'])
shared_ori_sel = sig_ori_sel_svd.merge(sig_ori_sel_cv, on='cell', how='left', suffixes=['_svd', '_cv'])

In [None]:
shared_dir_sel[['cell', 'dir_sel_p', 'dir_pref_svd', 'dir_pref_cv', 'dir_resp']]

In [None]:
shared_dir_sel[['cell', 'ori_sel_p', 'ori_pref_svd', 'ori_pref_cv', 'ori_resp']]

# Population Vector Decoding

Use SVD chosen, direction selective cells

In [None]:
cell_properties_svd = df_list[-1]
# The cutoff of QI here is from the Baden paper
sig_dir_sel_svd = cell_properties_svd[(cell_properties_svd.dir_sel_p <= 0.15)].sort_values('dir_sel_p')
cell_tuning = sig_dir_sel_svd[['cell', 'dir_pref']].sort_index()
tuned_cells = cell_tuning.cell.to_list()
dir_prefs = cell_tuning.dir_pref.to_list()

In [None]:
print(tuned_cells)
print(dir_prefs)

In [None]:
# Get the mean reponse per trial
trial_responses = data[-1].groupby(['trial_num'])[['direction'] + tuned_cells].agg(np.nanmean)

# Drop the inter-trial interval from df
trial_responses = trial_responses.drop(trial_responses[trial_responses.index == 0].index)

real_direction = trial_responses.direction.to_numpy()

decoded_angles = []
for i, row in trial_responses.iterrows():
    direction = row.direction
    cell_resps = row[tuned_cells]
    
    #-- Get preferred tuning --#
    decoded_length, decoded_angle = get_resultant_vector(np.deg2rad(dir_prefs), cell_resps, 'direction')
    decoded_angle = np.rad2deg(decoded_angle)
    if decoded_angle > 180:
        decoded_angle = decoded_angle % 180 - 180
    decoded_angles.append(decoded_angle)
    
decoded_angles = np.array(decoded_angles)

# for cell in cells:

#     current_cell_responses = trial_responses.iloc[:, trial_responses.columns.get_loc(cell)]

#     # Get mean response across trials
#     current_cell_mean_resp = current_cell_responses.groupby([key]).apply(np.nanmean)

#     # Normalize responses and fill any NaNs with zeros
#     current_cell_mean_resp = current_cell_mean_resp.fillna(0)

#     # Get a list of directions/orientations
#     angles = current_cell_mean_resp.index.to_numpy()

#     # Wrap angles from [-180, 180] to [0, 360] for direction tuning, and sort angles and cell responses
#     # Needed for pycircstat toolbox
#     angles_wrapped = fk.wrap(angles)
#     sort_idx = np.argsort(angles_wrapped)
#     angles_sorted = angles_wrapped[sort_idx]

#     current_cell_mean_resp = current_cell_mean_resp.to_numpy()
#     sorted_mean_resp = current_cell_mean_resp[sort_idx]

#     #-- Get preferred tuning --#
#     resultant_length, resultant_angle = get_resultant_vector(np.deg2rad(angles_sorted), sorted_mean_resp, key)
#     resultant_angle = fk.wrap(np.rad2deg(resultant_angle))


In [None]:
print(real_direction)

In [None]:
print(decoded_angles)

In [None]:
real_direction - decoded_angles

In [None]:
plt.imshow(np.cov(real_direction, decoded_angles))

# Polar plots

In [None]:
# Test polar vector sum for direction
fig, axes = plt.subplots(nrows=len(data), ncols=len(cells), subplot_kw={'projection': 'polar'}, figsize=(20,40))

for j, ds in enumerate(data):
    cells = [el for el in ds.columns if 'cell' in el]
    tuning = ds.groupby(['direction'])[cells].mean()
    tuning_sem = ds.groupby(['direction'])[cells].sem()

    for i, cell in enumerate(cells):

        baseline = tuning.iloc[0, tuning.columns.get_loc(cell)]
        current_cell = tuning.iloc[1:, tuning.columns.get_loc(cell)] - baseline
        current_cell -= current_cell.min()
        current_cell_norm = current_cell / current_cell.max()
        current_sem = (tuning_sem.iloc[1:, tuning_sem.columns.get_loc(cell)])
        directions = current_cell.index.to_numpy()
        r, theta = polar_vector_sum(current_cell.to_numpy(), directions)

        axes[j,i].plot(np.deg2rad(directions), current_cell.to_numpy())
        axes[j,i].fill_between(np.deg2rad(directions), current_cell.to_numpy()+current_sem.to_numpy(), current_cell.to_numpy()-current_sem.to_numpy(), alpha=0.2)
        axes[j,i].plot([0, np.deg2rad(theta)], [0, current_cell.max()], color='r')
        axes[j,i].set_rmin(0)
        axes[j,i].set_theta_zero_location("W")
        axes[j,i].set_theta_direction(-1)
        axes[j,i].set_xticklabels(['0', '45', '90', '135', '+/-180', '-135', '-90', '-45'])
        axes[j,i].set_title(cell)

plt.tight_layout()

In [None]:
# Test polar vector sum for orientation
fig, axes = plt.subplots(nrows=len(data), ncols=len(cells), subplot_kw={'projection': 'polar'}, figsize=(20,40))

for j, ds in enumerate(data):
    cells = [el for el in ds.columns if 'cell' in el]
    tuning = ds.groupby(['orientation'])[cells].mean()
    tuning_sem = ds.groupby(['orientation'])[cells].sem()

    for i, cell in enumerate(cells):

        baseline = tuning.iloc[0, tuning.columns.get_loc(cell)]
        current_cell = tuning.iloc[1:, tuning.columns.get_loc(cell)] - baseline
        current_cell -= current_cell.min()
        current_cell_norm = current_cell / current_cell.max()
        current_sem = (tuning_sem.iloc[1:, tuning_sem.columns.get_loc(cell)])
        directions = current_cell.index.to_numpy()
        r, theta = polar_vector_sum(current_cell.to_numpy(), directions)

        axes[j,i].plot(np.deg2rad(directions), current_cell.to_numpy())
        axes[j,i].fill_between(np.deg2rad(directions), current_cell.to_numpy()+current_sem.to_numpy(), current_cell.to_numpy()-current_sem.to_numpy(), alpha=0.2)
        axes[j,i].plot([0, np.deg2rad(theta)], [0, current_cell.max()], color='r')
        axes[j,i].set_rmin(0)
        axes[j,i].set_theta_zero_location("W")
        axes[j,i].set_theta_direction(-1)
        axes[j,i].set_thetamax(180)
        axes[j,i].set_title(cell)
