# single cell analysis
this script contains
+ quality metrics
+ stability
+ sig up or down mod -> writing the results to UnitsDf
+ single cell plots (raster / firing rates)

In [42]:
# imports
import sys, os
from pathlib import Path
from tqdm import tqdm
import pickle

import numpy as np
import pandas as pd

import pyks_tools as pkt
import pynapple as nap
from stim_tools import get_stim_dur_offset

In [5]:
# path definitions
# maybe it's possible to abbreviate this
exp_folder = Path('/media/georg/htcondor/shared-paton/georg/DAtime/data/2023-02-17_JJP-05313-dh_B_1-2-3/')
run_folder = exp_folder / 'stim_run_2_g0'
imec_bin_path = run_folder / 'stim_run_2_g0_t0.imec0.ap.bin'
ni_bin_path = run_folder / 'stim_run_2_g0_t0.nidq.bin'
ks_folder = run_folder / 'pyks2_output'
results_folder = ks_folder / 'results'

In [43]:
# path definitions
exp_folder = Path('/media/georg/htcondor/shared-paton/georg/DAtime/data/batch_24a/2024-06-06_JJP-08672_dh_1-6-1')
run_folder = exp_folder / 'stim_run_1_g0'
imec_bin_path = run_folder / 'stim_run_1_g0_t0.imec0.ap.bin'
ni_bin_path = run_folder / 'stim_run_1_g0_t0.nidq.bin'
ks_folder = run_folder / 'ibl_sorter_results'
results_folder = ks_folder / 'results'
os.makedirs(results_folder, exist_ok=True)

In [91]:
# load stim data
StimsDf = pd.read_csv(run_folder / 'StimsDf.csv')
StimsDf['stim_id'] = StimsDf['stim_id'].astype(str)

with open(run_folder / 'stim_classes.pkl', 'rb') as fH:
    stim_classes = pickle.load(fH)

# load ephys data
UnitsDf = pd.read_csv(results_folder / 'UnitsDf.csv')
units = nap.load_file(str(results_folder / 'units.npz'))

## VPL

In [48]:
# VPL stim
stim_times = StimsDf.loc[StimsDf['stim_id'] == '1']['t'].values # THIS WILL CRASH in the future
# stim_times = nap.Ts(t=stim_times)

n_stims = stim_times.shape[0]
n_units = len(units)

N_pre = np.zeros((n_stims, n_units))
N_post = np.zeros((n_stims, n_units))

stim_dur, t_offset = get_stim_dur_offset(stim_classes['1']) # HARDCODE

# building windows for counts
w = 3 # symmetric window in seconds after and before stim
grace = 0.1 # extra seperating pad around trigger

t0 = stim_times + t_offset # the zero point for VPL

intervals_pre = nap.IntervalSet(start=t0-grace-w, end=t0-grace) # a window of size w shifted by grace
intervals_post = nap.IntervalSet(start=t0+stim_dur+grace, end=t0+stim_dur+grace+w)

In [49]:
n_stims = stim_times.shape[0]

N_pre = np.zeros((n_stims, n_units))
N_post = np.zeros((n_stims, n_units))

# pre
for i in tqdm(range(n_stims)):
    units_r = units.restrict(intervals_pre[i])
    N_pre[i,:] = [units_r[j].shape[0] for j in range(len(units))]

# post
for i in tqdm(range(n_stims)):
    units_r = units.restrict(intervals_post[i])
    N_post[i,:] = [units_r[j].shape[0] for j in range(len(units))]


  0%|          | 0/198 [00:00<?, ?it/s]

100%|██████████| 198/198 [00:09<00:00, 21.44it/s]
100%|██████████| 198/198 [00:08<00:00, 22.92it/s]


In [50]:
# computing the diff and on the diff the stats
N_diff = N_post - N_pre

# do the stats - testing against zero mean

from scipy.stats import ttest_1samp
p_values = np.array([ttest_1samp(N_diff[:,j], 0).pvalue for j in range(n_units)])

# default test is two-sided
UnitsDf['sigmod_vpl'] = p_values < 0.05
UnitsDf['upmod_vpl'] = N_diff.sum(axis=0) > 0
UnitsDf['downmod_vpl'] = N_diff.sum(axis=0) < 0

# one sided to test if there are more spikes post wrt to pre
p_values = np.array([ttest_1samp(N_diff[:,j], 0, alternative='greater').pvalue for j in range(n_units)])
UnitsDf['sigupmod_vpl'] = p_values < 0.05

## same for DA - slightly different

In [51]:
# DA stim
stim_times = StimsDf.loc[StimsDf['stim_id'] == '3']['t'].values # the hardcoded stim 

