In [1]:
import pickle as pk
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
from datetime import datetime
from datetime import timedelta
import thunderfish.pulsetracker as pt
matplotlib.rcParams['font.size'] = 8.0
from scipy.signal import fftconvolve

from pulse_tracker import pd, cluster_object
#%matplotlib notebook

In [2]:
import io, os
import pickle

class RenameUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        renamed_module = module
        if module == "pulse_tracker_liz_multichannel_ml_new":
            renamed_module = "pulse_tracker_liz_2D"

        return super(RenameUnpickler, self).find_class(renamed_module, name)


def renamed_load(file_obj):
    return RenameUnpickler(file_obj).load()


def renamed_loads(pickled_bytes):
    file_obj = io.BytesIO(pickled_bytes)
    return renamed_load(file_obj)


def kde_init(sigma, sampling_rate):
    """
    creates the gaussian function which can be used to convolve with the PSTH to create a firing rate graph
    """
    t = np.arange(-4*sigma, 4*sigma, 1 / sampling_rate)
    fxn = np.exp(-0.5 * (t / sigma) ** 2) / np.sqrt(
        2 * np.pi) / sigma  # gaussian function 2.3 (analysis of parallel spike trains)

    return t, fxn

def gaussian_convolve(all_spike_trains, fxn, sampling_rate, avg_opt,trial_length):
    """
    Takes an array of spike trains of different sizes,
    convolves it with a gaussian, returns the average gaussian convolve spikes
    """
    all_convolve_spikes = []
    all_pos = []
    for spike_train in all_spike_trains:

        spike_train = spike_train

        # Boolean list in length of trial length, where 1 means spike happened, 0 means no spike
        trial_bool = np.zeros(trial_length)
        spike_indx = (spike_train * sampling_rate).astype(np.int)
        trial_bool[spike_indx] = 1

        # convolve gaussian with boolean spike list
        convolve_spikes = np.asarray(fftconvolve(trial_bool, fxn, mode='same'))
        all_convolve_spikes.append(convolve_spikes)

    return all_convolve_spikes

In [3]:
# only subtract if there are actually other classes present?

# 0: use two channels. use original features. no temporal features 
# 2: use two channels. use two after subtracting other channels. no temporal features
# 3: use whole spatial pattern. no temporal features
# 4: use whole spatial pattern. use temporal features
# 5: use whole spatial pattern. use temporal features. do not subtract eods that are not there (above threshold).
# 6: same as 5, but only save spatial features of presumably single eods.
# 7: same as 4, but only save spatial features of presumably single eods.

In [5]:
file_count = 0

date = '2019-10-17-19_48'
dati = datetime(2019,10,17,19,48)

path = 'data/results/' + date + '/'

print(os.listdir(path))

for file in os.listdir(path):
    if '0' in file:
        try:
            print(path+file)
            co = pk.load(open(path+file,'rb'))
        except:
            print('file %i didnt work'%i)

['0.pkl', 'pics', '3.pkl', '2.pkl', '1.pkl']
data/results/2019-10-17-19_48/0.pkl


In [6]:
clusters = co.clusters
dt = co.dt
starttime = co.starttime
endtime = co.endtime
plot_title = dati + timedelta(0,int(starttime))
x = np.arange(starttime,endtime,dt)

In [7]:
clusters

{'potential_eod': <pulse_tracker.cluster_object at 0x7f5128cd0b50>,
 '1': <pulse_tracker.cluster_object at 0x7f5128cb77d0>,
 '2': <pulse_tracker.cluster_object at 0x7f5128aa5c10>}

In [8]:
clusters['1'].debug.keys()

dict_keys(['c_tf', 'c_sf', 'c_eod'])

In [9]:
idx = []
used_clusters = []
peakwidth = 20*co.dt

#if co.potential_eod != []:
#    clusters.append(co.potential_eod)

#if co.eel != []:
#    clusters.append(co.eel)

#if co.artifacts != []:
#    clusters.append(co.artifacts)

#if co.deleted_clusters != []:
#    for dc in co.deleted_clusters:
#        clusters.append(dc)

for k,c in clusters.items():
    peakindices, peakx, peakh = pt.discardnearbyevents(c.ts,np.ones(len(c.ts)),peakwidth*2)
    
    if len(peakindices) > 0:
        used_clusters.append(c)
        if len(idx) > 0:
            nidx = np.zeros(len(c.ts))
            nidx[:len(peakindices)] = c.ts[peakindices]
            idx = np.vstack([idx,nidx])
        else:
            idx = np.zeros(len(c.ts))
            idx[:len(peakindices)] = c.ts[peakindices]

