In [1]:
### Importing modules
import time
import numpy as np
import wavelets as wl
from scipy import signal
from matplotlib import pyplot as plt

In [2]:
### Parameters
# Downsampling parameters: final resolution = 1000 Hz
fs = 30000.0
final_fs  = 1000.0
ds_factor = fs // final_fs

# 
npoints = 180000
n_samples = np.int(npoints / ds_factor) # Downsample to 1000 Hz
fs = 1000.0

# Morlet parameters
dt = 1 / fs
time_windows = np.arange(0, 3, dt)
frequencies = np.arange(1, 100, 1)
periods = 1 / (frequencies * dt)
scales = periods / wl.Morlet.fourierwl
n_frequencies = frequencies.shape[0]
time_points = time_windows.shape[0]

# Reducing epoch
start = 1000
stop = 4000
baseline = 2000

In [4]:
### Loading data

In [None]:
IDs_WT = ['SERT1597'] #, 'SERT1659'] #, 'SERT1678', 'SERT1908', 'SERT1984', 'SERT1985', 'SERT2014']
IDs_KO = ['SERT1668'] #, 'SERT1665', 'SERT2018', 'SERT2024', 'SERT2013'] 

WT = {'mPFC':{}, 'BLA':{}, 'NAC':{}, 'vHip':{}}
KO = {'mPFC':{}, 'BLA':{}, 'NAC':{}, 'vHip':{}}
clock = time.time()

print('###################\nDownsampling WTs to 1000 Hz...')
for structure in WT.keys():
    for ID in IDs_WT:
        print('Loading {} from {}...'.format(structure, ID))
        npys_dir = '/home/maspe/filer/SERT/' + ID + '/npys/'
        x = np.load(npys_dir + structure + '_epochs.npy', allow_pickle=True)      
    
        print('Downsampling...\n')
        WT[structure][ID] = signal.resample(x=x, num=n_samples, axis=1)

        
print('###################\nDownsampling KOs to 1000 Hz...')
for structure in KO.keys():
    for ID in IDs_KO:
        print('Loading {} from {}...'.format(structure, ID))
        npys_dir = '/home/maspe/filer/SERT/' + ID + '/npys/'        
        x = np.load(npys_dir + structure + '_epochs.npy', allow_pickle=True)
    
        print('Downsampling...\n')
        KO[structure][ID] = signal.resample(x=x, num=n_samples, axis=1)

        
print('All mice downsampled in {:.2f} s.'.format(time.time() - clock))

In [7]:
WT['BLA']['SERT1597'].shape

(9, 6000, 16)

In [8]:
x.shape

(7, 180000, 4)

In [None]:
# Morlet transform
mPFC_SRS = dict()
for key in mPFC_WT.keys():
    clock = time.time()
    print('Morlet wavelet for {} mPFC started!'.format(ID))
    
    n_channels = mPFC_WT[key].shape[0]
    n_epochs = mPFC_WT[key].shape[2]

    mPFC_SRS[key] = np.zeros((n_frequencies, time_points, n_channels, n_epochs))
    for epoch in range(n_epochs):
        
        for channel in range(n_channels): # SRP.shape[0]):
            ### DOCS says data: data in array to transform, length must be power of 2 !!!!
            wavel1 = wl.Morlet(mPFC_WT[key][channel, start:stop, epoch], scales=scales)
            mPFC_SRS[key][:, :, channel, epoch] = wavel1.getnormpower()
            
    print('Transforming to z-score...')        
    baseline_mean = np.mean(mPFC_SRS[key][:, baseline:, :, :], axis=1)
    baseline_sd = np.std(mPFC_SRS[key][:, baseline:, :, :], axis=1)
    mPFC_SRS[key] = (mPFC_SRS[key] - baseline_mean[:, None, :, :]) / baseline_sd[:, None, :, :]
        
    print('Mouse {} Fourier transformed in {:3.2f} s.\n'.format(ID, time.time() - clock))
        
print('Done!')

In [None]:
mPFC_mean_SRS = dict()
for key in mPFC_SRS.keys():
    mPFC_mean_SRS[key] = np.mean(mPFC_SRS[key], axis=(2,3))

In [None]:
np.save('/home/maspe/filer/SERT/ALL/npys/' + 'mPFC_SRS.npy', mPFC_mean_SRS)

In [None]:
npys_dir = '/home/maspe/filer/SERT/ALL/npys/'
mPFC_SRS_WT = np.load(npys_dir + 'mPFC_SRS.npy', allow_pickle=True).item()

In [None]:
iteration = 0
for key in mPFC_SRS_WT.keys(): 
    
    if iteration == 0:
        grand_average = mPFC_SRS_WT[key]
        
    else:
        grand_average = np.stack((grand_average, mPFC_SRS_WT[key]), axis=2)
         
    iteration += 1

grand_average = np.mean(grand_average, axis=2)
grand_average.shape
    

In [None]:
pwr1 = grand_average
fmin = min(frequencies)
fmax = max(frequencies)

plt.figure(1, figsize=(10, 4))
plt.clf()

ax1 = plt.subplot2grid((1, 5),(0, 0),colspan=4)
   
plt.imshow(pwr1,cmap='RdBu',vmax=np.max(pwr1),vmin=-np.max(pwr1),
           extent=(min(time_windows),max(time_windows),fmin,fmax),
           origin='lower', interpolation='none',aspect='auto')
plt.colorbar(fraction=0.05,pad=0.02)
    
plt.axvline(x=2, color='black')
   
locs, labels = plt.xticks()    
plt.xticks(locs, ['-2', '-1.5', '-1', '-0.5', '0', '0.5', '1'], fontsize=14)
plt.yticks(fontsize=14)

#ax1.set_yscale('log')
ax1.set_xlabel('Time (s)', fontsize=16)
ax1.set_ylabel('Frequency (Hz)', fontsize=16)
plt.title('SRS Grand-average', fontsize=20);


In [None]:
a = 'caca'

a + '1' = [1,4,3,2]