In [1]:
import matplotlib
%matplotlib tk
%autosave 180
%load_ext autoreload
%autoreload 2

import nest_asyncio
%config Completer.use_jedi = False

#
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import scipy
import numpy as np
import pandas as pd
import os
import os
os.chdir('..')

from calcium import calcium
from wheel import wheel
from visualize import visualize
from tqdm import trange


from scipy.io import loadmat

import umap

from sklearn.decomposition import PCA
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# 
np.set_printoptions(suppress=True)


Autosaving every 180 seconds


In [26]:
# Spike Duration
def compute_spike_distribution(data,
                               max_dur=5,
                               width_dur=0.1,
                               max_isi=120,
                               width_isi=1,
                               max_amp = 50,
                               width_amp = .1):

    sample_rate = 30.
    #width = .1
    #max_dur = 10

    durs = []
    n_spikes = []
    isi = []
    amplitudes = []
    median_amplitudes = []
    for k in trange(data.shape[0]):

        #
        temp = data[k]

        # FIND BEGINNIGN AND ENDS OF FLUORescence above some threshold
        from scipy.signal import chirp, find_peaks, peak_widths

        #
        peaks, _ = find_peaks(temp)  # middle of the pluse/peak

        #
        widths, heights, starts, ends = peak_widths(temp, peaks)
        xys = np.int32(np.vstack((starts, ends)).T)
        
        #
        median_amplitudes.append(np.nanmedian(heights))
        y = np.histogram(heights, bins = np.arange(0,max_amp,width_amp))
        amplitudes.append(y[0]/np.nanmax(y[0]))
       
        # duration histgorams:
        durs_local = (xys[:,1]-xys[:,0])/sample_rate
        y = np.histogram(durs_local, bins = np.arange(0,max_dur,width_dur))
        durs.append(y[0]/np.nanmax(y[0]))
        
        # isi histograms
        isi_local = (xys[1:,0]-xys[:-1,1])/sample_rate
        #print (isi_local)
        y = np.histogram(isi_local, bins = np.arange(0,max_isi,width_isi))
        isi.append(y[0]/np.nanmax(y[0]))
       
        #
        n_spikes.append(durs_local.shape[0])
       
    
    # 
    durs = np.vstack(durs)
    #durs[]
    isi = np.vstack(isi)
    n_spikes = np.hstack(n_spikes)
    heights = np.vstack(amplitudes)
    median_amplitudes = np.hstack(median_amplitudes)
   
    return durs, n_spikes, isi, heights, median_amplitudes

