# COURSE: Master Python for scientific programming by solving projects
## PROJECT: Time-frequency analysis of EEG data
#### TEACHER: Mike X Cohen, sincxpress.com
##### COURSE URL: udemy.com/course/maspy_x/?couponCode=202201

In [None]:
# import all necessary modules
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat

# Create real and complex Morlet wavelets

In [None]:
# functions to create wavelets

def createRealWavelet(time,freq,fwhm):
  # time = time for wavelet, should be zero-centered
  # freq = peak frequency for wavelet
  # fwhm = full-width at half-maximum in seconds
  sinepart = np.cos(2*np.pi*freq*time)
  gauspart = np.exp( (-4*np.log(2)*time**2)/(fwhm**2) )
  return sinepart*gauspart



In [None]:
# parameters
freq  = 5 # Hz
fwhm  = .5
srate = 500 # Hz
time  = np.arange(-2*srate,2*srate)/srate
npnts = len(time)

# now create one wavelet and visualize in time and in frequency domains
wavelet = createRealWavelet(time,freq,fwhm)

# power spectrum of wavelet
hz = np.linspace(0,srate/2,int(npnts/2))
waveletX = abs(np.fft.fft(wavelet)/npnts)**2

# setup the figure
fig,ax = plt.subplots(1,2,figsize=(15,5))

# time-domain version
ax[0].plot(time,wavelet,'k')
ax[0].set_xlabel('Time (s)')
ax[0].set_ylabel('Amplitude (a.u.)')
ax[0].set_title('Time domain')

# frequency-domain version
ax[1].stem(hz,waveletX[:len(hz)],'k')#,use_line_collection=True)
ax[1].plot(hz,waveletX[:len(hz)],'m')
ax[1].set_xlim([0,20])
ax[1].set_xlabel('Time (s)')
ax[1].set_ylabel('Amplitude (a.u.)')
ax[1].set_title('Frequency domain')

plt.show()

# Complex-valued Morlet wavelets

In [None]:
def createComplexWavelet(time,freq,fwhm):
  sinepart = np.exp( 1j*2*np.pi*freq*time )
  gauspart = np.exp( (-4*np.log(2)*time**2)/(fwhm**2) )
  return sinepart*gauspart

In [None]:
# create a complex Morlet wavelet

wavelet = createComplexWavelet(time,5,1)

fig = plt.subplots(1,figsize=(15,8))
plt.plot(time,np.real(wavelet),label='Real part')
plt.plot(time,np.imag(wavelet),label='Imaginary part')
plt.plot(time,np.abs(wavelet),'k',label='Magnitude')

plt.xlabel('Time (s)')
plt.legend(fontsize=19)
plt.show()

In [None]:
# plot its magnitude and phase

fig = plt.subplots(1,figsize=(15,8))
plt.plot(time,np.angle(wavelet),label='Phase')
plt.plot(time,np.abs(wavelet),'k',label='Magnitude')

plt.xlabel('Time (s)')
plt.ylabel('Angle (rad.) or amplitude (a.u.)')
plt.legend()
plt.show()

# Create a wavelet family

In [None]:
# parameters
nfrex  =   40
lofreq =    2   # Hz
hifreq =   80   # Hz

frex   = np.linspace(lofreq,hifreq,nfrex)
fwhms  = np.linspace(4,1,nfrex)

In [None]:
# create a family of wavelets
waveletfam = np.zeros((nfrex,npnts),dtype=complex)

for wi in range(nfrex):
  waveletfam[wi,:] = createComplexWavelet(time,frex[wi],fwhms[wi])

fig,ax = plt.subplots(1,3,figsize=(15,5))

# show the real part
ax[0].imshow(np.real(waveletfam),
             aspect='auto',origin='lower',
             extent=[time[0],time[-1],lofreq,hifreq],
             vmin=-.8,vmax=.8)
ax[0].set_xlabel('Time (s)')
ax[0].set_ylabel('Frequency (Hz)')
ax[0].set_title('Real part')

