# Decoding analysis of single cell spiking activity

We consider the spiking activity of individual SOMI near reward onset and train a Bayesian decoder assuming Poisson spiking (see Methods Decoder). We evaluate the decoding accuracy in a 2-fold cross-validation scheme and check for significant decoding through label shuffling.

The sampling rate is always 30 kHz. For units of time, I choose seconds, and thus for rates/frequencies Hz. I keep the timestamps of the spikes and the behavioral data mostly as integer values, i.e. as multiples of the inverse sampling rate.

This code should work with Python 3.9 and higher. The analysis for all cells of the expert animals takes about 1 hour.


## Imports and module reloading

In [None]:
#%matplotlib inline

import os
from pathlib import Path
import json

import numpy as np
import matplotlib.pyplot as plt

from sklearn.neighbors import KernelDensity
#from scipy.stats import ecdf 

# Module for the decoding analysis
#import decoding
from decoding import *

## Load existing data

In [None]:
base_p = Path(os.getcwd())
start_p = base_p / 'data'
print(start_p)

In [None]:
# Load data from expert or non-expert animals
# Change this from expert_s to non_expert_s to load the corresponding data
session_status = expert_s # non_expert_s

cells_data_fname = 'decoding_data__' + session_status + '.npy'
fname = start_p / cells_data_fname

cells_data_d = np.load(fname, allow_pickle=True).item()

print('Number of cells:', len(cells_data_d))
list(cells_data_d.keys())

## Perform decoding analysis for all cells

Because of shuffling test, the full decoding analysis takes a while.

In [None]:
# Start decoding at this time before reward onset:
decoding_offset_time = 1.5 # in s
decoding_offset = round(decoding_offset_time * sample_rate_Hz)

# Split variant for two-fold cross-validation:
split_variant = 0 # Split into even and odd trials
rng_split = None # np.random.default_rng(rng_split_seed) for split variants involving rng

split_variant_name = split_variant_names[split_variant]
print('Using split variant:', split_variant_name)

# Gaussian kernel bandwidth for instantaneous firing rate estimation
bw_ms = 200 # ms

decoding_s = f'_bw{bw_ms}' + f'_split_{split_variant_name}'
# print(decoding_s)

cells_results_d = {}
num_cells = 0
for cell_id_s, data_d in cells_data_d.items():
    print('Cell:', cell_id_s)    
    #if num_cells > 10: break # for testing

    # Starting point:
    # spikes = data_d['spikes'] # spike time steps
    # reward_start = data_d['reward_start'] # reward onset time steps
    # # Align spikes
    # # Start timestep for alignment:
    # align_start = reward_start - decoding_offset
    # spikes_aligned_list = align_spikes(spikes, align_start)

    # # Mean firing rate and CV of the aligned spiking activity
    # pooled_rate, pooled_cv = get_stats_aligned(spikes_aligned_list, plot_isi=False)
    
    # Get already aligned spikes:
    spikes_aligned_l = data_d['spikes_aligned_l']

    # Perform decoding analysis using aligned spikes
    decoding_results_d, ts_rate, rates_aligned = decoding_spikes_aligned(spikes_aligned_l, 
                                                                         bw_ms, split_variant, 
                                                                         rng_split, print_output=False)

    # Store results from decoding analysis
    decoding_results_d.update(pooled_rate=data_d['pooled_rate'], pooled_cv=data_d['pooled_cv'],
                              session_rate=data_d['session_rate'], session_cv=data_d['session_cv'])
    
    cells_results_d[cell_id_s] = decoding_results_d
    num_cells += 1

print(f'Data from {num_cells} cells has been analyzed.', len(cells_results_d))

In [None]:
# Save results of decoding analysis
cells_results_fname = 'decoding_results__' + session_status + decoding_s + '.npy'
fname = start_p / cells_results_fname

np.save(fname, cells_results_d)

## Select and sort cells according to the results


In [None]:
# Load results data if saved before
cells_results_fname = 'decoding_results__' + session_status + decoding_s + '.npy'
fname = start_p / cells_results_fname

cells_results_d = np.load(fname, allow_pickle=True).item()


In [None]:
# Put results for all units into arrays:
def get_results_arrs(units_results_d, min_rate=0.0):
    fc_list = []
    pv_list = []
    win_list = []
    pooled_rates_l = []
    session_rates_l = []
    keys_l = []
    #for results_d in units_results_d.values():
    for key, results_d in units_results_d.items():
        if results_d["pooled_rate"] >= min_rate:
            fc_list.append(results_d['fc_arr'][:,0])
            pv_list.append(results_d['pv_arr'][:,0])
            win_list.append(results_d['win_arr'])
            session_rates_l.append(results_d["session_rate"])
            pooled_rates_l.append(results_d["pooled_rate"])
            keys_l.append(key)
        
    fcs = np.array(fc_list)
    pvs = np.array(pv_list)
    wins = np.array(win_list)
    ws = wins[0]
    session_rates = np.array(session_rates_l)
    pooled_rates = np.array(pooled_rates_l)
    
    return fcs, pvs, ws, session_rates, pooled_rates, keys_l

