In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scipy import signal
import scipy
import scipy.io as sio

In [None]:
wavelet_challenge = sio.loadmat('wavelet_codeChallenge.mat')
wavelet_challenge

original = np.squeeze(wavelet_challenge['signal'])
signalFIR = np.squeeze(wavelet_challenge['signalFIR'])
signalMW = np.squeeze(wavelet_challenge['signalMW'])
srate = wavelet_challenge['srate'][0][0]

npnts = len(original)

In [None]:
# Plot in time domain and frequency domain

time_vec = np.arange(0, npnts)/srate
hz_vec = np.linspace(0, srate, npnts)

# Get amplitude spectra
amp_original = np.abs(scipy.fft.fft(original)/npnts)
amp_signalFIR = np.abs(scipy.fft.fft(signalFIR)/npnts)
amp_signalMW = np.abs(scipy.fft.fft(signalMW)/npnts)

plt.figure()
plt.subplot(311)
plt.plot(time_vec, original, label='original')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (a.u.)')


plt.subplots_adjust(hspace = 1)

plt.subplot(312)
plt.plot(time_vec, signalFIR, label='FIR')
plt.plot(time_vec, signalMW, label='MW')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (a.u.)')
plt.legend()

plt.subplot(313)
plt.plot(hz_vec, amp_original, label='original')
plt.plot(hz_vec, amp_signalFIR, label='FIR')
plt.plot(hz_vec, amp_signalMW, label='MW')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude (a.u.)')
plt.xlim([0,20])
plt.legend()




In [None]:
# 2-stage wideband filter? 
# Refresh memory on how to do this 

hp_cutoff = 9 # Hz
hp_transw = 1 # Hz
order = int(np.round(11*srate/hp_cutoff))
print(order)

frex = [0, hp_cutoff, hp_cutoff + hp_transw, srate/2]
bshape = [0, 0, 1, 1]

# Make high pass filter
hp_filter = signal.firls(order, bands=frex, desired=bshape, fs=srate)

# Evaluate filter characteristics
plt.subplot(121)
plt.plot(hp_filter)

plt.subplot(122)
hz_hpfilt = np.linspace(0, srate, order)
plt.plot(frex, bshape, label='ideal')
filtpow = np.abs(scipy.fft.fft(hp_filter))**2
plt.plot(hz_hpfilt, filtpow, label='actual')
plt.xlim([0,20])
plt.legend()

In [None]:
# Apply high pass filter to data
signal_hp = signal.filtfilt(hp_filter, 1, original)

plt.subplot(211)
plt.plot(time_vec, original, label='original')
plt.plot(time_vec, signal_hp, label='high-pass filtered')
plt.legend()
plt.title('Time Domain')

plt.subplots_adjust(hspace = 1)

amp_hp = np.abs(scipy.fft.fft(signal_hp)/npnts)
plt.subplot(212)
plt.plot(hz_vec, amp_original, label='original')
plt.plot(hz_vec, amp_hp, label='high-pass filtered')
plt.plot(hz_vec, amp_signalFIR, label='FIR')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude')
plt.xlim([0, 20])
plt.title('Frequency Domain')
plt.legend()



In [None]:
#lowpass filter
lp_cutoff = 14 # Hz
lp_transw = 3 # Hz
order = int(13*srate/lp_cutoff)
print(order)

frex = [0, lp_cutoff, lp_cutoff+lp_transw, srate/2]
bshape = [1, 1, 0, 0]

lp_filter = signal.firls(order, bands=frex, desired=bshape, fs=srate)

plt.subplot(121)
plt.plot(lp_filter)

plt.subplot(122)
hz_filt = np.linspace(0,srate,order)
filtpow = np.abs(scipy.fft.fft(lp_filter))**2 # not dividing by order to make max 1
plt.plot(frex, bshape, label='ideal')
plt.plot(hz_filt, filtpow, label='actual')
plt.xlim([0,20])
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power')

In [None]:
# apply low-pass filter to hp_filtered data for 2-stage filter
signal_lp = signal.filtfilt(lp_filter, 1, signal_hp)

# Plot filtered 
plt.subplot(211)
plt.plot(time_vec, signal_hp, label='high pass')
plt.plot(time_vec, signal_lp, label='low pass')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Time Domain')

plt.subplots_adjust(hspace = 1)

amp_lp = np.abs(scipy.fft.fft(signal_lp)/npnts)

plt.subplot(212)
plt.plot(hz_vec, amp_hp, label='high-pass')
plt.plot(hz_vec, amp_lp, label='low-pass')
plt.plot(hz_vec, amp_signalFIR, label='FIR')
plt.legend()
plt.xlim([0, 20])
plt.xlabel('Frequency (Hz)')    
plt.ylabel('Amplitude')
plt.title('Frequency Domain')


In [None]:
# Now time to try the wavelet convolution 

# Create morlet wavelet 
# wavelet parameters
# Center wavelet frequency at 12.5 with 5 Hz bandpass
ffreq = 12.5 # filter frequency in Hz
fwhm = 1/4# full-width at half-maximum in seconds
wavtime = np.arange(-3, 3, 1/srate)

# Create the wavelet 
morwav = np.cos(2*np.pi*ffreq*wavtime) * np.exp( -(4*np.log(2)*wavtime**2) / fwhm**2)

# amplitude spectrum of wavelet
# (note that the wavelet needs its own hz because different length)
wavehz = np.linspace(0, srate/2, int(np.floor(len(wavtime)/2)+1))
morwavX = 2*np.abs(scipy.fft.fft(morwav))

# plot it!
plt.subplot(211)
plt.plot(wavtime,morwav, 'k')
plt.xlim([-5, 5])
plt.xlabel('Time (sec.)')

plt.subplot(212)
plt.plot(wavehz, morwavX[:len(wavehz)], 'k')
plt.xlim([0, ffreq*2])
plt.xlabel('Frequency (Hz)')
plt.show()



In [None]:
# manual convolution

nConv = npnts + len(wavtime) - 1
halfw = int(np.floor(len(wavtime)/2) )

# spectrum of wavelet 
morwavX = scipy.fft.fft(morwav, nConv) #compute it to have the same length as nConv. Zeros are padded into morwav so that the spectral multiplication is valid

# now normalize in the frequency domain
morwavX = morwavX / np.max(morwavX)

# now for the rest of convolution 
convres = scipy.fft.ifft( morwavX * scipy.fft.fft(original, nConv)) # also padd the original signal to have same number of points as the result of convolution
numOfConvres = len(convres) # 24575 length
convres = np.real( convres[halfw:-halfw+1]) # 12288 length -> same as original (Crop the padded zeros "wings" or "edges")

# time domain 
plt.plot(time_vec, original, 'k', label='original')
plt.plot(time_vec, convres, 'b', label = 'morlet wavelet filtered')
plt.plot(time_vec, signalFIR, label='FIR')
plt.plot(time_vec, signalMW, label='MW')
plt.legend()
plt.xlabel('Time')
plt.show()

# frequency domain
convresX = np.abs(scipy.fft.fft(convres)/npnts) 
plt.plot(hz_vec, amp_original, label='original')
plt.plot(hz_vec, amp_signalFIR, label='FIR')
plt.plot(hz_vec, amp_signalMW, label='MW')
plt.plot(hz_vec, convresX, label='morlet wavelet conv')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude (a.u.)')
plt.xlim([0,20])
plt.legend()


In [None]:
print((nConv))
print(numOfConvres)
print(len(convres))
print(halfw)
print(len(morwav))
print(len(original))

In [None]:
(scipy.fft.fft(original, nConv))