#
def plot_metrics(savefig, d):
    max_dur = 60
    width_dur=0.033

    max_isi=30
    width_isi=0.033



    if savefig:
        fontsize=5
    else:
        fontsize=12


    #
    fig=plt.figure(figsize=(10,5))
    nrows = 3
    ncols = 5
    grid = GridSpec(nrows, ncols,
                    left=0.1, bottom=0.15, 
                    right=0.94, top=0.94, 
                    wspace=0.3, hspace=0.3)


    #
    for typ in range(3):

        # 
        if typ==0:
            data = d['events_threshold']
            spikes_thresh = d['oasis_thresh_prefilter']
            max_amp=20000
            width_amp=1
            xmax_amp=200
            xmax_dur=5
            title = 'Fluorescence thresholded\n (~Steffen method)'

        elif typ==1:
            data = d['events_upphase_scaled']
            spikes_thresh = d['oasis_thresh_prefilter']
            max_amp=20000
            width_amp=1
            xmax_amp=200
            xmax_dur=5
            title = 'Smooth oasis scaled by \n# spikes in window (novel method)'

        elif typ==2: 
            data = d['spikes']
            spikes_thresh = d['oasis_thresh_prefilter']
            max_amp=50
            width_amp=.1
            xmax_amp=20
            xmax_dur=0.3
            title = 'Oasis spikes (thresholded to ) ' +str(spikes_thresh)


            
        gaussian_width = .1
        duration = data.shape[1]/30.

        # 
        durs, n_spikes, isi, heights, median_amplitudes = compute_spike_distribution(data,
                                                                              max_dur,
                                                                              width_dur,
                                                                              max_isi,
                                                                              width_isi,
                                                                              max_amp,
                                                                              width_amp)

        ##############################################
        ############# plot spike rates ###############
        ##############################################
        ax1 = fig.add_subplot(grid[typ, 0])
        plot_n_spikes_distributions(ax1, 
                                    n_spikes, 
                                    duration,
                                    fontsize
                                   )
        ax1.set_ylabel(title + "\n# of events / second", fontsize=fontsize)

        ##############################################
        ###### plot median spike heigh vs. rates #####
        ##############################################
        ax14 = fig.add_subplot(grid[typ, 1])

        plot_frate_vs_peak(ax14, 
                            median_amplitudes, 
                            n_spikes,
                            heights,
                            duration,
                            fontsize
                           )

        ##############################################
        ######## plot amplitude distributions ########
        ##############################################
        ax12 = fig.add_subplot(grid[typ, 2])
        if False:
            print ("heights: ", heights.shape)
            #heights = gaussian_filter1d(heights, gaussian_width)
            idx = np.argsort(np.argmax(heights,axis=1))
            ylabel = "Neuron ID]\n(sorted by peak)"
            xlabel = "Event amplitudes"
            plot_imshow(ax12, heights, idx, 
                        max_amp, 
                        width_amp, 
                        xmax_amp,
                        xlabel, ylabel,
                        fontsize)
        else:
            idx = np.argsort(n_spikes)
            ylabel = "Neuron ID\n(sorted by # spikes)"
            xlabel = "Event amplitudes"
            #xmax=500
            plot_imshow(ax12, heights, idx, 
                        max_amp, 
                        width_amp, 
                        xmax_amp,
                        xlabel, ylabel,
                       fontsize)
            plt.legend()

        ##############################################
        ######## plot durations distributions ########
        ##############################################
        ax2 = fig.add_subplot(grid[typ, 3])
        if False:
            durs = gaussian_filter1d(durs, gaussian_width)
            idx = np.argsort(np.argmax(durs,axis=1))
            ylabel = "Neuron ID\n(sorted by peak)"
            xlabel = "Duration of event (sec)"
            plot_imshow(ax2, durs, idx, 
                        max_dur,
                        width_dur,
                        xmax_dur,
                        xlabel, ylabel,
                       fontsize)
        else:
            idx = np.argsort(n_spikes)
            ylabel="Neuron ID\n(sorted by # spikes in unit)"
            xlabel="Duration of event (sec)"
            xmax=5
            plot_imshow(ax2,durs,idx,
                        max_dur,
                        width_dur,
                        xmax_dur,
                        xlabel, ylabel,
                       fontsize)

        # ##############################################
        # ########### plot ISI distributions ###########
        # ##############################################
        ax4 = fig.add_subplot(grid[typ, 4])
        if False:
            isi = gaussian_filter1d(isi, gaussian_width)
            idx = np.argsort(np.argmax(isi,axis=1))
            ylabel = "Neuron ID\n(sorted by peak)"
            xlabel = "Inter-event-interval (sec)"
            xmax=10
            plot_imshow(ax4,isi,idx,
                        max_isi,
                        width_isi,
                        xmax,
                        xlabel, ylabel,
                       fontsize)
        else:
            idx = np.argsort(np.argmax(durs,axis=1))
            ylabel="Neuron ID\n(sorted by # spikes)"
            xlabel="Inter-event-interval (sec)"
            xmax=10
            plot_imshow(ax4,isi,idx,
                        max_isi,
                        width_isi, 
                        xmax,
                        xlabel, ylabel,
                       fontsize)




    if savefig:
        plt.savefig('/home/cat/fig.svg',dpi=600)
        plt.close()
    else:
        plt.show()
    