In [10]:
idx

array([[5.11696667, 5.26076667, 5.3085    , ..., 0.        , 0.        ,
        0.        ],
       [0.22363333, 0.876     , 7.4216    , ..., 0.        , 0.        ,
        0.        ],
       [5.11696667, 5.26076667, 5.3085    , ..., 0.        , 0.        ,
        0.        ]])

In [11]:
idx[idx == 0] = starttime
idx = idx - starttime
idx[idx<0] = 0

In [12]:
idx

array([[5.11696667, 5.26076667, 5.3085    , ..., 0.        , 0.        ,
        0.        ],
       [0.22363333, 0.876     , 7.4216    , ..., 0.        , 0.        ,
        0.        ],
       [5.11696667, 5.26076667, 5.3085    , ..., 0.        , 0.        ,
        0.        ]])

In [13]:
t, fxn = kde_init(0.04, 1/dt)
firing_rates = gaussian_convolve(idx,fxn,1/dt,'nonaverage',int((endtime-starttime)/dt))

In [14]:
%matplotlib tk

In [15]:
cur_x = 0
cur_y = 0

fig = plt.figure(figsize=(10,10))
fig.suptitle(plot_title)
ax0 = plt.subplot(313)
ax1 = plt.subplot(312)
ax2 = plt.subplot(351)
ax3 = plt.subplot(352)
ax4 = plt.subplot(353)
ax5 = plt.subplot(354)
ax6 = plt.subplot(355)

ax2.set_title('extracted EOD')
ax3.set_title('whole snip')
#fig, [ax1, ax2] = plt.subplots(2, 1,figsize=(10,5))

# create a horizontal plot
colors1 = ['C{}'.format(i) for i in range(idx.shape[0])]
ax1.eventplot(idx,colors=colors1)#,colors=plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(used_clusters)])
ax0.plot(x-starttime,np.transpose(firing_rates))

ax1.set_xlabel('time [s]')
ax1.set_xlim([0,(endtime-starttime)])

ax0.set_xlabel('time [s]')
ax0.set_xlim([0,endtime-starttime])

def update_plot(ydata, i=None, xdata=None):
    cluster = int(np.round(ydata))

    if cluster < 0:
        cluster = 0

    elif cluster > len(used_clusters):
        cluster = len(used_clusters)

    if xdata is not None:
        tp = xdata
        i = (np.abs(used_clusters[cluster].ts-starttime-tp)).argmin()
        
    #ax1.cla()
    ax2.cla()
    ax3.cla()
    ax4.cla()
    ax5.cla()
    #ax5.axis('off')
    ax6.cla()
    #ax6.axis('off')

    ax3.set_title('whole snip')
    
    #current_maxchan = used_clusters[cluster].all_max_channels[i]
    #current_xpos = current_maxchan%8
    #current_ypos = np.floor(current_maxchan/8)
    
    #ax4.set_title('extracted EOD, maxchan=%i,(%i, %i)'%(current_maxchan,current_xpos,current_ypos))
    #ax5.set_title('temporal')
    
    print(i)
    print(len(used_clusters[cluster].debug['c_eod']))
    ax1.plot(used_clusters[cluster].ts[i]-starttime,cluster,'x',c='r')
    ax1.set_xlabel('time [s]')
    
    ax2.set_title('eod')
    ax2.plot(used_clusters[cluster].debug['c_eod'][i])
    
    ax3.set_title('temp feat')
    ax3.plot(used_clusters[cluster].debug['c_tf'][i])
    
    ax4.set_title('spatial pattern')
    ax4.imshow(used_clusters[cluster].debug['c_sf'][i].reshape(4,8))
    
    plt.tight_layout()
    plt.draw()
    
    return i, cluster

def on_key(event):
    global cur_x
    global cur_y
    if event.key == 'right':
        cur_x, cur_y = update_plot(cur_y, cur_x+1)
    elif event.key == 'left':
        cur_x, cur_y = update_plot(cur_y, cur_x-1)
    elif event.key == 'up':
        cur_x, cur_y = update_plot(cur_y+1, cur_x)
    elif event.key == 'down':
        cur_x, cur_y = update_plot(cur_y-1, cur_x)

def onclick(event):
    global cur_x
    global cur_y
    cur_x, cur_y = update_plot(event.ydata,xdata=event.xdata)
    
cid = fig.canvas.mpl_connect('button_press_event', onclick)
cid = fig.canvas.mpl_connect('key_press_event', on_key)

