In [1]:
import h5py
import numpy as np
import pandas as pd
import wavelets as wl
from scipy import signal
from matplotlib import pyplot as plt

### Figure 2

In [2]:
IDs_WT = ['SERT1597', 'SERT1659', 'SERT1678', 'SERT1908', 'SERT1984', 'SERT1985', 'SERT2014']
IDs_KO = ['SERT1668', 'SERT1665', 'SERT2013', 'SERT2018', 'SERT2024']

all_info_WT = {}
all_mice_WT = {}

print('Processing wild-types')
for ID in IDs_WT:
    npys_dir = '/home/maspe/filer/SERT/' + ID + '/npys/'
    spikes_dir = '/home/maspe/filer/SERT/' + ID + '/spikes/results/'
    
    info = np.load(npys_dir + 'info.npy', allow_pickle=True).item()
    all_info_WT[ID] = info
    channels = info['channels_list']
    
    fit = []
    for channel in channels:
        path = spikes_dir + channel + '.result.hdf5'
    
        fit.append(h5py.File(path, 'r'))
        
    print('Loaded ' + spikes_dir)
        
    all_mice_WT[ID] = fit


print('Processing knock-out')
all_info_KO = {}
all_mice_KO = {}
for ID in IDs_KO:
    npys_dir = '/home/maspe/filer/SERT/' + ID + '/npys/'
    spikes_dir = '/home/maspe/filer/SERT/' + ID + '/spikes/results/'
    
    info = np.load(npys_dir + 'info.npy', allow_pickle=True).item()
    all_info_KO[ID] = info
    channels = info['channels_list']
    
    fit = []
    for channel in channels:
        path = spikes_dir + channel + '.result.hdf5'
    
        fit.append(h5py.File(path, 'r'))
        
    print('Loaded ' + spikes_dir)
        
    all_mice_KO[ID] = fit

all_units_WT = {}
all_channels_WT = {}


Processing wild-types
Loaded /home/maspe/filer/SERT/SERT1597/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1659/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1678/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1908/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1984/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1985/spikes/results/
Loaded /home/maspe/filer/SERT/SERT2014/spikes/results/
Processing knock-out
Loaded /home/maspe/filer/SERT/SERT1668/spikes/results/
Loaded /home/maspe/filer/SERT/SERT1665/spikes/results/
Loaded /home/maspe/filer/SERT/SERT2013/spikes/results/
Loaded /home/maspe/filer/SERT/SERT2018/spikes/results/
Loaded /home/maspe/filer/SERT/SERT2024/spikes/results/


In [3]:
for mouse in all_mice_WT.keys():
    n_channels = 32
    units = []
    channels_id = []
    
    iteration = 0
    for channel in range(n_channels):
        for unit in all_mice_WT[mouse][channel]['spiketimes'].keys():
            units.append(all_mice_WT[mouse][channel]['spiketimes'][unit][()]) # Final "[()]" is to import values from h5py 
      
            channels_id.append(all_info_WT[mouse]['channels_locs'][iteration])
        
        iteration += 1
        
            
    all_units_WT[mouse] = units
    all_channels_WT[mouse] = channels_id
        
# print('n units = %i' %len(all_units))


all_units_KO = {}
all_channels_KO = {}
for mouse in all_mice_KO.keys():
    n_channels = 32
    units = []
    channels_id = []
    
    iteration = 0
    for channel in range(n_channels):
        for unit in all_mice_KO[mouse][channel]['spiketimes'].keys():
            units.append(all_mice_KO[mouse][channel]['spiketimes'][unit][()]) # Final "[()]" is to import values from h5py 
      
            channels_id.append(all_info_KO[mouse]['channels_locs'][iteration])
        
        iteration += 1
        
            
    all_units_KO[mouse] = units
    all_channels_KO[mouse] = channels_id
        
# print('n units = %i' %len(all_units))

secs = 2
sampleRate = 30000
window = int(sampleRate * secs)

In [4]:
all_epochs_WT = {}
all_perispikes_WT = {}
n_points = 30000 * 60 * 10
window = 30000 * 60 * 2
for mouse in all_mice_WT.keys():
    startOF = np.int(all_info_WT[mouse]['startOF']) * 30
    stopOF = startOF + n_points
    
    init = np.arange(startOF, startOF + window, 1)
    end = np.arange(stopOF - window, stopOF, 1)
    
    task_time = np.concatenate([init, end])

    all_epochs_WT[mouse] = task_time
    
