In [4]:
import h5py
import numpy as np
import pandas as pd
import wavelets as wl
from scipy import signal
from matplotlib import pyplot as plt

In [52]:
ID    = 'SERT1985'
mydir = '/home/maspe/filer/SERT/' + ID + '/'

mPFC = np.load(mydir + 'mPFC_epochs.npy')

SRP = np.mean(mPFC, axis=2)
SRP.shape

(10, 6000)

### Figure 1

In [53]:
fs = 1000.0
dt = 1 / fs
time = np.arange(0, 6, dt)
frequencies = np.arange(1, 100, 1)
periods = 1 / (frequencies * dt)
scales = periods / wl.Morlet.fourierwl

# wavel1 = wl.Morlet(SRP[1,:], scales=scales)

In [None]:
for chan in range(SRP.shape[0]):
    print(chan)
    wavel1 = wl.Morlet(SRP[chan,:], scales=scales)

    pwr1=wavel1.getnormpower()

    fmin=min(frequencies)
    fmax=max(frequencies)

    plt.figure(1, figsize=(10, 4))
    plt.clf()

    ax1=plt.subplot2grid((1, 5),(0, 0),colspan=4)
    #
    plt.imshow(pwr1,cmap='RdBu',vmax=np.max(pwr1),vmin=-np.max(pwr1),
           extent=(min(time),max(time),fmin,fmax),origin='lower', 
           interpolation='none',aspect='auto')
    
    plt.axvline(x=3, color='black')
    
    locs, labels = plt.xticks()
    
    plt.xticks(locs, ['-3', '-2', '-1', '0', '1', '2', '3'])
    #ax1.set_yscale('log')
    ax1.set_ylabel('Frequency (Hz)')

    ax2=plt.subplot2grid((1,5),(0,4))
    ax2.plot(np.sum(pwr1[:,:3000],-1),frequencies, 'g')
    ax2.plot(np.sum(pwr1[:,-3000:],-1),frequencies, 'r')
    
    

    plt.savefig(mydir + str(chan) + ".png", dpi=50, facecolor='w', edgecolor='w',
            orientation='portrait', papertype=None, format='png',
            transparent=False)
    
    plt.close()

### Figure 2

In [1]:
mydir = '/home/maspe/filer/testFiles/spiking_data/SERT1985/'

In [2]:
chans = ['ch01', 'ch02', 'ch03', 'ch04', 'ch05', 'ch06', 'ch07', 'ch08', 'ch09', 'ch10',
         'ch11', 'ch12', 'ch13', 'ch14', 'ch15', 'ch16', 'ch17', 'ch18', 'ch19', 'ch20',
         'ch21', 'ch22', 'ch23', 'ch24', 'ch25', 'ch26', 'ch27', 'ch28', 'ch29', 'ch30',
         'ch31', 'ch32']

nchannels = len(chans)

In [5]:
fit = []
for chan in chans:
    path = mydir + chan + '/' + chan + '.result.hdf5'
    
    fit.append(h5py.File(path, 'r'))

In [6]:
all_units = []
for a in range(len(fit)):
    for key in fit[a]['spiketimes'].keys():
        all_units.append(fit[a]['spiketimes'][key][()]) # Final "[()]" is to import values from h5py 
        
len(all_units)

53

In [7]:
df = pd.read_excel("/home/maspe/filer/testFiles/spiking_data/SERT1985/entradas_sert1985.xlsx",
                   sheet_name=0, header=None, names=["locs"])
entrances_times = np.array(df['locs'].tolist(), dtype='int')
entrances_times = entrances_times * 30

n_epochs = len(entrances_times)
epochs_pre = []
epochs_post = []
# post = []
secs = 1.5
sampleRate = 30000
window = int(sampleRate * secs)

# for epoch in range(len(entrances_times)):
#     pre.append(np.arange(entrances_times[epoch] - window, entrances_times[epoch], 1))
#     post.append(np.arange(entrances_times[epoch], entrances_times[epoch] + window, 1))

for epoch in range(n_epochs):
    epochs_pre.append([entrances_times[epoch] - window, entrances_times[epoch]])
    epochs_post.append([entrances_times[epoch], entrances_times[epoch] + window])
#     post.append([entrances_times[epoch] + 1, entrances_times[epoch] + window])    
    
    
# Flatten vector of all points previous or post to center entrance
# pre_flatten = np.unique([item for sublist in pre for item in sublist])
# post_flatten = np.unique([item for sublist in post for item in sublist])

In [12]:
pre = []
post = []
for epoch in range(20):
    pre.append(np.arange(entrances_times[epoch] - window, entrances_times[epoch], 1))
    post.append(np.arange(entrances_times[epoch], entrances_times[epoch] + window, 1))

In [45]:
epoch_matrix=np.concatenate((np.array(pre), np.array(post)), axis=1)

In [46]:
epoch_matrix.shape

(20, 90000)

In [11]:
spikes_matrix = np.zeros([len(all_units), window * 2])
spikes_matrix.shape

(53, 90000)

In [8]:
spikes_epochs_pre = []
spikes_epochs_post = []
for unit in range(len(all_units)):
    spiketimes_pre = []
    spiketimes_post = []
    for epoch in range(n_epochs):
        spiketimes_pre.extend(all_units[unit][(all_units[unit] > epochs_pre[epoch][0]) & (all_units[unit] < epochs_post[epoch][1])])
        spiketimes_post.extend(all_units[unit][(all_units[unit] > epochs_post[epoch][0]) & (all_units[unit] < epochs_post[epoch][1])])
    spikes_epochs_pre.append(spiketimes_pre)
    spikes_epochs_post.append(spiketimes_post)

In [26]:
len(spikes_epochs_pre[52])

12613

In [37]:
spikes = []
for unit in range(len(spikes_epochs_pre)):
    # np.isin(element, test_elements)
    spikes.append(np.isin(pre, spikes_epochs_pre[unit] + spikes_epochs_post[unit]))

In [36]:
len(spikes)

53