Traceback (most recent call last):
  File "/home/colombia/miniconda3/lib/python3.7/site-packages/matplotlib/cbook/__init__.py", line 216, in process
    func(*args, **kwargs)
  File "<ipython-input-15-6fb8a1c82bac>", line 94, in onclick
    cur_x, cur_y = update_plot(event.ydata,xdata=event.xdata)
  File "<ipython-input-15-6fb8a1c82bac>", line 40, in update_plot
    i = (np.abs(used_clusters[cluster].ts-starttime-tp)).argmin()
IndexError: list index out of range
Traceback (most recent call last):
  File "/home/colombia/miniconda3/lib/python3.7/site-packages/matplotlib/cbook/__init__.py", line 216, in process
    func(*args, **kwargs)
  File "<ipython-input-15-6fb8a1c82bac>", line 94, in onclick
    cur_x, cur_y = update_plot(event.ydata,xdata=event.xdata)
  File "<ipython-input-15-6fb8a1c82bac>", line 40, in update_plot
    i = (np.abs(used_clusters[cluster].ts-starttime-tp)).argmin()
IndexError: list index out of range
Traceback (most recent call last):
  File "/home/colombia/minicond

In [None]:
nidx = np.array(idx)
nidx = idx[:4]

In [13]:
nidx.shape

(7200,)

In [16]:
nidx[nidx>5/dt] = 0
cols=['b','r','g','k']

In [17]:
frs = []
tps = []
for i in range(4):
    frs.append(1/(np.diff(nidx[i][nidx[i]>0]*dt)))
    
    idxs = np.where((frs[i]<0.6*np.median(frs[i]))|(frs[i]>1.5*np.median(frs[i])))[0]
    idxs = np.append(idxs,np.where(frs[i]>1.5*np.median(frs[i]))[0]-1)
    idbool = np.ones(len(frs[i])).astype('bool')
    idbool[idxs] = 0
    
    tps.append((nidx[i][nidx[i]>0][:-1])[idbool]*dt)
    frs[i] = frs[i][idbool]
    #tps.append((nidx[i][nidx[i]>0][:-1])[frs[i]<np.mean(frs[i])])

In [18]:
nidx[nidx==0] = 100000000

In [19]:
cols= ['b','k','r','g']

In [20]:
%matplotlib inline
from matplotlib import gridspec

In [21]:
fig = plt.figure(figsize=(8,3*0.8))
gs = gridspec.GridSpec(2,2,hspace=0,width_ratios=[2,1])

ax = fig.add_subplot(gs[0, 0])
ax.eventplot(nidx*dt,colors=cols)
ax.set_xlim([0,5.1])
ax.axis('off')
ax.set_title('(a)')

ax = fig.add_subplot(gs[1, 0])
for i in range(4):
    ax.plot(tps[i],frs[i],c=cols[i])
ax.set_xlim([0,5.1])
ax.set_ylim([14,17])
ax.set_xlabel('time [s]')
ax.set_ylabel('EOD frequency [Hz]')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = fig.add_subplot(gs[:, 1])
for i in range(4):
    ax.plot((used_clusters[i].positions[nidx[i]*dt<5,0]-7)*-1,(used_clusters[i].positions[nidx[i]*dt<5,1]-3)*-1,'o',c=cols[i],rasterized=True)
ax.axis('square')
ax.set_xticks([0,1,2,3,4,5,6,7])
ax.set_xlim([0,7])
ax.set_ylim([0,3])
#ax.axis('off')
# Create a Rectangle patch
ax.grid(True)
ax.set_title('(b)')

plt.tight_layout()
#plt.savefig('EOD_traces.svg')

In [22]:
for i,u in enumerate(used_clusters[2].all_temp_patterns):
    if len(u)<40:
        a = np.zeros(40)
        a[:len(u)] = u
        used_clusters[2].all_temp_patterns[i] = a

In [23]:
for c in np.unique(used_clusters[2].all_max_channels):
        plt.figure()
        ar = np.array(used_clusters[2].all_temp_patterns)
        #for i in range(370,400):
        #    plt.plot(ar[used_clusters[0].all_max_channels==c][i],'b')
        #    plt.title(i)
        plt.plot(ar[used_clusters[2].all_max_channels==c].T,alpha=0.2)
        #plt.figure()
        #plt.plot(ar[used_clusters[0].all_max_channels==c][-600:-400].T,alpha=0.2)
        plt.show()
#if used_clusters[0].all_max_channels[i] - used_clusters[0].all_max_channels[i+1] !=0:
#    print(i)
#    break
#plt.show()

In [24]:
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LogNorm

%matplotlib inline
vmin=0.0001
vmax=4
fig = plt.figure(figsize=(8,5.5))