# show the angles
ax[1].imshow(np.angle(waveletfam),
             aspect='auto',origin='lower',
             extent=[time[0],time[-1],lofreq,hifreq])
ax[1].set_xlabel('Time (s)')
ax[1].set_ylabel('Frequency (Hz)')
ax[1].set_title('Phase')


# show the magnitudes
ax[2].imshow(np.abs(waveletfam),
             aspect='auto',origin='lower',
             extent=[time[0],time[-1],lofreq,hifreq])
ax[2].set_xlabel('Time (s)')
ax[2].set_ylabel('Frequency (Hz)')
ax[2].set_title('Magnitude')


plt.show()

# Import the EEG data

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
# import the data to python
from scipy.io import loadmat
EEG = loadmat('sampleEEGdata.mat')

# extract the necessary information
times = np.squeeze(EEG['EEG'][0][0][14])
data  = EEG['EEG'][0][0][15]
fs    = EEG['EEG'][0][0][11][0][0].astype(int) # sampling rate

print(fs)
print(np.shape(data))

In [None]:
# compute ERP
# (compare the code below against the discussion in the video about axis)
erp = np.mean(data,axis=2)[46,:]

# plot trial-averaged response
plt.plot(times,erp)
plt.xlim([-200,1000])
plt.xlabel('Time (ms)')
plt.ylabel('Voltage ($\mu V$)')
plt.title('ERP from channel 47')
plt.show()

# Wavelet convolution

In [None]:
# need to make new wavelets with the new sampling rate
nfrex  =   40
lofreq =    2   # Hz
hifreq =   30   # Hz

frex   = np.linspace(lofreq,hifreq,nfrex)
fwhms  = np.linspace(1,.5,nfrex)

wavtime = np.arange(-fs,fs+1)/fs # note the change in variable name
npnts   = len(wavtime)

waveletfam = np.zeros((nfrex,npnts),dtype=complex)
for wi in range(nfrex):
  waveletfam[wi,:] = createComplexWavelet(wavtime,frex[wi],fwhms[wi])


# plot a few wavelets to make sure they look good
for i in range(4):
  plt.plot(wavtime,np.real(waveletfam[i*4,:]) + i*1.5)

plt.xlabel('time (s)')
plt.tick_params(labelleft=False) #labelbottom=False
plt.show()

In [None]:
# run convolution for one frequency

# run convolution
convres = np.convolve(erp,waveletfam[0,:],mode='same' )

plt.plot(times,np.real(convres),label='Real part')
plt.plot(times,abs(convres),label='Magnitude')
plt.plot([times[0],times[-1]],[0,0],'k--')
ylim = plt.ylim()
plt.plot([0,0],ylim,'k:')

plt.xlim([times[0],times[-1]])
plt.ylim(ylim)
plt.xlabel('Time (ms)')
plt.ylabel('Voltage (a.u.)')
plt.legend()
plt.show()

# Create a time-frequency map

In [None]:
# initialize
tf = np.zeros((nfrex,len(times)))

# loop over frequencies and repeat convolution
for fi in range(nfrex):
  convres = np.convolve(erp,waveletfam[fi,:],mode='same')
  tf[fi,:] = np.abs(convres)

plt.imshow(tf,
          aspect='auto',origin='lower',
          extent=[times[0],times[-1],lofreq,hifreq],
          vmin=0,vmax=100)

plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.show()

# Bonus

In [None]:
phases = np.zeros((nfrex,len(times)))

for fi in range(nfrex):
  convres = np.convolve(data[48,:,0],waveletfam[fi,:],mode='same')
  phases[fi,:] = np.angle(convres)

plt.imshow(phases,
          aspect='auto',origin='lower',
          extent=[times[0],times[-1],lofreq,hifreq],
          vmin=-np.pi,vmax=np.pi,
          cmap='hsv')

plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.colorbar()
plt.show()