all_epochs_KO = {}
all_perispikes_KO = {}
for mouse in all_mice_KO.keys():
    startOF = all_info_KO[mouse]['startOF'] * 30
    stopOF = startOF + n_points
    
    init = np.arange(startOF, startOF + window, 1)
    end = np.arange(stopOF - window, stopOF, 1)
    
    task_time = np.concatenate([init, end])

    all_epochs_KO[mouse] = task_time
    


In [6]:
task_npoints = task_time.shape[0]

In [None]:
all_perispikes_WT = {}

for mouse in all_mice_WT.keys():
    all_spikes = all_units_WT[mouse]
    peristimulus_spikes = np.zeros((len(all_spikes), task_npoints))   
    
    for unit in range(len(all_spikes)):
        print('Processing unit %i of mouse %s' %(unit, mouse))
        peristimulus_spikes[unit, :] = np.isin(all_epochs_WT[mouse], all_spikes[unit])#, axes=(1, 0))
#         peristimulus_spikes[:, :, ] = np.transpose(np.isin(epochs_matrix, all_spikes[unit]), axes=(0, 2, 1))
# #         peristimulus_spikes = np.transpose(np.array(peristimulus_spikes), axes=(0, 2, 1))
        
    all_perispikes_WT[mouse] = peristimulus_spikes

In [7]:
all_perispikes_KO = {}

for mouse in all_mice_KO.keys():
    all_spikes = all_units_KO[mouse] 
    peristimulus_spikes = np.zeros((len(all_spikes), task_npoints))   
    
    for unit in range(len(all_spikes)):
        print('Processing unit %i of mouse %s' %(unit, mouse))
        peristimulus_spikes[unit, :] = np.isin(all_epochs_KO[mouse], all_spikes[unit])
        
    all_perispikes_KO[mouse] = peristimulus_spikes

Processing unit 0 of mouse SERT2013
Processing unit 1 of mouse SERT2013
Processing unit 2 of mouse SERT2013
Processing unit 3 of mouse SERT2013
Processing unit 4 of mouse SERT2013
Processing unit 5 of mouse SERT2013
Processing unit 6 of mouse SERT2013
Processing unit 7 of mouse SERT2013
Processing unit 8 of mouse SERT2013
Processing unit 9 of mouse SERT2013
Processing unit 10 of mouse SERT2013
Processing unit 11 of mouse SERT2013
Processing unit 12 of mouse SERT2013
Processing unit 13 of mouse SERT2013
Processing unit 14 of mouse SERT2013
Processing unit 15 of mouse SERT2013
Processing unit 16 of mouse SERT2013
Processing unit 17 of mouse SERT2013
Processing unit 18 of mouse SERT2013
Processing unit 19 of mouse SERT2013
Processing unit 20 of mouse SERT2013
Processing unit 21 of mouse SERT2013
Processing unit 22 of mouse SERT2013
Processing unit 23 of mouse SERT2013
Processing unit 24 of mouse SERT2013
Processing unit 25 of mouse SERT2013
Processing unit 26 of mouse SERT2013
Processing 

Processing unit 18 of mouse SERT2018
Processing unit 19 of mouse SERT2018
Processing unit 20 of mouse SERT2018
Processing unit 21 of mouse SERT2018
Processing unit 22 of mouse SERT2018
Processing unit 23 of mouse SERT2018
Processing unit 24 of mouse SERT2018
Processing unit 25 of mouse SERT2018
Processing unit 26 of mouse SERT2018
Processing unit 27 of mouse SERT2018
Processing unit 28 of mouse SERT2018
Processing unit 29 of mouse SERT2018
Processing unit 30 of mouse SERT2018
Processing unit 31 of mouse SERT2018
Processing unit 32 of mouse SERT2018
Processing unit 33 of mouse SERT2018
Processing unit 34 of mouse SERT2018
Processing unit 35 of mouse SERT2018
Processing unit 36 of mouse SERT2018
Processing unit 37 of mouse SERT2018
Processing unit 38 of mouse SERT2018
Processing unit 39 of mouse SERT2018
Processing unit 40 of mouse SERT2018
Processing unit 41 of mouse SERT2018
Processing unit 42 of mouse SERT2018
Processing unit 43 of mouse SERT2018
Processing unit 44 of mouse SERT2018
P

In [8]:
all_perispikes_KO.keys()

['SERT2013', 'SERT1668', 'SERT2024', 'SERT2018', 'SERT1665']

In [13]:
np.save(npys_dir + '1665_perispikes_KO.npy', all_perispikes_KO['SERT1665'])