gs = GridSpec(3,4)
n = -18

ax = fig.add_subplot(gs[1, :])
ax.set_title('(c)')
ax.axis('off')

for i in range(4):
    ax = fig.add_subplot(gs[1, i])
    tp = used_clusters[i].all_spatial_features[n-1].reshape(4,8)*0.1
    tp[tp<=0] = np.min(tp[tp>0])
    ax.imshow(np.flip(tp,axis=1),cmap='Blues',norm=LogNorm(vmin=vmin,vmax=vmax))
    ax.axis('off')
    ax.set_title('fish #%i'%(i+1))

ax = fig.add_subplot(gs[0, :2])
ax.imshow(np.flip(used_clusters[2].all_spatial_patterns[n].reshape(4,8),axis=1)*0.1,cmap='Blues',norm=LogNorm(vmin=vmin,vmax=vmax))
ax.set_title('current spatial pattern')
ax.axis('off')

ax = fig.add_subplot(gs[1, :])
ax.set_title('(c)')
ax.axis('off')

ax = fig.add_subplot(gs[0, 2:])
ax.plot(['1','2','3','4'],used_clusters[2].reg_coefs[n],'o',c='k')
#ax.set_title('regression coeficients')
ax.set_xlabel('fish #')
ax.set_ylabel('regression coeficient')
ax.set_ylim([-0.1,1.5])
ax.set_title('(b)')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = fig.add_subplot(gs[0, :2])
im=ax.imshow(np.flip(used_clusters[2].all_spatial_patterns[n].reshape(4,8),axis=1)*0.1,cmap='Blues',norm=LogNorm(vmin=vmin,vmax=vmax))
ax.set_title('(a)')
ax.axis('off')

ax = fig.add_subplot(gs[2, :2])
tp = used_clusters[2].all_spatial_features[n].reshape(4,8)
tp[tp<=0] = np.min(tp[tp>0])
ax.imshow(np.flip(tp*0.1,axis=1),cmap='Blues',norm=LogNorm(vmin=vmin,vmax=vmax))
ax.set_title('(d)')
ax.axis('off')

ax = fig.add_subplot(gs[2, 2:])
tp = used_clusters[2].all_spatial_patterns[n].reshape(4,8)-used_clusters[2].all_spatial_features[n].reshape(4,8)
tp[tp<=0] = np.min(tp[tp>0])
ax.imshow(np.flip(tp*0.1,axis=1),cmap='Blues',norm=LogNorm(vmin=vmin,vmax=vmax))
ax.set_title('(e)')
ax.axis('off')

cbar=plt.colorbar(im)
cbar.set_label('Variance [$mV^2$]')

plt.tight_layout()
plt.savefig('spatial_feature_ex.svg')

In [25]:
from matplotlib.gridspec import GridSpec
%matplotlib inline

fig = plt.figure()
gs = GridSpec(6,3)

ax = fig.add_subplot(gs[:, 0])
ax.imshow(np.log(used_clusters[2].all_spatial_patterns[n].reshape(4,8)),cmap='Blues')
ax.set_title('current spatial pattern')
ax.axis('off')

ax = fig.add_subplot(gs[:3, 2])
tp = used_clusters[2].all_spatial_features[n].reshape(4,8)
tp[tp<=0] = np.min(tp[tp>0])
ax.imshow(np.log(tp),cmap='Blues')
ax.set_title('extracted spatial pattern for fish 3')
ax.axis('off')

ax = fig.add_subplot(gs[3:, 2])
tp = used_clusters[2].all_spatial_patterns[n].reshape(4,8)-used_clusters[2].all_spatial_features[n].reshape(4,8)
tp[tp<=0] = np.min(tp[tp>0])
ax.imshow(np.log(tp),cmap='Blues')
ax.set_title('extracted spatial pattern for fish 4')
ax.axis('off')

plt.tight_layout()
plt.savefig('position estimate.png')

