In [None]:
import numpy as np
import scipy as sp
import scipy.signal
import os 
import matplotlib.pyplot as plt
from scipy import fft, arange
%matplotlib inline
import sys
sys.path.append('../../')
import python.utils as ut
from definitions import DATA_PATH

In [None]:
d = ut.load_data_spm('spmeeg_1.mat')
lfp = d['data'][5,]
fs = d['fsample'][0][0]

In [None]:
band = [48, 52]
order = 2
Wn = band / fs * 2
b, a = scipy.signal.butter(order, Wn, btype='bandstop')
lfp_flt = sp.signal.lfilter(b, a, lfp)

In [None]:
# calculate normally 
f, psd = sp.signal.welch(lfp, fs=fs, window='hamming', nperseg=1024)
mask = ut.get_array_mask(f > 2, f < 45)

In [None]:
# calculate with epoching 
# get largest multiple of 1024
epoch_length = 1024
multiple = int(np.floor(lfp.shape[0] / epoch_length))
idx = epoch_length * multiple
lfp_r = lfp[:idx]

lfp_r = np.reshape(lfp_r, (multiple, epoch_length))
fr, psdr = sp.signal.welch(lfp_r, fs=fs, window='hamming', nperseg=1024)
mask = ut.get_array_mask(f > 2, f < 45)

In [None]:
# compare in plot
plt.plot(f[mask], psdr.mean(axis=0)[mask], label='epoching')
plt.plot(f[mask], psd[mask], label='normal')
plt.legend();

## Calculate time frequency decomposition 

In [None]:
f, t, sxx = scipy.signal.spectrogram(lfp, fs=fs, nperseg=1024)
sxx = np.log(sxx)
sxx.shape

In [None]:
np.max(t)

In [None]:
mask = ut.get_array_mask(f > 4, f < 21)
plt.plot(f[mask], sxx.mean(axis=1)[mask])

In [None]:
plt.pcolormesh(t, f[mask], sxx[mask, :])
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.colorbar();

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

In [None]:
xgrid, ygrid = np.meshgrid(f[mask], t)

In [None]:
sxx[mask, :].shape

In [None]:
fig = plt.figure(figsize=(20, 10))
ax = fig.gca(projection='3d')

surf = ax.plot_surface(X=xgrid, Y=ygrid, Z=sxx[mask, :].T, cmap=cm.viridis)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Time [s]')
ax.set_zlabel('log power ')
fig.colorbar(surf, shrink=0.5, aspect=5);

In [None]:
plt.plot(f[mask], np.exp(sxx[mask, ]).mean(axis=1))

In [None]:
fig = plt.figure(figsize=plt.figaspect(2.))

ax = fig.add_subplot(2, 1, 1, projection='3d')
X = np.arange(-5, 5, 0.25)
xlen = len(X)
Y = np.arange(-5, 5, 0.25)
ylen = len(Y)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                       linewidth=0, antialiased=False)