# Get indices of the first window length with p-value in pvs is below the level pv_level:
def get_inds_pv_level(pv_level, pvs, ws, fcs):
    inds_pv_level = -np.ones(pvs.shape[0], dtype=int)
    w_pv_level = np.zeros(pvs.shape[0])
    fc_pv_level = np.zeros(fcs.shape[0])
    for i, pv, fc in zip(range(pvs.shape[0]), np.round(pvs,8), fcs):
        ind = np.argmax(pv <= pv_level)
        if ind == 0 and pv[0] > pv_level:
            inds_pv_level[i] = pvs.shape[1] #-1 # last index + 1
            w_pv_level[i] = np.nan
            fc_pv_level[i] = np.max(fc)
            # print(np.argmin(pv), fc[np.argmin(pv)], np.max(fc))
        else:
            #if ind == 0: print(pv, fc)
            inds_pv_level[i] = ind
            w_pv_level[i] = ws[ind]
            fc_pv_level[i] = fc[ind] # np.max(fc)
    return inds_pv_level, w_pv_level, fc_pv_level

In [None]:
# Minimum rate to include
min_rate = 1.5 # in Hz

keys_l_orig = list(cells_results_d.keys())
print(len(keys_l_orig))

fcs, pvs, ws, session_rates, pooled_rates, keys_l = get_results_arrs(cells_results_d, min_rate = min_rate)
num_units = fcs.shape[0]

print('Number of units included:', num_units)
print('Number of window lengths:', ws.shape[0])

### Limit considered window range

In [None]:
w_min = 0.25
w_max = 2.75
w_range = np.searchsorted(ws, [w_min, w_max])

fcs_lim = fcs[:,w_range[0]:w_range[1]+1]
pvs_lim = pvs[:,w_range[0]:w_range[1]+1]
ws_lim = ws[w_range[0]:w_range[1]+1]

print(ws_lim)
#print(fcs_lim.shape)

# Use considered window lengths
fcs = fcs_lim
pvs = pvs_lim
ws = ws_lim

### Get maxima, minima, and level crossings for sorting

In [None]:
# Maximum decoding accuracy (fraction correct) for each unit
fc_max = np.max(fcs, axis=1)
inds_fc_max = np.argmax(fcs, axis=1)
# get window lengths of maxima
win_fc_max = np.array([ws[ind] for ind in inds_fc_max])

# Signicance level for shuffle test
pv_level = 0.01
# First window length with significance level pvs below level pv_level
inds_pv_level, w_pv_level, fc_pv_level = get_inds_pv_level(pv_level, pvs, ws, fcs)

# Use mean rates as secondary sorting criterium:
inds_lexsort = np.lexsort((pooled_rates, -inds_pv_level)) # in "increasing" order
#inds_lexsort = np.lexsort((fc_max, -inds_pv_level)) # in "increasing" order

# list(zip(inds_pv_level[inds_lexsort[::-1]], pooled_rates[inds_lexsort[::-1]]))

### Plot sorted decoding results

In [None]:
# cmap = plt.cm.plasma
cmap = plt.cm.cool

units = np.arange(1,num_units+1)
fcs_sorted = fcs[inds_lexsort]
w_pv_level_sorted = w_pv_level[inds_lexsort]
keys_sorted = [keys_l[ind] for ind in inds_lexsort]

fig_width = 7.0 # inch
fig_height = 9.0 # inch
# fig_height = 5.0 # inch, for Non-expert
if session_status.startswith(non_expert_s):
    fig_height = 4.5

fig, ax = plt.subplots(1,1, figsize=(fig_width,fig_height), dpi=150,
             layout="constrained")
pc = ax.pcolormesh(ws, units, fcs_sorted,
              shading='nearest', cmap='viridis',
              vmin = 0.35, vmax=0.95) #vmin = 0.25, vmax=0.95) #

ax.axvline(decoding_offset_time, lw=1, color=cmap(0.95))
#ax.plot(win_fc_max[inds_lexsort], units, '+', color='w', alpha=0.5)
ax.plot(w_pv_level_sorted, units, 'w.', markersize=4)


ax.set_yticks(units, keys_sorted, fontsize='x-small');
ax.set_xlim([0.15, 2.85])
ax.set_ylim([0.25, num_units+0.75])
ax.set_xlabel('decoding window (s)')
fig.colorbar(pc, ax=ax, location='top', fraction=0.1, shrink=0.66, pad=0.01, label='accuracy')
plt.show(fig)

In [None]:
# Save figure:
plots_p = start_p

plot_name_s = session_status + '__decoding_results' + f'_offset{decoding_offset_time:.1f}' + decoding_s + f'_pv{100*pv_level:02.0f}'
fname = Path(plots_p) / (plot_name_s  + '.pdf')
#print(fname)
fig.savefig(fname, dpi=300, format='pdf', facecolor='w')
fname = Path(plots_p) / (plot_name_s  + '.png')
fig.savefig(fname, dpi=300, format='png', facecolor='w')
plt.close(fig)

## Save sorted cell keys

In [None]:
keys_sorted_name_s = session_status + '__keys_sorted' + f'_offset{decoding_offset_time:.1f}' + decoding_s + f'_pv{100*pv_level:02.0f}'
fname = Path(start_p) / (keys_sorted_name_s  + '.txt')

with open(fname, 'w') as f:
    f.write(json.dumps(keys_sorted[::-1]))

#Now read the file back into a Python list object
# with open('test.txt', 'r') as f:
#     a = json.loads(f.read())