In [26]:
def real_feat(mean,idxs=None,noise_thresh=0.05,ax1=None, ax2=None):
    mean = mean*1000
    
    slope_num=4
    w_num = slope_num - 2
    
    ax1.plot(np.arange(len(mean))/30,mean,c='b')
    
    argmaxmin,y = extract_maxmin(mean,slope_num-1,ax1)
    maxmin = y[argmaxmin]
    slopes = np.diff(maxmin)

    #print(argmaxmin)
    if len(argmaxmin) >= slope_num+1:
        features = np.ones((len(argmaxmin)-slope_num+2,slope_num+w_num))*-1
        best_i = 0
        s = 0
        for i in range(len(argmaxmin)-slope_num):
            if np.sum(np.abs(slopes[i:i+slope_num]))>s:
                s = np.sum(np.abs(slopes[i:i+slope_num]))
                best_i = i
            if np.min(np.diff(argmaxmin[i:i+slope_num])[1:-1]/len(mean)) > noise_thresh:
                features[i,:slope_num] = slopes[i:i+slope_num]
                features[i,slope_num:] = 10*np.diff(argmaxmin[i:i+slope_num])[1:-1]/len(mean)
    else:
        return np.ones((3,slope_num+w_num))*-1,0
    
    
    print('plotting')
    ax2.plot(np.arange(len(mean))/30,mean,c='b')
    ax2.plot(argmaxmin[best_i:best_i+slope_num+1]/30,y[argmaxmin][best_i:best_i+slope_num+1],'o',c='r')
    
    #ax2.plot(argmaxmin[best_i-1]/30,y[argmaxmin][best_i-1],'o',c='k',alpha=0.5)
    #ax2.plot(argmaxmin[best_i+slope_num+1]/30,y[argmaxmin][best_i+slope_num+1],'o',c='k',alpha=0.5)

    
    #return only the features around best i?
    if best_i == 0:
        idx = slice(best_i,best_i+2)
        bi = 0
    elif len(features) > best_i + 2:
        idx = slice(best_i-1,best_i+2)
        bi = 1
    else:
        idx = slice(best_i-1,best_i+1)
        bi = 1

    if np.isin(-1,features[idx]) and np.max(features[idx])>-1:
        features[idx,0] != -1
        #recompute best index??
        return features[idx][features[idx,0] != -1], np.argmax(np.sum(features[idx][features[idx,0] != -1],axis=1))
    else:
        return features[idx],bi
    
def extract_maxmin(mean,k,ax1):
    x = range(len(mean))

    y = mean
    ig = argrelextrema(y, np.greater)
    il = argrelextrema(y, np.less)
    
    argmaxmin = np.sort(np.append(ig,il)).astype('int')
    argmaxmin = np.append(np.append(0,argmaxmin),len(mean)-1)
    
    ax1.plot(argmaxmin/30,y[argmaxmin],'o',c='r')

    return argmaxmin,y

In [27]:
from scipy.signal import argrelextrema

fig = plt.figure(figsize=(5,3.3))
gs = GridSpec(2,2)

#fig, (ax1,ax2,ax3) = plt.subplots(1,3,sharey='row',sharex='row')

ax3 = fig.add_subplot(gs[1, :])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1],sharex=ax1,sharey=ax1)

ax2.plot([0.5,1],[-5.5,-5.5],c='k')
ax2.text(0.65,-7,'0.5 ms')

ax1.plot([-0.1,-0.1],[-1,1.5],c='k')
ax1.text(-0.23,-0.6,'2.5 mV',rotation='vertical')

#ax1.set_title('(a)')
#ax2.set_title('(b)')
#ax3.set_title('(c)')

ax1.axis('off')
ax2.axis('off')

ax1.set_xlabel('time [ms]')
ax2.set_xlabel('time [ms]')
ax1.set_ylabel('amplitude [mV]')
ax3.set_xlabel('feature #')
ax3.set_ylabel('feature value [a.u.]')

ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)
ax3.set_ylim([-10,10])

plt.setp(ax2.get_yticklabels(), visible=False)

rf, bi = real_feat(used_clusters[3].all_temp_patterns[n],ax1=ax1,ax2=ax2)

ax3.plot(range(1,7),rf[bi],'o',c='k')
#for rff in rf:
#    ax3.plot(rff,'--',c='k',alpha=0.5)
plt.tight_layout()
plt.savefig('temp_features.svg')

In [28]:
# compute maxima and minima of peak.
# select the n peaks that have the greatest difference
# features: 
plt.figure(figsize=(2,6))
for t in range(10):
    ax2 = plt.subplot(10,1,t+1)
    plt.gca().axis('off')
    #plt.plot(used_clusters[2].all_temp_patterns[n-t],c='b')
    rf, bi = real_feat(used_clusters[2].all_temp_patterns[n],ax1=ax1,ax2=ax2)
plt.tight_layout()
plt.savefig('10_temp.svg')

In [34]:
plt.figure(figsize=(3,2))
ax2 = plt.subplot(1,1,1)
plt.gca().axis('off')
    #plt.plot(used_clusters[2].all_temp_patterns[n-t],c='b')
rf, bi = real_feat(used_clusters[2].all_temp_patterns[n],ax1=ax1,ax2=ax2)
plt.savefig('new_temp.svg')