In [None]:
mPFC_WT = {}
NAC_WT  = {}
BLA_WT  = {}
vHip_WT = {}

for mouse in all_mice_WT.keys():
    mPFC_WT_indexes  = [i for i,x in enumerate(all_channels_WT[mouse]) if x == 'mPFC_left']
    NAC_WT_indexes = [i for i,x in enumerate(all_channels_WT[mouse]) if x == 'NAC_left']
    BLA_WT_indexes  = [i for i,x in enumerate(all_channels_WT[mouse]) if x == 'BLA_left']
    vHip_WT_indexes  = [i for i,x in enumerate(all_channels_WT[mouse]) if x == 'vHipp_left']
    
    
    mPFC_WT[mouse] = np.sum(all_perispikes_WT[mouse][mPFC_WT_indexes, :], axis=(0, 2))
    NAC_WT[mouse] = np.sum(all_perispikes_WT[mouse][NAC_WT_indexes, :], axis=(0, 2))
    BLA_WT[mouse] = np.sum(all_perispikes_WT[mouse][BLA_WT_indexes, :], axis=(0, 2))
    vHip_WT[mouse] = np.sum(all_perispikes_WT[mouse][vHip_WT_indexes, :], axis=(0, 2))

    
mPFC_KO = {}
NAC_KO  = {}
BLA_KO  = {}
vHip_KO = {}

for mouse in all_mice_KO.keys():
    mPFC_KO_indexes  = [i for i,x in enumerate(all_channels_KO[mouse]) if x == 'mPFC_left']
    NAC_KO_indexes = [i for i,x in enumerate(all_channels_KO[mouse]) if x == 'NAC_left']
    BLA_KO_indexes  = [i for i,x in enumerate(all_channels_KO[mouse]) if x == 'BLA_left']
    vHip_KO_indexes  = [i for i,x in enumerate(all_channels_KO[mouse]) if x == 'vHipp_left']
    
    
    mPFC_KO[mouse] = np.sum(all_perispikes_KO[mouse][mPFC_KO_indexes, :], axis=(0, 2))
    NAC_KO[mouse] = np.sum(all_perispikes_KO[mouse][NAC_KO_indexes, :], axis=(0, 2))
    BLA_KO[mouse] = np.sum(all_perispikes_KO[mouse][BLA_KO_indexes, :], axis=(0, 2))
    vHip_KO[mouse] = np.sum(all_perispikes_KO[mouse][vHip_KO_indexes, :], axis=(0, 2))

In [None]:
myspikes = {'WT': {'BLA': BLA_WT, 'NAC': NAC_WT, 'vHip': vHip_WT}, 'KO': {'BLA': BLA_KO, 'NAC': NAC_KO, 'vHip': vHip_KO}}

In [None]:
np.save('/home/maspe/filer/SERT/ALL/npys/spikes', myspikes)

In [None]:
tx = range(0, 120000, 60)