#
def plot_imshow(ax,
                data,
                idx,
                max_len,
                time_bin,
                xmax,
                xlabel,
                ylabel,
                fontsize=10,
                ):

    # 
    data2 = data[idx].copy()

    ax.imshow(data[idx],
          aspect='auto',
          interpolation=None,
          extent = [0 , max_len, 0 , data.shape[0]],
          cmap='Greys')
    
    # 
    mua = np.nansum(data, axis=0)
    mua = mua/np.nanmax(mua)*data.shape[0]/3.  # scale mua to look better in plot
    t=np.arange(0,mua.shape[0], 1)*time_bin
    ax.plot(t,mua,c='red',
           linewidth=3,
           alpha=.7)

    # 
    mean = np.average(t, weights=mua)
    vals = []
    for k in range(mua.shape[0]):
        if mua[k]>0:
            vals.append(np.zeros(int(mua[k]*100))+t[k])
        
    vals = np.hstack(vals)
    median=np.nanmedian(vals)
        
    plt.plot([median,median],[0,data.shape[0]],
            '--',linewidth=3, 
             c='green', label='median')
    
    plt.plot([mean,mean],[0,data.shape[0]],
            '--',linewidth=3, 
             c='blue', label='mean')
    
    # 
    plt.xlim(0,xmax)
    plt.xlabel(xlabel, fontsize=fontsize)
    plt.ylabel(ylabel, fontsize=fontsize)
#     ax.tick_params(axis='both', which='major', labelsize=fontsize)


# 
def plot_frate_vs_peak(ax, medians,
                       n_spikes,
                       heights,
                       duration,
                       fontsize=10):
    
    
    #
    ax.scatter(np.float32(n_spikes)/duration, 
               medians, c='grey',
               edgecolor='black',
               alpha=.7
              )
    
    plt.xlim(0.001,10)
    plt.ylim(0.1,np.nanmax(medians)*1.1)

    plt.semilogx()
    
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    plt.xlabel("Event rate", fontsize=fontsize)
    plt.ylabel("Median amplitude of events in unit", fontsize=fontsize)

    
#  
def plot_n_spikes_distributions(ax1, 
                                n_spikes, 
                                duration,
                                fontsize=10,
                                min_rate=1E-3,
                                max_rate=10):
    
    #
    #fontsize=4
    idx = np.argsort(n_spikes)
    ax1.scatter(np.arange(idx.shape[0]),
                n_spikes[idx]/duration,
               c='black',
               s=10,
                edgecolor='black',
               alpha=.1)
    
    
    plt.xlabel("Neuron ID (sorted)", fontsize=fontsize)
    plt.xlim(0,idx.shape[0])
    #min_rate = 1E-3
    #max_rate = 2E-1
    
    # 
    plt.ylim(min_rate,max_rate)
    plt.semilogy()
    
    # 
    ax1.tick_params(axis='both', which='major', labelsize=fontsize)
    ax1t=ax1.twinx() 
    y = np.histogram(n_spikes[idx]/duration, bins=np.arange(0,1,0.001))
    
    # 
    yy = y[0]/np.nanmax(y[0])*n_spikes.shape[0]
    
    plt.plot(np.nanmax(y[1])-yy+n_spikes.shape[0],
             y[1][:-1],
             c='red',
             linewidth=2)
    
    #yy = gaussian_filter1d(yy, 1)
    
    idx = np.argmax(y[0][2:])
    
    # 
    plt.plot([0,n_spikes.shape[0]],
             [y[1][:-1][idx+2],y[1][:-1][idx+2]],
             #[idx,idx],
             '--', 
             linewidth=3,
             c='Grey')
    ax1t.semilogy()
    plt.yticks([])
    plt.ylim(min_rate,max_rate)  

In [27]:
#################################################
#################################################
#################################################
from scipy.ndimage import gaussian_filter1d
from matplotlib.gridspec import  GridSpec

#
fname = '/media/cat/4TB/donato/steffen/DON-004366/20210228/suite2p/plane0/binarized_traces.npz'
d = np.load(fname,
              allow_pickle=True)

savefig = False

# 
plot_metrics(savefig, d)

  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  amplitudes.append(y[0]/np.nanmax(y[0]))
  durs.append(y[0]/np.nanmax(y[0]))
  isi.append(y[0]/np.nanmax(y[0]))
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1028/1028 [00:00<00:00, 1353.03it/s]
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  amplitudes.append(y[0]/np.nanmax(y[0]))
  durs.append(y[0]/np.nanmax(y[0]))
  isi.append(y[0]/np.nanmax(y[0]))
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1028/1028 [00:00<00:00, 1285.58it/s]
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  amplitudes.append(y[0]/np.nanmax(y[0]))
  durs.append(y[0]/np.nanmax(y[0]))
  isi.append(y[0]/np.nanmax(y[0]))
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1028/1028 [00:01<00:00, 806.92it/s]