n_stims = stim_times.shape[0]
n_units = len(units)

N_pre = np.zeros((n_stims, n_units))
N_post = np.zeros((n_stims, n_units))

# stim_dur, t_offset = get_stim_dur_offset(stim_classes['1']) # HARDCODE

# building windows for counts
w = 3 # symmetric window in seconds after and before stim
grace = 0.1 # extra seperating pad around trigger

t0 = stim_times # zero point for DA stim is trigger time (no offset)

intervals_pre = nap.IntervalSet(start=t0-grace-w, end=t0-grace) # a window of size w shifted by grace
intervals_post = nap.IntervalSet(start=t0+grace, end=t0+grace+w)


In [52]:
# counting code is identical
n_stims = stim_times.shape[0]

N_pre = np.zeros((n_stims, n_units))
N_post = np.zeros((n_stims, n_units))

# pre
for i in tqdm(range(n_stims)):
    units_r = units.restrict(intervals_pre[i])
    N_pre[i,:] = [units_r[j].shape[0] for j in range(len(units))]

# post
for i in tqdm(range(n_stims)):
    units_r = units.restrict(intervals_post[i])
    N_post[i,:] = [units_r[j].shape[0] for j in range(len(units))]


  0%|          | 0/199 [00:00<?, ?it/s]

100%|██████████| 199/199 [00:08<00:00, 23.09it/s]
100%|██████████| 199/199 [00:08<00:00, 22.86it/s]


In [53]:
# stats code is slightly different

# computing the diff and on the diff the stats
N_diff = N_post - N_pre

# do the stats - testing against zero mean
# think of alternatively doing this one sided

from scipy.stats import ttest_1samp
p_values = np.array([ttest_1samp(N_diff[:,j], 0).pvalue for j in range(n_units)])

UnitsDf['sigmod_da'] = p_values < 0.05
UnitsDf['upmod_da'] = N_diff.sum(axis=0) > 0
UnitsDf['downmod_da'] = N_diff.sum(axis=0) < 0

# p_values = np.array([ttest_1samp(N_diff[:,j], 0, alternative='greater').pvalue for j in range(n_units)])
# UnitsDf['sigupmod_vpl'] = p_values < 0.05

In [54]:
# store result
UnitsDf.to_csv(results_folder / 'UnitsDf.csv', index=None)

##  Plotting

In [55]:
import matplotlib.pyplot as plt
%matplotlib qt

# order cells by mean or median
order = np.argsort(np.average(N_diff,axis=0))

n_units = UnitsDf.shape[0]

In [56]:
fig, axes = plt.subplots()

# colors = np.array(['k' if p > 0.05 else 'r' for p in p_values])
# colors[UnitDf.good.values] = 'g'
colors = np.zeros((UnitsDf.shape[0],4))
colors[:,3] = 1.0

# for i in np.where(UnitsDf['sigmod_vpl'] * UnitsDf['upmod_vpl'])[0]:
for i in np.where(UnitsDf['sigupmod_vpl'])[0]:
    colors[i,0] += 1.0

for i in np.where(UnitsDf['sigmod_da'])[0]:
    colors[i,2] += 1.0

avgs = np.average(N_diff[:,order],axis=0) # / w
sds = np.std(N_diff[:,order],axis=0) # / w
pct = np.percentile(N_diff[:,order], (5,95), axis=0)

# for i in range(n_units):
#     axes.plot([i,i],[avgs[i]-sds[i],avgs[i]+sds[i]], lw=0.5, color=colors[order[i]])
for i in range(n_units):
    axes.plot([i,i],[pct[0,i],pct[1,i]], lw=1.5, color=colors[order[i]])

axes.scatter(range(n_units), avgs, c=colors[order],s=4)
axes.axhline(0,linestyle=':',lw=1,color='k',alpha=0.8)
axes.set_xlabel('units')
axes.set_ylabel('\u0394 spikes/s')
import seaborn as sns
sns.despine(fig)


In [57]:
s = 30
plt.matshow(N_diff[:,order],vmin=-s,vmax=s, cmap='PiYG')

<matplotlib.image.AxesImage at 0x7f0ae8142590>

single unit raster plots

In [62]:
Rates = nap.load_file(str(results_folder / 'unit_rates.npz'))

In [93]:
def reslice_timestamps(T, slice_times, pre, post):
    T_slices = []
    for t in slice_times:
        ix = np.logical_and(T > t+pre, T < t+post)
        T_slices.append(T[ix] - t) # relative times
    return T_slices

pre, post = -2, 5
stim_times = nap.Ts(t=StimsDf['t'].values)
stim_intervals = nap.IntervalSet(start=stim_times.times()+pre, end=stim_times.times()+post)