# mPFC
iteration = 0
t = np.zeros((len(mPFC_WT.keys()), window*2 / 60))
for mouse in mPFC_WT.keys():
    t[iteration,:]=np.add.reduceat(mPFC_WT[mouse], range(0, mPFC_WT[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_WT=np.mean(t, axis=0)

iteration = 0
t = np.zeros((len(mPFC_KO.keys()), window*2 / 60))
for mouse in mPFC_KO.keys():
    t[iteration,:]=np.add.reduceat(mPFC_KO[mouse], range(0, mPFC_KO[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_KO=np.mean(t, axis=0)

plt.figure(figsize=(20,10))
plt.subplot(2,2,1)
plt.plot(tx, tmean_WT, '-', color='blue', alpha=0.7)
plt.plot(tx, tmean_KO, '-', color='red', alpha=0.7)
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000], ['-3', '-2', '-1', '0', '1', '2', '3'], fontsize=16)
plt.yticks(fontsize=16)
plt.ylim([0,70])
plt.axvline(x=60000, color='black')
#plt.xlabel('time (s)')
plt.ylabel('units', fontsize=18)
plt.title('mPFC', fontsize=24)


# NAC
iteration = 0
t = np.zeros((len(NAC_WT.keys()), window*2 / 60))
for mouse in NAC_WT.keys():
    t[iteration,:]=np.add.reduceat(NAC_WT[mouse], range(0, NAC_WT[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_WT=np.mean(t, axis=0)

iteration = 0
t = np.zeros((len(NAC_KO.keys()), window*2 / 60))
for mouse in NAC_KO.keys():
    t[iteration,:]=np.add.reduceat(NAC_KO[mouse], range(0, NAC_KO[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_KO=np.mean(t, axis=0)

plt.subplot(2,2,2)
plt.plot(tx, tmean_WT, '-', color='blue', alpha=0.7)
plt.plot(tx, tmean_KO, '-', color='red', alpha=0.7)
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000], ['-3', '-2', '-1', '0', '1', '2', '3'], fontsize=16)
plt.yticks(fontsize=16)
plt.ylim([0,70])
plt.axvline(x=60000, color='black')
#plt.xlabel('time (s)')
#plt.ylabel('units')
plt.title('NAC', fontsize=22)


# BLA
iteration = 0
t = np.zeros((len(BLA_WT.keys()), window*2 / 60))
for mouse in BLA_WT.keys():
    t[iteration,:]=np.add.reduceat(BLA_WT[mouse], range(0, BLA_WT[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_WT=np.mean(t, axis=0)

iteration = 0
t = np.zeros((len(BLA_KO.keys()), window*2 / 60))
for mouse in BLA_KO.keys():
    t[iteration,:]=np.add.reduceat(BLA_KO[mouse], range(0, BLA_KO[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_KO=np.mean(t, axis=0)

plt.subplot(2,2,3)
plt.plot(tx, tmean_WT, '-', color='blue', alpha=0.7)
plt.plot(tx, tmean_KO, '-', color='red', alpha=0.7)
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000], ['-3', '-2', '-1', '0', '1', '2', '3'], fontsize=16)
plt.yticks(fontsize=16)
plt.ylim([0,70])
plt.axvline(x=60000, color='black')
plt.xlabel('time (s)', fontsize=18)
plt.ylabel('units', fontsize=18)
plt.title('BLA', fontsize=24)


# vHip
iteration = 0
t = np.zeros((len(vHip_WT.keys()), window*2 / 60))
for mouse in vHip_WT.keys():
    t[iteration,:]=np.add.reduceat(vHip_WT[mouse], range(0, vHip_WT[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_WT=np.mean(t, axis=0)

iteration = 0
t = np.zeros((len(vHip_KO.keys()), window*2 / 60))
for mouse in vHip_KO.keys():
    t[iteration,:]=np.add.reduceat(vHip_KO[mouse], range(0, vHip_KO[mouse].shape[0], 60))
    
    iteration += 1
    
tmean_KO=np.mean(t, axis=0)

plt.subplot(2,2,4)
plt.plot(tx, tmean_WT, '-', color='blue', alpha=0.7)
plt.plot(tx, tmean_KO, '-', color='red', alpha=0.7)
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000], ['-3', '-2', '-1', '0', '1', '2', '3'], fontsize=16)
plt.yticks(fontsize=16)
plt.ylim([0,70])
plt.axvline(x=60000, color='black')
plt.xlabel('time (s)', fontsize=18)
#plt.ylabel('units')
plt.title('vHip', fontsize=24)


plt.savefig('/home/maspe/filer/SERT/ALL/figs/perispikes.png', dpi=150, facecolor='w', edgecolor='w',
            orientation='portrait', papertype=None, format='png', transparent=False)
    
#plt.close()

#####################################

#### Backup code

In [None]:
spikes = []
for unit in range(len(spikes_epochs_pre)):
    # np.isin(element, test_elements)
    spikes.append(np.isin(epoch_matrix, [spikes_epochs_pre[unit] + spikes_epochs_post[unit]]))

In [None]:
spikes_epochs_pre = []
spikes_epochs_post = []
for unit in range(len(all_units)):
    spiketimes_pre = []
    spiketimes_post = []
    for epoch in range(n_epochs):
        spiketimes_pre.extend(all_units[unit][(all_units[unit] > epochs_pre[epoch][0]) & (all_units[unit] < epochs_post[epoch][1])])
        spiketimes_post.extend(all_units[unit][(all_units[unit] > epochs_post[epoch][0]) & (all_units[unit] < epochs_post[epoch][1])])
    spikes_epochs_pre.append(spiketimes_pre)
    spikes_epochs_post.append(spiketimes_post)

In [None]:
len(spikes_epochs_pre[52])

In [None]:
spikes = []
for unit in range(len(spikes_epochs_pre)):
    # np.isin(element, test_elements)
    spikes.append(np.isin(epoch_matrix, [spikes_epochs_pre[unit] + spikes_epochs_post[unit]]))

In [None]:
np.sum(spikes[50], axis=0).shape