t_rel = Rates.restrict(stim_intervals[0]).times() - stim_times[0].times()

# warning, duplication in memory
Rates_resliced = np.stack([Rates.restrict(interval).as_array() for interval in stim_intervals], axis=2)


unit_id = 0

colors = {'1': '#E54444', '2':'#5F35E6', '3':'#D54CE6'}
fig, axes = plt.subplots(nrows=4,sharex=True, figsize=[5,5])
for j, label in enumerate(['1','2','3']):
    spikes_sliced = reslice_timestamps(units[unit_id].times(), StimsDf.loc[StimsDf.stim_id == label].t, pre, post)
    for i in range(len(spikes_sliced)):
        t = spikes_sliced[i]
        y = np.ones(t.shape[0]) * i
        axes[j+1].plot(t,y, '.', color='k', markersize=1.5, alpha=0.35)

    # rates 
    trial_ix = np.where(StimsDf.stim_id == label)[0]
    rates = Rates_resliced[:, unit_id, trial_ix]
    rates_avg = np.average(rates,axis=1)
    axes[0].plot(t_rel, rates_avg, color=colors[label])
    axes[0].set_xlim(-1,5)

for ax in [axes[2],axes[3]]:
    ax.axvspan(0.0,2.0,linewidth=0,color='darkcyan',alpha=0.15)

for ax in [axes[1],axes[2]]:
    ax.axvspan(0.25,0.5,linewidth=0,color='firebrick',alpha=0.15)

for ax in axes[1:]:
    ax.set_ylabel('trial #')

axes[0].set_ylabel('rate (z)')
axes[-1].set_xlabel('time (s)')

fig.tight_layout()
fig.suptitle('unit_id:%i' % unit_id)
fig.subplots_adjust(top=0.925)
sns.despine(fig)


In [87]:
# fig, axes = plt.subplots()
# sns.histplot(UnitsDf,x='amp_min')
UnitsDf.iloc[6]

cluster_id                    6
amp_max               11.871768
amp_min                5.844698
amp_median             7.182305
amp_std_dB             1.143595
contamination               0.0
contamination_alt           0.0
drift                 49.213257
missed_spikes_est      0.316301
noise_cutoff          10.604946
presence_ratio         0.455508
presence_ratio_std    19.574046
slidingRP_viol              1.0
spike_count              3717.0
firing_rate            0.789403
label                  0.666667
KSLabel                    good
ContamPct                   0.0
Amplitude                  57.5
good                       True
n_spikes                   3717
frate                  0.789393
sigmod_vpl                False
upmod_vpl                  True
downmod_vpl               False
sigupmod_vpl              False
sigmod_da                 False
upmod_da                  False
downmod_da                 True
Name: 6, dtype: object

In [None]:

# %% raster plots
import seaborn as sns
from helpers import reslice_timestamps

colors = {'1':'firebrick','2':'purple','3':'darkcyan'}

for unit_id in tqdm(unit_ids):
    fig, axes = plt.subplots(nrows=4,sharex=True, figsize=[5,5])
    for j, label in enumerate(['1','2','3']):
        spikes_sliced = reslice_timestamps(spikes.times[unit_id], StimsDf.loc[StimsDf.stim_id == label].t, -1, 5)
        for i in range(len(spikes_sliced)):
            t = spikes_sliced[i]
            y = np.ones(t.shape[0]) * i
            axes[j+1].plot(t,y, '.', color='k', markersize=1, alpha=0.3)

        # rates 
        trial_ix = np.where(StimsDf.stim_id == label)[0]
        rates = Rates.resliced[:, unit_id, trial_ix]
        rates_avg = np.average(rates,axis=1)
        axes[0].plot(Rates.t_rel, rates_avg, color=colors[label])
        axes[0].set_xlim(-1,5)

    for ax in [axes[2],axes[3]]:
        ax.axvspan(0.0,2.0,linewidth=0,color='darkcyan',alpha=0.15)

    for ax in [axes[1],axes[2]]:
        ax.axvspan(0.25,0.5,linewidth=0,color='firebrick',alpha=0.15)

    for ax in axes[1:]:
        ax.set_ylabel('trial #')

    axes[0].set_ylabel('rate (z)')
    axes[-1].set_xlabel('time (s)')

    fig.tight_layout()
    fig.suptitle('unit_id:%i' % unit_id)
    fig.subplots_adjust(top=0.925)
    sns.despine(fig)
    
    fig.savefig(results_folder / 'plots' / 'rasters' / ('unit_%i.png' % unit_id))
    plt.close()